#!/usr/bin/python3
# -*- coding: utf-8 -*-

"""
Special loader for blksnap precompiled kernel module
"""

__author__  = "Veeam Software Group GmbH"
__license__ = "GPL-2"

import os
import sys
import time
import platform


class CKernelVersion:
	def __init__(self, str):
		#common version string like "3.10.0-862.6.3.el7.x86_64.debug"
		self._versionString = str

		#two halhs like "3.10.0" and "862.6.3.el7.x86_64.debug"
		(self._revisionString, self._patchString) = str.split("-")

		#first half numbers collection [3, 10, 0]
		nums = self._revisionString.split(".")
		inx = 0
		self._revisionNumbers = []
		while inx<len(nums) and nums[inx].isdigit():
			self._revisionNumbers.append(int(nums[inx]))
			inx += 1

		#second half numbers collection [862, 6, 3]
		nums = self._patchString.split(".")
		inx = 0
		self._patchNumbers = []
		while inx < len(nums) and nums[inx].isdigit():
			self._patchNumbers.append(int(nums[inx]))
			inx += 1

		#flavour words collection ["el7", "x86_64", "debug"]
		self._flavour = []
		while inx < len(nums):
			self._flavour.append(nums[inx])
			inx += 1

	def GetAccurateVersion(self):
		return self._versionString

	def IsSameFlavour(self, other):
		if len(self._flavour) != len(other._flavour):
			return False

		inx = 0;
		while inx < len(self._flavour):
			if self._flavour[inx][0:3] in ["el7", "el8", "el9"]:
				if self._flavour[inx][0:3] != other._flavour[inx][0:3]:
					return False;
			elif self._flavour[inx][0:4] in ["el10"]:
				if self._flavour[inx][0:4] != other._flavour[inx][0:4]:
					return False;
			else:
				if self._flavour[inx] != other._flavour[inx]:
					return False;
			inx += 1
		return True;

	def IsSameRelease(self, other):
		if not isinstance(other,CKernelVersion):
			return NotImplemented

		if len(self._patchNumbers) != len(other._patchNumbers):
			return False

		if (self._revisionString != other._revisionString):
			return False

		if (len(self._patchNumbers) > 0) and (self._patchNumbers[0] != other._patchNumbers[0]):
			return False

		return (self.IsSameFlavour(other)) and (self > other)

	def IsUek(self):
		if len(self._patchNumbers) == 0:
			return False

		result = self._revisionString+"-"+str(self._patchNumbers[0])
		for flavour in self._flavour:
			if flavour in ["el7uek", "el8uek", "el9uek", "el10uek"]:
				return True
		return False

	@staticmethod
	def _CompareNumbers(first, second):
		firstLen = len(first)
		secondLen = len(second)

		for inx in range(min(firstLen, secondLen)):
			if first[inx] < second[inx]:
				return -1
			if first[inx] > second[inx]:
				return 1

		if firstLen < secondLen:
			return -1
		elif firstLen > secondLen:
			return 1
		else:
			return 0

	def __eq__(self,other):
		if not isinstance(other,CKernelVersion):
			return NotImplemented

		res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
		if res == 0:
			res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
			if res == 0:
				return True
		return False


	def __lt__(self,other):
		if not isinstance(other,CKernelVersion):
			return NotImplemented

		res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
		if res == -1:
			return True
		elif res == 0:
			res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
			if res == -1:
				return True

		return False

	def __le__(self,other):
		if not isinstance(other,CKernelVersion):
			return NotImplemented

		res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
		if res == -1:
			return True
		elif res == 0:
			res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
			if res == -1 or res == 0:
				return True

		return False

	def __gt__(self,other):
		if not isinstance(other,CKernelVersion):
			return NotImplemented

		res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
		if res == 1:
			return True
		elif res == 0:
			res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
			if res == 1:
				return True

		return False

	def __ge__(self,other):
		if not isinstance(other,CKernelVersion):
			return NotImplemented

		res = self._CompareNumbers(self._revisionNumbers, other._revisionNumbers)
		if res == 1:
			return True
		elif res == 0:
			res = self._CompareNumbers(self._patchNumbers, other._patchNumbers)
			if res == 1 or res == 0:
				return True

		return False

	def __str__(self):
		return self._versionString
	def __repr__(self):
		return self._versionString

class CKernelModule:
	def __init__(self, _kernelVersion, _modulePath):
		self._kernelVersion = _kernelVersion;
		self._modulePath = _modulePath

def ExecuteCommand(cmd):
	print("Execute ["+cmd+"]")
	sys.stdout.flush()
	result = os.system(cmd)
	if result != 0:
		raise RuntimeError("Failed to execute ["+cmd+"]")

def TryExecuteCommand(cmd):
	print("Try execute ["+cmd+"]")
	sys.stdout.flush()
	result = os.system(cmd)

def EnumerateAvaildableKernels(moduleNames):
	modulesList = []

	path = "/lib/modules"
	entries = os.listdir(path)
	for dirName in entries:
		fullDirPath = os.path.join(os.path.join(path, dirName), "extra")

		if not os.path.isdir(fullDirPath):
			continue

		found = True
		for moduleName in moduleNames:
			if not os.path.exists(os.path.join(fullDirPath, moduleName+".ko")):
				found = False
				break

		if found:
			modulesList.append(CKernelModule(CKernelVersion(dirName), fullDirPath))

	return modulesList;

def FindSameReleaseModule(kver, moduleNames):
	if not isinstance(kver, CKernelVersion):
		return NotImplemented

	modulesList = EnumerateAvaildableKernels(moduleNames)
	modulesList = sorted(modulesList, key=lambda CKernelModule: CKernelModule._kernelVersion, reverse=True)

	print("Kmod binaries were found:")
	for module in modulesList:
		print("\t"+str(module._modulePath))

	for module in modulesList:
		if kver.IsSameRelease(module._kernelVersion):
			return module._modulePath

	return None

def FindSuitableModule(strFoundVersion, moduleNames):
	foundVersion = CKernelVersion(strFoundVersion)
	print("Linux kernel found:"+str(foundVersion))

	modulesList = EnumerateAvaildableKernels(moduleNames)
	modulesList = sorted(modulesList, key=lambda CKernelModule: CKernelModule._kernelVersion, reverse=True)

	print("Blksnap kmod binaries were found:")
	for module in modulesList:
		print("\t"+str(module._modulePath))

	for module in modulesList:
		if module._kernelVersion <= foundVersion:
			return module._modulePath

	return None

def CheckModuleExist(modulePath):
	if not os.path.isfile(modulePath):
		raise RuntimeError("Module [{0}] is not exist".format(modulePath))

def TryUnloadModules():
	if os.path.exists("/dev/veeamblksnap"):
		try:
			ExecuteCommand("rmmod veeamblksnap")
		except Exception as ex:
			print("Failed to unload veeamblksnap. {0}".format(ex))
			return False

	if os.path.exists("/dev/blksnap"):
		try:
			ExecuteCommand("rmmod blksnap")
		except Exception as ex:
			print("Failed to unload blksnap. {0}".format(ex))
			return False

	if os.path.exists("/sys/kernel/livepatch/bdevfilter"):
		try:
			ExecuteCommand("echo 0 > /sys/kernel/livepatch/bdevfilter/enabled")
		except Exception as ex:
			print("Failed to disable bdevfilter. {0}".format(ex))
			return False

	cnt = 0
	while os.path.exists("/sys/kernel/livepatch/bdevfilter"):
		time.sleep(1)
		cnt+=1
		if cnt > 10:
			print("bdevfilter cannot be disabled.")
			return False

	TryExecuteCommand("rmmod bdevfilter")

	return True

def TryLoadBdevFilter(modulePath):
	modulePath = modulePath+'/'+"bdevfilter.ko"

	CheckModuleExist(modulePath)

	if os.path.isfile("/sys/kernel/livepatch/bdevfilter/enabled"):
		UnloadModule();

	cmd = "insmod " + modulePath
	ExecuteCommand(cmd)

def TryLoadBlksnap(modulePath, params):
	modulePath = modulePath+'/'+"veeamblksnap.ko"

	CheckModuleExist(modulePath)

	cmd = "insmod " + modulePath
	for param in params:
		cmd += " "
		cmd += param

	ExecuteCommand(cmd)

def ModInfo(modulePath):
	CheckModuleExist(modulePath)

	ExecuteCommand("modinfo \""+modulePath+"\"")


def TryProcess(modulePath, params, isCheck, isModinfo):
	if isCheck:
		CheckModuleExist(modulePath+'/'+"veeamblksnap.ko")
	elif isModinfo:
		ModInfo(modulePath+'/'+"veeamblksnap.ko")
	else:
		TryUnloadModules()
		TryLoadBdevFilter(modulePath)
		TryLoadBlksnap(modulePath, params)

def dist_os_release():
	d = {}
	with open("/etc/os-release") as file:
		for line in file:
			stripped = line.strip()
			if stripped and not stripped.startswith("#"):
				key, val = stripped.split("=")
				key = key.strip()
				val = val.strip()
				d[key] = val
	return d["ID"].strip("\""), d["VERSION_ID"].strip("\"")

def Usage():
	usageText = "Usage:\n"
	usageText += "Load a compatible blksnap kernel module for the current kernel.\n"
	usageText += "\n"
	usageText += "\t--help, -h - Show this usage.\n"
	usageText += "\t--check - Checks the existence of a compatible kernel module for the current kernel.\n"
	usageText += "\t--modinfo - Call 'modinfo' for a compatible kernel module.\n"
	usageText += "\t--kmodAnyKernel - Try to find a module with a suitable release. In this case the system kernel may be incompatible with the module.\n"
	usageText += "\t--unload - Try to unload kernel module.\n"
	print(usageText);

def main(params):
	isCheck = False;
	isModinfo = False;
	isKmodAnyKernel = False;

	try:
		# Process parameters
		isParameter = True;
		while isParameter and (len(params) > 0):
			if params[0] == "--help" or params[0] == "-h":
				Usage()
				return 0
			if params[0] == "--check":
				isCheck = True;
				isKmodAnyKernel = True;
			elif params[0] == "--modinfo":
				isModinfo = True;
				isKmodAnyKernel = True;
			elif params[0] == "--kmodAnyKernel":
				isKmodAnyKernel = True;
			elif params[0] == "--unload":
				TryUnloadModules()
				return 0
			else:
				isParameter = False;

			if isParameter:
				del params[0]

		if (platform.system() != "Linux"):
			raise RuntimeError("Found unsupported platform [{0}]".format(platform.system()) )

		distName, distVersion = dist_os_release();
		unameKernelVersion = platform.uname()[2]

		if not distName in ["oracle", "ol", "redhat", "rhel", "centos", "rocky", "almalinux"]:
			raise RuntimeError("Found unsupported distribution [{0}]".format(distName) )

		kver = CKernelVersion(unameKernelVersion)

		#try find direct module
		strAccurateKernelVersion = kver.GetAccurateVersion()
		kmodPath = "/lib/modules/"+strAccurateKernelVersion+"/extra"
		print("Accurate kernel version: "+strAccurateKernelVersion)
		try:
			TryProcess(kmodPath, params, isCheck, isModinfo)
			return 0
		except Exception as ex:
			print("Failed to load module for accurate kernel version [{0}]".format(strAccurateKernelVersion) )

		#try find module with same release
		if kver.IsUek():
			print("Skip an attempt to load a same release kernel for UEK")
		else:
			kmodPath = FindSameReleaseModule(kver, ["veeamblksnap", "bdevfilter"])
			if kmodPath is None:
				print("Failed to find same release modules for kernel [{0}]".format(strAccurateKernelVersion) )
			else:
				print("Same release kernel module path: "+kmodPath)
				try:
					TryProcess(kmodPath, params, isCheck, isModinfo)
					return 0;
				except Exception as ex:
					print("Failed to load module for same release kernel version")

		if not isKmodAnyKernel:
			return -1

		kmodPath = FindSuitableModule(strAccurateKernelVersion, ["veeamblksnap", "bdevfilter"])
		if kmodPath is None:
			print("Failed to find suitable module for kernel [{0}]".format(strAccurateKernelVersion) )
			return -1
		print("Suitable kernel module path: "+kmodPath)
		try:
			TryProcess(kmodPath, params, isCheck, isModinfo)
			return 1 # This mean that the module was loaded, but the system kernel may be incompatible with the module

		except Exception as ex:
			print("Failed to find suitable kernel module. {0}".format(ex))
			return -1
	except:
		print("Unexpected error:"+sys.exc_info()[0])
		raise

params = sys.argv
del params[0]
result = main(params)
exit(result)
