#!/bin/python3

# Copyright (c) Veeam Software Group GmbH

import sys,os,signal,fcntl,argparse,traceback,subprocess

def conflict_exit_code(arg):
    result = int(arg)
    if result < 0 or result > 255:
        raise ValueError("invalid exit code")
    return result

def get_lockpath():
    cp = subprocess.run(["rpm", "-E", "%{?_rpmlock_path}"], capture_output=True, shell=False)
    cp.check_returncode()
    result = cp.stdout.decode('utf8').strip()
    return result or "/var/lib/rpm/.rpm.lock"

def main(argv):
    parser = argparse.ArgumentParser(prog=argv[0],
                                     description="execute command with locking rpm database")
    excgrp = parser.add_mutually_exclusive_group(required=False)
    parser.add_argument("cmd", type=str, choices=["dnf", "rpm"],
                        help="command to be executed")
    parser.add_argument("args", type=str, nargs="*",
                        help="command to be executed")
    excgrp.add_argument("-x", "--exclusive", action='store_true',
                        help="get an exclusive lock (default)")
    excgrp.add_argument("-s", "--shared", action='store_true',
                        help="get a shared lock")
    parser.add_argument("-n", "--nonblock", action='store_true',
                        help="fail rather than wait")
    parser.add_argument("-E", "--conflict-exit-code", type=conflict_exit_code, default=1,
                        help="exit code after conflict or timeout")
    parser.add_argument("-w", "--timeout", type=int, default=15*60,
                        help="wait for a limited amount of time")
    args = parser.parse_args(argv[1:])

    lockpath = get_lockpath()
    fd = os.open(lockpath, os.O_RDWR | os.O_CREAT, 0o644)
    fcntl.fcntl(fd, fcntl.F_SETFD, fcntl.fcntl(fd, fcntl.F_GETFD) & ~fcntl.FD_CLOEXEC)

    lockmode = fcntl.LOCK_EX
    if args.shared:
        lockmode = fcntl.LOCK_SH
    if args.nonblock:
        lockmode |= fcntl.LOCK_NB

    if args.timeout == 0:
        lockmode |= fcntl.LOCK_NB
    elif args.timeout > 0:
        def handler(signal, frame):
            print(f"timeout for locking \"{lockpath}\"", file=sys.stderr)
            exit(args.conflict_exit_code)
        signal.signal(signal.SIGALRM, handler)
        signal.alarm(args.timeout)

    try:
        fcntl.lockf(fd, lockmode)
    except BlockingIOError:
        signal.alarm(0)
        print(f"failed to lock \"{lockpath}\"", file=sys.stderr)
        exit(args.conflict_exit_code)
    signal.alarm(0)

    os.execlp(args.cmd, args.cmd, *args.args)


if __name__ == '__main__':
    try:
        main(sys.argv)
    except Exception as exc:
        print(f"ERROR: {str(exc)}", file=sys.stderr)
        print(f"* Traceback:", file=sys.stderr)
        modules = dict((m.__file__, n) for n, m in sys.modules.items()
                       if hasattr(m, "__file__"))
        for t in traceback.extract_tb(exc.__traceback__):
            m = modules.get(t.filename, t.filename)
            print(f"* {m}:{t.lineno}:{t.name}", file=sys.stderr)
        exit(1)
