#!/usr/bin/env python
#
# compares cigar strings for uta0, uta1 databases:
# (1) queries uta0 for {tx_ac:cigars}
# (2) queries uta1 for {tx_ac:[alt_ac,cigars]} where tx_ac is in the uta0 set and alt_aln_method='splign'
# (3) for each tx_ac:
# (3a) get uta0 cigar
# (3b) find the uta1 cigar for tx_ac where alt_ac = an accession from GRChr37.
# (3c) translate the uta1 cigar to uta0 format (map D->I, I->D, =->M)
# (3d) compare
# (4) write all results + mismatches to /tmp
#
from __future__ import with_statement
import collections
import csv
import string

import psycopg2, psycopg2.extras

# just need the keys
NC_to_chr_dict = {
    'NC_000001.10':  '1', 'NC_000002.11':  '2', 'NC_000003.11': '3',
    'NC_000004.11':  '4', 'NC_000005.9' :  '5', 'NC_000006.11': '6',
    'NC_000007.13':  '7', 'NC_000008.10':  '8', 'NC_000009.11': '9',
    'NC_000010.10': '10', 'NC_000011.9' : '11', 'NC_000012.11': '12',
    'NC_000013.10': '13', 'NC_000014.8' : '14', 'NC_000015.9' : '15',
    'NC_000016.9' : '16', 'NC_000017.10': '17', 'NC_000018.9' : '18',
    'NC_000019.9' : '19', 'NC_000020.10': '20', 'NC_000021.8' : '21',
    'NC_000022.10': '22', 'NC_000023.10':  'X', 'NC_000024.9' : 'Y',
    }

GRChr37NCs = NC_to_chr_dict.keys()

translation = string.maketrans('DI=', 'IDM')

def get_uta0_cigars():
    dsn="host=uta.invitae.com dbname=uta user=uta_public password=uta_public"
    conn = psycopg2.connect(dsn)
    sel_cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)

    sel_sql = "SELECT ac, cigars from transcript_cigars_v;"
    sel_cur.execute(sel_sql)
    uta0_cigars = {}
    for row in sel_cur.fetchall():
        uta0_cigars[row['ac']] = row['cigars']
    return uta0_cigars


def  get_uta1_cigars(txs):
    dsn="host=uta.invitae.com dbname=uta_dev user=uta_public password=uta_public"
    conn = psycopg2.connect(dsn)
    sel_cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
    txs_string = str(txs)[1:-1]
    sel_sql = "SELECT tx_ac, alt_ac, cigars from tx_aln_cigar_dv WHERE alt_aln_method='splign' AND tx_ac in ({})".\
        format(txs_string)
    sel_cur.execute(sel_sql)
    uta1_cigars = collections.defaultdict(list)
    for row in sel_cur.fetchall():
        uta1_cigars[row['tx_ac']].append({'alt_ac': row['alt_ac'], 'cigars': row['cigars']})
    return uta1_cigars

def uta1_cigar_as_uta0(uta1_cigar):
    return uta1_cigar.translate(translation)


def main():

    uta0_cigars = get_uta0_cigars()
    uta1_cigars = get_uta1_cigars(uta0_cigars.keys())

    mismatches = []
    with open('/tmp/uta0_ut1_compare.tsv', 'w') as f:
        header = ['tx_ac', 'uta0_cigar', 'uta1_cigar']
        writer = csv.DictWriter(f, header, delimiter='\t', lineterminator='\n')
        writer.writeheader()
        for tx_ac in uta0_cigars:
            uta0_cigar = uta0_cigars[tx_ac]
            uta1_cigar_candidates = uta1_cigars[tx_ac]

            for candidate in uta1_cigar_candidates:
                uta1_cigar = 'MISSING'
                if candidate['alt_ac'] in GRChr37NCs:
                    uta1_cigar = candidate['cigars']
                    break

            row = {'tx_ac': tx_ac, 'uta0_cigar': uta0_cigar, 'uta1_cigar': uta1_cigar}
            writer.writerow(row)

            # comparison
            if uta0_cigar != uta1_cigar_as_uta0(uta1_cigar):
                mismatches.append(row)

    with open('/tmp/uta0_ut1_mismatches.tsv', 'w') as f:
        writer = csv.DictWriter(f, header, delimiter='\t', lineterminator='\n')
        writer.writeheader()
        writer.writerows(mismatches)






if __name__ == "__main__":
    main()

