#!/usr/bin/python3

import os
import sys
import json
import tempfile
import time
import optparse
import collections
import multiprocessing

os.environ.setdefault("CONDOR_CONFIG", "/etc/condor-ce/condor_config")

# htcondor and classad modules are imported within the functions that need them 
# to avoid loading HTCondor code in the parent then forking. Otherwise, issues
# arise in the child processes.

from htcondorce import rrd, web_utils

def parse_opts():
    parser = optparse.OptionParser()
    parser.add_option("-s", "--spool", help="Spool directory", dest="spool")

    return parser.parse_args()[0]

def secure_json_write(fname, contents):
    temp_fd, temp_fname = tempfile.mkstemp(dir=os.path.dirname(fname))
    temp_file = os.fdopen(temp_fd, 'w')
    try:
        try:
            json.dump(contents, temp_file)
            os.rename(temp_fname, fname)
        except:
            print("Error writing json file", file=sys.stderr)
    finally:
        temp_file.close()


def _newstat():
    return {"Running": 0, "Idle": 0, "Held": 0, "Jobs": 0}

def process_one_schedd(ad):
    import htcondor2 as htcondor
    import classad2 as classad
    ad = classad.ClassAd(ad)
    schedd = htcondor.Schedd(ad)
    status_map = {1: "Idle", 2: "Running", 5: "Held"}
    total_results = _newstat()
    gpu_total_results = _newstat()
    vo_results = collections.defaultdict(_newstat)
    job_count = collections.defaultdict(_newstat)
    gpu_job_count = collections.defaultdict(_newstat)
    print(f"Processing CE {ad.get('Name', 'Unknown')}.", file=sys.stderr)
    try:
        for job in schedd.query("RoutedJob is UNDEFINED", ["JobStatus", 'x509UserProxyVOName',
                                          'x509UserProxyFirstFQAN', 'x509userproxysubject', 'RequestGPUs']):
            dn = job.get("x509userproxysubject") or 'Unknown'
            vo = job.get('x509UserProxyVOName') or 'Unknown'
            voms = job.get('x509UserProxyFirstFQAN', '').replace("/Capability=NULL", "").replace("/Role=NULL", "") or 'Unknown'
            gpus = job.get('RequestGPUs')
            job_key = (dn, vo, voms)
            results = vo_results[vo]
            job_results = job_count[job_key]
            results["Jobs"] += 1
            total_results['Jobs'] += 1
            job_results['Jobs'] += 1

            status = status_map.get(job.get("JobStatus"))
            if status:
                results[status] += 1
                job_results[status] += 1
                total_results[status] += 1

            if gpus:
                gpu_job_count[job_key]["Jobs"] += 1
                gpu_total_results['Jobs'] += 1
                if status:
                    gpu_job_count[job_key][status] += 1
                    gpu_total_results[status] += 1

    except RuntimeError as e:
        print(f"Failure to query CE {ad.get('Name', 'Unknown')}: {e}", file=sys.stderr)
    return [total_results, vo_results, job_count,
            gpu_total_results, gpu_job_count, str(ad)]


def get_ads(spooldir):
    if not spooldir:
        spooldir = web_utils.get_spooldir()
    environ = {'htcondorce.spool': spooldir}
    ads = [str(ad) for ad in web_utils.get_schedd_ads({})]
    return [ads, environ]


def dict_sum1(src, dst):
    """ helper to add each src[key] value to dst[key]
        assumes dst contains or provides defaults for keys from src """
    for key, val in src.items():
        dst[key] += val


def dict_sum2(src, dst):
    """ helper to add each src[key1][key2] value to dst[key1][key2]
        assumes dst contains or provides defaults for keys from src """
    for key, val in src.items():
        dict_sum1(val, dst[key])


def list_job_results_final(job_results):
    job_results_final = []
    for key, val in job_results.items():
        val['DN'] = key[0]
        val['VO'] = key[1]
        val['VOMS'] = key[2]
        job_results_final.append(val)
    return job_results_final


def main():
    opts = parse_opts()

    htcondor = __import__("htcondor2")
    # Check if the CEVIEW daemon is in the daemon list, exit if it is not
    if htcondor.param["DAEMON_LIST"].find("CEVIEW") == -1:
        print("CEVIEW is not in DAEMON_LIST, exiting...", file=sys.stderr)
        sys.exit(0)

    pool = multiprocessing.Pool(10)
    ads, environ = pool.map(get_ads, [opts.spool])[0]

    total_results = _newstat()
    vo_results = collections.defaultdict(_newstat)
    job_results = collections.defaultdict(_newstat)
    gpu_job_results = collections.defaultdict(_newstat)
    gpu_total_results = _newstat()

    map_results = pool.map(process_one_schedd, ads, 1)

    import classad2 as classad

    for (schedd_total_results, schedd_vo_results, schedd_job_results,
         schedd_gpu_total_results, schedd_gpu_job_results, ad_str) in map_results:

        dict_sum1(schedd_total_results,     total_results)
        dict_sum1(schedd_gpu_total_results, gpu_total_results)

        dict_sum2(schedd_vo_results,        vo_results)
        dict_sum2(schedd_job_results,       job_results)
        dict_sum2(schedd_gpu_job_results,   gpu_job_results)

        ad = classad.ClassAd(ad_str)

        total_fname = rrd.check_rrd(environ, ad['Name'], "jobs")
        if os.isatty(2):
            print(f"Schedd {ad['Name']} totals: {schedd_total_results}", file=sys.stderr)
        rrd.update(total_fname, "N:%d:%d:%d" % (schedd_total_results["Running"], schedd_total_results["Idle"],
                                                    schedd_total_results["Held"]))

        for vo, info in schedd_vo_results.items():
            vo_fname = rrd.check_rrd(environ, ad['Name'], "vos", vo)
            if os.isatty(2):
                print(f"Schedd {ad['Name']} VO {vo} totals: {info}", file=sys.stderr)
            rrd.update(vo_fname, "N:%d:%d:%d:%d" % (info['Running'], info['Idle'], info['Held'], info['Jobs']))

    job_results_final = list_job_results_final(job_results)
    pilots_fname = rrd.path_with_spool(environ, "pilots")
    secure_json_write(pilots_fname, job_results_final)

    # someday we might want to use these gpu_pilots totals, but not yet
    #gpu_job_results_final = list_job_results_final(gpu_job_results)
    #gpu_pilots_fname = rrd.path_with_spool(environ, "gpu_pilots")
    #secure_json_write(gpu_pilots_fname, gpu_job_results_final)

    # update schedd ad with GPU totals (Jobs, Idle, Running, Held)
    for key, val in gpu_total_results.items():
        print(f"GPUPilotsTotal{key}={val}")
    # send update to collector right away, too
    print("- update:true")

    total_results['UpdateTime'] = time.time()
    totals_fname = rrd.path_with_spool(environ, "totals")
    secure_json_write(totals_fname, total_results)

    total_fname = rrd.check_rrd(environ, None, "jobs")
    if os.isatty(2):
        print(f"Totals: {total_results}", file=sys.stderr)
    rrd.update(total_fname, "N:%d:%d:%d" % (total_results["Running"], total_results["Idle"], total_results["Held"]))


    vos_fname = rrd.path_with_spool(environ, "vos.json")
    secure_json_write(vos_fname, vo_results)
    for vo, info in vo_results.items():
        vo_fname = rrd.check_rrd(environ, None, "vos", vo)
        if os.isatty(2):
            print(f"VO {vo} totals: {info}", file=sys.stderr)
        rrd.update(vo_fname, "N:%d:%d:%d:%d" % (info['Running'], info['Idle'], info['Held'], info['Jobs']))


if __name__ == '__main__':
    main()
