#!/usr/bin/env python
#
# convert lrg data into tx and exonset files
#
# LRG provides the following downloads:
# zip file of all XML files - includes TX coords in TX space -> use for tx_info
# zip file of all fasta files
# tsv file of summary data w/ genomic coords -> use for exon_set
#
# http://www.lrg-sequence.org
#
from __future__ import with_statement
import argparse
import operator
import csv
import operator
import os
import xml.etree.ElementTree

import uta.formats.exonset
import uta.formats.txinfo
import uta.luts

def main():
    parser = argparse.ArgumentParser(description='convert LRG files to format to import to UTA')
    parser.add_argument('-t', action="store", dest="tabfile", help="input tab-delimited text file from LRG")
    parser.add_argument('-x', action="store", dest="xmldir", help="input XML dir from LRG")
    parser.add_argument('-o', action="store", dest="outdir", help="directory to write output tx and exonset files")
    
    args_in = parser.parse_args()

    if not os.path.exists(args_in.outdir):
        os.makedirs(args_in.outdir)

    tx_acs_hgnc = {}    # save this for the TX file

    # write exon set file
    with open(args_in.tabfile, 'r') as f:
        _ = f.readline()   # discard the 1st line - timestamp
        reader = csv.DictReader(f, delimiter='\t')
        exonset_lines = set()

        for row in reader:
            # setup exon pair string
            exon_pairs = [x.split('-') for x in row['EXONS_COORDS'].split(',')]
            exon_coords = [(int(start), int(end)) for (start, end) in  exon_pairs]
            g_list = ['{},{}'.format(start - 1, end) for (start, end) in exon_coords]
            if row['STRAND'] == '-1':
                g_list.reverse()
            g_exon_coords_string = ';'.join(g_list)

            exonset_lines.add((row['# LRG_TRANSCRIPT'],
                               uta.luts.chr_to_NC[row['CHROMOSOME']],
                               'LRG',
                               row['STRAND'],g_exon_coords_string))
            tx_acs_hgnc[row['# LRG_TRANSCRIPT']] = row['HGNC_SYMBOL']


    with open(os.path.join(args_in.outdir, 'lrg.exonset'), 'w') as f:
        esw = uta.formats.exonset.ExonSetWriter(f)
        exonset_lines = list(exonset_lines)
        exonset_lines = sorted(exonset_lines, key=operator.itemgetter(0))
        for exonset_line in exonset_lines:
            exonset_row = uta.formats.exonset.ExonSet(*exonset_line)
            esw.write(exonset_row)


    # write tx info file
    with open(os.path.join(args_in.outdir, 'lrg.txinfo'), 'w') as f_out:
        tw = uta.formats.txinfo.TxInfoWriter(f_out)

        tx_acs = tx_acs_hgnc.keys()
        tx_acs.sort()

        for tx_ac in tx_acs:

            lrg_root = tx_ac[:tx_ac.find('t')]
            xml_fn = "{}.xml".format(lrg_root)    # TX are LRG_7t1; XML file name is LRG_7.xml
            xml_filepath = os.path.join(args_in.xmldir, xml_fn)

            if os.path.exists(xml_filepath):
                tree = xml.etree.ElementTree.parse(xml_filepath)
                root = tree.getroot()

                tx_name = tx_ac[tx_ac.find('t'):]
                tx_elements = root.findall("./fixed_annotation/transcript")
                no_match = True
                for tx_element in tx_elements:
                    if tx_name == tx_element.attrib['name']:    # match
                        no_match = False

                        # setup CDS coords - convert from LRG coords to interbase (subtract 1 from start)
                        exons = tx_element.findall('./exon')
                        cds_start_end_tag = tx_element.find('./coding_region/coordinates')
                        cds_start_lrg = int(cds_start_end_tag.attrib['start'])
                        cds_end_lrg = int(cds_start_end_tag.attrib['end'])

                        match_start = match_end = False
                        for exon in exons:
                            exon_coord_lrg = exon.find("./coordinates[@coord_system='{}']".format(lrg_root))
                            exon_lrg_start = int(exon_coord_lrg.attrib['start'])
                            exon_lrg_end = int(exon_coord_lrg.attrib['end'])
                            exon_coord_tx = exon.find("./coordinates[@coord_system='{}']".format(tx_ac))
                            exon_tx_start = int(exon_coord_tx.attrib['start'])

                            if exon_lrg_start <= cds_start_lrg <= exon_lrg_end:
                                match_start = True
                                cds_start_i = cds_start_lrg - exon_lrg_start + exon_tx_start - 1

                            if exon_lrg_start <= cds_end_lrg <= exon_lrg_end:
                                match_end = True
                                cds_end_i = cds_end_lrg - exon_lrg_start + exon_tx_start

                            if match_start and match_end:
                                break

                        if not (match_start and match_end): # should never happen
                            print "ERROR: failed to find cds start and/or end for accession {}".format(tx_ac)
                            break

                        tx_cds_se_i = '{},{}'.format(cds_start_i, cds_end_i)

                        # setup exon coords - convert to interbase (subtract 1 from start)
                        exon_coords = []
                        for exon in exons:
                            exon_coord = exon.find("./coordinates[@coord_system='{}']".format(tx_ac))
                            exon_coords.append(((int(exon_coord.attrib['start'])) - 1, int(exon_coord.attrib['end'])))

                        exon_coords = sorted(exon_coords, key=operator.itemgetter(0))
                        tx_exon_se_i = ";".join(["{},{}".format(start, end) for (start, end) in exon_coords])

                        # finally - output result
                        tx_row = uta.formats.txinfo.TxInfo('LRG', tx_ac, tx_acs_hgnc[tx_ac],
                                                             tx_cds_se_i, tx_exon_se_i)
                        tw.write(tx_row)

                        break

                if no_match:
                    print "WARNING: tx {} not found for accession {}".format(tx_name, tx_ac)

            else:
                print "WARNING: xml file {} not found".format(xml_filepath)


if __name__ == "__main__":
    main()

