#!/usr/bin/python3

import argparse
import os
import socket
import subprocess
import sys

import htcondor2 as htcondor

DEFAULT_COLLECTOR_PORT="9619"

def parse_args():
    default_hostname = socket.getfqdn()

    parser = argparse.ArgumentParser(description="Register a HTCondor-CE with a collector")
    parser.add_argument("-pool", help="Remote collector to query", metavar="coll",
        default="collector.opensciencegrid.org")
    parser.add_argument("-name", help="CE name to register as", metavar="name",
        default=default_hostname)
    parser.add_argument("-debug", action="store_true", help="Enable debug mode")
    return parser.parse_args()


def request_token(pool, name, debug=False):
    coll = htcondor.Collector(pool)
    coll_ad = coll.locate(htcondor.DaemonTypes.Collector)

    identity = '%s@users.htcondor.org' % name

    htcondor.param['SEC_TOKEN_DIRECTORY'] = '/etc/condor-ce/tokens.d'

    req = htcondor.TokenRequest(identity, bounding_set=['ADVERTISE_SCHEDD'])
    req.submit(coll_ad)
    reqid = req.request_id

    pool_without_port = pool.split(":")[0]
    print("Request is queued with collector %s; please approve it by visiting"
          " https://%s/registry/code" % (pool, pool_without_port))
    token = None
    for idx in range(10):
        print("Request ID is %s" % reqid)
        try:
            token = req.result(600)
        except RuntimeError:
            print("Failure occurred: %s; will retry with a new request")
            req = htcondor.TokenRequest(identity, bounding_set=['ADVERTISE_SCHEDD'])
            req.submit(coll_ad)
            reqid = req.request_id
        break

    if token is None:
        print("Token request failed after 10 retries.")
        return False

    token.write("50-%s-registration" % pool_without_port)
    print("CE registration is complete!")
    return True


def reconfig():
    subprocess.call(["condor_ce_reconfig"])


def main():
    args = parse_args()

    if os.geteuid() != 0:
        print("This command must be run as root")
        sys.exit(1)

    pool = args.pool
    if not pool.startswith("<") and pool.find(":") < 0:
        pool += ":" + DEFAULT_COLLECTOR_PORT

    if request_token(pool, args.name, debug=args.debug):
        print("Successfully acquired token; sending a reconfig to the CE.")
        reconfig()

if __name__ == '__main__':
    main()
