#!/usr/bin/python3
# -*- coding: utf-8 -*-

"""
Special loader for veeamsnap precompiled kernel module
"""

__author__  = "Veeam Software Group GmbH"
__license__ = "GPL-2"

import os
import sys
import platform


class CKernelVersion:
    def __init__(self, str):
        #common version string like "3.10.0-862.6.3.el7.x86_64.debug"
        self._versionString = str 

        #two halhs like "3.10.0" and "862.6.3.el7.x86_64.debug"
        (self._revisionString, self._patchString) = str.split("-")  

        #first half numbers collection [3, 10, 0]
        nums = self._revisionString.split(".")
        inx = 0
        self._revisionNumbers = []  
        while inx<len(nums) and nums[inx].isdigit():
            self._revisionNumbers.append(int(nums[inx]))
            inx += 1
        
        #second half numbers collection [862, 6, 3]
        nums = self._patchString.split(".")
        inx = 0
        self._patchNumbers = []  
        while inx < len(nums) and nums[inx].isdigit():
            self._patchNumbers.append(int(nums[inx]))
            inx += 1

        #flavour words collection ["el7", "x86_64", "debug"]
        self._flavour = []  
        while inx < len(nums):
            self._flavour.append(nums[inx])
            inx += 1
    
    def GetAccurateVersion(self):
        return self._versionString

    def IsSameFlavour(self, other):
        if len(self._flavour) != len(other._flavour):
            return False

        inx = 0;
        while inx < len(self._flavour):
            if self._flavour[inx][0:3] in ["el7", "el8", "el9"]:
                if self._flavour[inx][0:3] != other._flavour[inx][0:3]:
                    return False;
            else:
                if self._flavour[inx] != other._flavour[inx]:
                    return False;
            inx += 1
        return True;

    def IsSameRelease(self, other):
        return ((self._revisionString == other._revisionString) and
                (self._patchNumbers[0] == other._patchNumbers[0]) and
                self.IsSameFlavour(other) and
                self > other)

    def IsUek(self):
        result = self._revisionString+"-"+str(self._patchNumbers[0])
        for flavour in self._flavour:
            if flavour in ["el7uek", "el8uek", "el9uek"]:
                return True
        return False

    @staticmethod
    def _CompareNumbers(first, second):
        firstLen = len(first)
        secondLen = len(second)

        for inx in range(min(firstLen, secondLen)):
            if first[inx] < second[inx]:
                return -1
            if first[inx] > second[inx]:
                return 1

        if firstLen < secondLen:
            return -1
        elif firstLen > secondLen:
            return 1
        else:
            return 0

    def __eq__(self,other):
        if not isinstance(other,CKernelVersion):
            return NotImplemented

        res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
        if res == 0:
            res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
            if res == 0:
                return True
        return False

        
    def __lt__(self,other):
        if not isinstance(other,CKernelVersion):
            return NotImplemented

        res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
        if res == -1:
            return True
        elif res == 0:
            res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
            if res == -1:
                return True

        return False

    def __le__(self,other):
        if not isinstance(other,CKernelVersion):
            return NotImplemented

        res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
        if res == -1:
            return True
        elif res == 0:
            res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
            if res == -1 or res == 0:
                return True

        return False

    def __gt__(self,other):
        if not isinstance(other,CKernelVersion):
            return NotImplemented

        res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
        if res == 1:
            return True
        elif res == 0:
            res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
            if res == 1:
                return True

        return False

    def __ge__(self,other):
        if not isinstance(other,CKernelVersion):
            return NotImplemented

        res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
        if res == 1:
            return True
        elif res == 0:
            res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
            if res == 1 or res == 0:
                return True

        return False

    def __str__(self):
        return self._versionString
    def __repr__(self):
        return self._versionString

class CKernelModule:
    def __init__(self, _kernelVersion, _modulePath):
        self._kernelVersion = _kernelVersion;
        self._modulePath = _modulePath

def ExecuteCommand(cmd):
    print("Execute ["+cmd+"]")
    sys.stdout.flush()
    result = os.system(cmd)
    if result != 0:    
        raise RuntimeError("Failed to execute ["+cmd+"]")

def FindSameReleaseModule(kver):
    modulesList = []
    path = "/lib/modules"
    entries = os.listdir(path)
    for dirName in entries:
        fullDirPath = os.path.join(path, dirName)
        if os.path.isdir(fullDirPath):
            fileFullPath = os.path.join(fullDirPath, "extra/veeamsnap.ko")
            if os.path.exists(fileFullPath):
                modulesList.append(CKernelModule(CKernelVersion(dirName), fileFullPath))

    modulesList = sorted(modulesList, key=lambda CKernelModule: CKernelModule._kernelVersion, reverse=True)

    print("Kmod binaries were found:")
    for module in modulesList:
        print("\t"+str(module._modulePath))

    for module in modulesList:
        if kver.IsSameRelease(module._kernelVersion):
            return module._modulePath

    return None

def FindSuitableModule(kver):
    modulesList = []
    path = "/lib/modules"
    entries = os.listdir(path)
    for dirName in entries:
        fullDirPath = os.path.join(path, dirName)
        if os.path.isdir(fullDirPath):
            fileFullPath = os.path.join(fullDirPath, "extra/veeamsnap.ko")
            if os.path.exists(fileFullPath):
                modulesList.append(CKernelModule(CKernelVersion(dirName), fileFullPath))

    modulesList = sorted(modulesList, key=lambda CKernelModule: CKernelModule._kernelVersion, reverse=True)
    
    print("Veeamsnap kmod binaries were found:")
    for module in modulesList:
        print("\t"+str(module._modulePath))

    for module in modulesList:
        if module._kernelVersion <= kver:
            return module._modulePath

    return None

def CheckModuleExist(modulePath):
    if not os.path.isfile(modulePath):
        raise RuntimeError("Module [{0}] is not exist".format(modulePath))

def TryLoadModule(modulePath, params):
    CheckModuleExist(modulePath)

    cmd = "insmod " + modulePath
    for param in params:
        cmd += " "
        cmd += param

    ExecuteCommand(cmd)

def ModInfo(modulePath):
    CheckModuleExist(modulePath)

    ExecuteCommand("modinfo \""+modulePath+"\"")


def TryProcess(modulePath, params, isCheck, isModinfo):
    if isCheck:
        CheckModuleExist(modulePath)
    elif isModinfo:
        ModInfo(modulePath)
    else:
        TryLoadModule(modulePath, params )

def dist_os_release():
    d = {}
    with open("/etc/os-release") as file:
        for line in file:
            stripped = line.strip()
            if stripped and not stripped.startswith("#"):
                key, val = stripped.split("=")
                key = key.strip()
                val = val.strip()
                d[key] = val
    return d["ID"].strip("\""), d["VERSION_ID"].strip("\"")

def Usage():
    usageText = "Usage:\n"
    usageText += "Load a compatible veeamsnap kernel module for the current kernel.\n"
    usageText += "\n"
    usageText += "\t--help, -h - Show this usage.\n"
    usageText += "\t--check - Checks the existence of a compatible kernel module for the current kernel.\n"
    usageText += "\t--modinfo - Call 'modinfo' for a compatible kernel module.\n"
    usageText += "\t--kmodAnyKernel - Try to find a module with a suitable release. In this case the system kernel may be incompatible with the module.\n"
    print(usageText);

def main(params):
    isCheck = False;
    isModinfo = False;
    isKmodAnyKernel = False;

    try:
        # Process parameters
        isParameter = True;
        while isParameter and (len(params) > 0):
            if params[0] == "--help" or params[0] == "-h":
                Usage()
                return 0
            if params[0] == "--check":
                isCheck = True;
                isKmodAnyKernel = True;
            elif params[0] == "--modinfo":
                isModinfo = True;
                isKmodAnyKernel = True;
            elif params[0] == "--kmodAnyKernel":
                isKmodAnyKernel = True;
            else: 
                isParameter = False;
            
            if isParameter:
                del params[0]

        if (platform.system() != "Linux"):
            raise RuntimeError("Found unsupported platform [{0}]".format(platform.system()) )

        distName, distVersion = dist_os_release();
        unameKernelVersion = platform.uname()[2]

        if not distName in ["oracle", "ol", "redhat", "rhel", "centos", "rocky", "almalinux"]:
            raise RuntimeError("Found unsupported distribution [{0}]".format(distName) )

        distVersionWords = distVersion.split(".")
        if distVersionWords[0] < "6":
            raise RuntimeError("Found unsupported platform version [{0}]".format(distVersion) )

        kver = CKernelVersion(unameKernelVersion)
    
        #try find direct module
        strAccurateKernelVersion = kver.GetAccurateVersion()
        kmodPath = "/lib/modules/"+strAccurateKernelVersion+"/extra/veeamsnap.ko"
        print("Accurate kernel version: "+strAccurateKernelVersion)
        try:
            TryProcess(kmodPath, params, isCheck, isModinfo)
            return 0
        except Exception as ex:
            print("Failed to load module for accurate kernel version [{0}]".format(strAccurateKernelVersion) )

        #try find module with same release
        if kver.IsUek():
            print("Skip an attempt to load a same release kernel for UEK")
        else:
            #try find module with same release kernel 
            kmodPath = FindSameReleaseModule(kver)
            if kmodPath is None:
                print("Failed to find same release module for kernel [{0}]".format(strAccurateKernelVersion) )
            else:
                print("Same release kernel module path: "+kmodPath)
                try:
                    TryProcess(kmodPath, params, isCheck, isModinfo)
                    return 0
                except Exception as ex:
                    print("Failed to load module for same release kernel version [{0}]".format(strAccurateKernelVersion) )

        if not isKmodAnyKernel:
            return -1
        
        #try find module with suitable release
        kmodPath = FindSuitableModule(kver)
        if kmodPath is None:
            print("Failed to find suitable module for kernel [{0}]".format(strAccurateKernelVersion) )
            return -1

        print("Suitable kernel module path: "+kmodPath)
        try:
            TryProcess(kmodPath, params, isCheck, isModinfo)
            return 1 # This mean that the module was loaded, but the system kernel may be incompatible with the module
        except Exception as ex:
            print("Failed to find suitable veeamsnap kernel module. {0}".format(ex))
            return -1
    except:
        print("Unexpected error:"+str(sys.exc_info()[0]))
        raise

params = sys.argv
del params[0]
result = main(params)
exit(result)
