#!/usr/bin/env python

from __future__ import division

import argparse
import gzip
import logging, logging.config
import os
import pkg_resources
import re
import sys

import Bio.SeqIO
from bdi.utils.aminoacids import seq_md5
import eutils.client

from uta.formats.exonset import ExonSet,ExonSetWriter
from uta.formats.geneinfo import GeneInfo,GeneInfoWriter
from uta.formats.txinfo import TxInfo,TxInfoWriter

transcript_origin='NCBI RefSeq'

def parse_args(argv):
    ap = argparse.ArgumentParser(
        description = __doc__,
        )
    ap.add_argument('GENES',
                    nargs='*')
    ap.add_argument('--prefix','-p',
                    default='ncbi')
    opts = ap.parse_args(argv)
    return opts


def write_exonsets_for_gene(e_gene,esw,hgnc):
    rp_combos_seen = set()
    tx_processed = set()
    for e_ref in e_gene.references:
        for e_prd in e_ref.products:
            if not e_prd.acv.startswith('NM_'):
                logger.info("skipping non-NM transcript {e_prd.acv} for {e_gene.hgnc} on {e_ref.acv}".format(
                    e_gene=e_gene,e_ref=e_ref,e_prd=e_prd))
                continue
            
            # e.g., ABCC6, (AC_000148.1,NM_001171.5) is duplicated (with
            # very different genomic coordinates). I have no idea what this means
            rp = (e_ref.acv,e_prd.acv)
            if rp in rp_combos_seen:
                logger.warning("found duplicate ref,prd pair {rp} in {hgnc}; skipping".format(
                    rp=rp, hgnc=hgnc))
                continue

            es = ExonSet(tx_ac=e_prd.acv,
                         alt_ac=e_ref.acv,
                         method='splign',
                         strand=e_prd.genomic_coords.strand,
                         exons_se_i=";".join([",".join([str(c) for c in ex]) for ex in e_prd.genomic_coords.intervals])
                )
            esw.write(es)
            tx_processed.add(e_prd.acv)
            rp_combos_seen.add(rp)

    return tx_processed


if __name__ == '__main__':
    logging_conf_fn = pkg_resources.resource_filename('uta', 'etc/logging.conf')
    logging.config.fileConfig(logging_conf_fn)
    logging.getLogger().setLevel(logging.INFO)
    logger = logging.getLogger(__name__)

    ec = eutils.client.Client()
    opts = parse_args(sys.argv[1:])

    es_fn = opts.prefix + '.exonset.gz'
    fa_fn = opts.prefix + '.fasta.gz'
    gi_fn = opts.prefix + '.geneinfo.gz'
    ti_fn = opts.prefix + '.txinfo.gz'

    if len(opts.GENES) == 0:
        logging.info('reading genes from stdin')
        hgncs = [ e.strip() for e in sys.stdin.readlines() ]
    else:
        logging.info('reading genes from stdin')
        hgncs = opts.GENES


    tx_set = set()
    n = len(hgncs)
    giw = GeneInfoWriter(gzip.open(gi_fn+'.tmp','w'))
    esw = ExonSetWriter(gzip.open(es_fn+'.tmp','w'))
    for i,hgnc in enumerate(hgncs):
        if i % 25 == 0 or i+1 == len(hgncs):
            logger.info("\n{i}/{n} ({p:.1f}%): {hgnc}...".format(
                i=i, n=n, p=(i+1)/n*100, hgnc=hgnc))

        try:
            e_gene = ec.fetch_gene_by_hgnc(hgnc)
        except Exception as e:
            logger.exception(e)
            continue

        gi = GeneInfo(hgnc=hgnc,
                      maploc=e_gene.maploc,
                      aliases=','.join(e_gene.synonyms),
                      type=e_gene.type,
                      summary=e_gene.summary,
                      descr=e_gene.description)
        giw.write(gi)

        if e_gene.type != 'protein-coding':
            logger.info("Skipping {e_gene.hgnc} (not protein coding)".format(e_gene=e_gene))
            continue

        try:
            tx_seen = write_exonsets_for_gene(e_gene,esw,hgnc)
            tx_set.update(tx_seen)
        except Exception as e:
            logger.exception(e)
            pass

        
    faw = gzip.open(fa_fn+'.tmp','w')
    tiw = TxInfoWriter(gzip.open(ti_fn+'.tmp','w'))
    n = len(tx_set)
    for i,ac in enumerate(tx_set):
        logger.info("="*70+"\n{i}/{n} ({p:.1f}%): {ac}...".format(
            i=i, n=n, p=(i+1)/n*100, ac=ac))
        try:
            e_tx = ec.fetch_nuccore_by_ac(ac)
            es = TxInfo(ac=ac,
                        origin=transcript_origin,
                        hgnc=e_tx.gene,
                        cds_se_i=",".join([str(c) for c in e_tx.cds_se_i]),
                        exons_se_i=";".join([",".join([str(c) for c in ex]) for ex in e_tx.exons_se_i])
                )
            tiw.write(es)
            faw.write(">"+e_tx.acv+"\n"+e_tx.seq+"\n")
        except Exception as e:
            logger.exception(e)
            pass

    for fn in [gi_fn,ti_fn,es_fn,fa_fn]:
        os.rename(fn+'.tmp',fn)
