#!/usr/bin/python3

import argparse
import base64
import binascii
import json
import os
import subprocess
import sys
import time

os.environ.setdefault("CONDOR_CONFIG", "/etc/condor-ce/condor_config")

import htcondor2 as htcondor
import htcondorce.tools as ce


DEFAULT_COLLECTOR_PORT="9619"

def parse_args():
    parser = argparse.ArgumentParser(description="Create a token based off of a SciToken for authentication with a remote CE")
    parser.add_argument("-p", "--pool", help="Remote collector to query", metavar="coll",
        default="collector.opensciencegrid.org")
    parser.add_argument("-n", "--name", help="CE name to exchange SciToken with", metavar="name",
        required=True)
    parser.add_argument("-s", "--scitoken", help="File containing SciToken to exchange", metavar="token_file",
        required=True)
    parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode")
    return parser.parse_args()


def acquire_ce_token(pool, name, token_dir, debug=False):
    fetch_env = dict(os.environ)
    fetch_env.setdefault('_condor_SEC_TOKEN_DIRECTORY', token_dir)
    args = ['condor_token_fetch', '-pool', pool, '-name', name,
        "-lifetime", "3600", '-remote', '-authz', 'READ']
    if debug:
        args.append('-debug')
    result = subprocess.run(args, stderr=subprocess.PIPE,
        stdout=subprocess.PIPE, env=fetch_env)
    if result.returncode:
        print("condor_token_fetch failed with error %d:" % result.returncode)
        print(result.stderr.decode('ascii'), file=sys.stderr, end='')
    return ce.to_str(result.stdout)


def exchange_scitoken(pool, name, scitoken_fname, token_dir, debug=False):
    exchange_env = dict(os.environ)
    exchange_env.setdefault('_condor_SEC_TOKEN_DIRECTORY', token_dir)
    args = ['condor_scitoken_exchange', '-scitoken', scitoken_fname,
        "-pool", pool, "-name", name]
    if debug:
        args.append("-debug")
    result = subprocess.run(args, stderr=subprocess.PIPE,
        stdout=subprocess.PIPE, env=exchange_env)
    if result.returncode:
        print("condor_scitoken_exchange failed with error %d:" % result.returncode)
        print(result.stderr.decode('ascii'), file=sys.stderr, end='')
    return ce.to_str(result.stdout)


def token_is_valid(contents):
    info = contents.split(".")
    if len(info) != 3:
        return False

    payload = info[1]
    while len(payload) % 4 != 0:
        payload += "="

    try:
        payload_contents = base64.b64decode(payload)
    except binascii.Error:
        return False
    try:
        payload_obj = json.loads(payload_contents)
    except json.JSONDecodeError:
        return False
    if not isinstance(payload_obj, dict):
        return False

    if 'exp' not in payload_obj or \
            not isinstance(payload_obj['exp'], int):
        return False

    now = int(time.time())
    if now >= payload_obj['exp']:
        return False

    return True


def main():
    args = parse_args()

    token_dir = htcondor.param.get('SEC_TOKEN_DIRECTORY')
    if not token_dir:
        if os.geteuid() == 0:
            token_dir = htcondor.param['SEC_TOKEN_SYSTEM_DIRECTORY']
        else:
            token_dir = os.path.expanduser("~/.condor-ce/tokens.d")
    try:
        os.makedirs(token_dir)
    except FileExistsError:
        pass

    if args.debug:
        print("Will acquire a token from pool=%s name=%s from scitoken in %s" %
            (args.pool, args.name, args.scitoken),
            file=sys.stderr)

    exchange_token_name = os.path.join(token_dir, args.name)
    try:
        with open(exchange_token_name, "r") as fp:
            contents = fp.read()
    except FileNotFoundError:
        with open(exchange_token_name, "w") as fp:
            pass
        contents = ''

    pool = args.pool
    if not pool.startswith("<") and pool.find(":") < 0:
        pool += ":" + DEFAULT_COLLECTOR_PORT

    if not token_is_valid(contents):
        if args.debug:
            print("No pre-existing token found -- will acquire a new one from %s.\n" % pool, file=sys.stderr)
        contents = acquire_ce_token(pool, args.name, token_dir, debug=args.debug)
        if not contents:
            print("Failed to acquire a token from collector %s to authenticate with CE %s" % (pool, args.name), file=sys.stderr)
            sys.exit(2)
        with open(exchange_token_name, "w") as fp:
            fp.write(contents.decode('ascii') + "\n")

    result = exchange_scitoken(pool, args.name, args.scitoken, token_dir, debug=args.debug)
    if result:
        print(result.decode('ascii'), end='')
    else:
        print("Failed to exchange SciToken in %s with CE at %s" % (args.scitoken, args.name), file=sys.stderr)
        sys.exit(3)

if __name__ == '__main__':
    main()
