#!/usr/bin/env python

from __future__ import division, print_function, unicode_literals

__doc__ = """write """
__version__ = '1.0.0'

import os, sys
from filecache import filecache, WEEK
import logging

import IPython
import sqlalchemy.exc
import sqlsoup

from locus.core.exceptions import LocusNCBIError
import locus.ncbi.gene
import locus.ncbi.refseq
import locus.ncbi.helpers as lnh


def _strand_as_signed_int(plus_minus):
    if plus_minus == 'plus': return 1
    if plus_minus == 'minus': return -1
    raise RuntimeError('strand "'+plus_minus+'" not understood')

def load_transcript(tx_db,hgnc,g,nm):
    @filecache(WEEK)
    def _refseq(nm):
        return lnh.efetch_nuccore_by_ac(nm)

    t = locus.ncbi.refseq.RefSeq( _refseq(nm) )
    t_exons = t.exons
    t_ex_names = t.exon_names
    assert len(t_exons) == len(t_ex_names), nm+": exon <s,e> list and exon name list have different numbers of elemens"

    g_exons = g.grch37p10_product_exons(nm)

    logging.debug('loading %s, gene %s, %d g/%d tx exons' % (
            nm, hgnc, len(g_exons), len(t_exons)))

    # load transcript
    cds_start_i,cds_end_i = t.cds_start_end_i
    tx_db.transcript.insert( ac = nm, gene = hgnc, cds_start_i = cds_start_i,
                             cds_end_i = cds_end_i, seq = t.seq )
    tx_db.commit()

    # genomic exons
    for i,se in enumerate(g_exons):
        tx_db.genomic_exon.insert(ac = nm, start_i = se[0], end_i = se[1], ord = i+1)

    # transcript exons
    for i,sen in enumerate(zip(t_exons,t_ex_names)):
        tx_db.transcript_exon.insert(ac = nm,  name = sen[1], ord = i+1,
                                     start_i = sen[0][0], end_i = sen[0][1])

    return (len(g_exons), len(t_exons))


if __name__ == '__main__':
    @filecache(WEEK)
    def _fetch_gene(hgnc):
        return lnh.efetch_gene_by_hgnc_name(hgnc)

    logging.basicConfig(level=logging.DEBUG)

    # database schema defined in db.sql
    # 'postgresql:///' == local socket connection
    tx_db = sqlsoup.SQLSoup('postgresql:///reece')

    existing_genes = [ r[0] for r in 
                       tx_db.connection().execute('select gene from gene').fetchall() ]
    existing_nm_gene = dict(tx_db.connection().execute('select ac,gene from transcript').fetchall())
    logging.info('%d existing genes, %d existing transcripts' % (
            len(existing_genes), len(existing_nm_gene)))

    for hgnc in sys.argv[1:]:
        tx_db.rollback()

        try:
            g = locus.ncbi.gene.Gene( _fetch_gene(hgnc) )
            if hgnc in existing_genes:
                logging.info('gene %s: already in database' % hgnc)
            else:
                m = g.grch37p10_mapping()
                tx_db.gene.insert(
                    gene = hgnc,
                    chr = m['chr'], strand = _strand_as_signed_int(m['strand']),
                    start_i = m['start_i'], end_i = m['end_i'], maploc = g.maploc,
                    descr = g.desc, summary = g.summary
                    )
                tx_db.commit()
                logging.info('gene %s: loaded' % hgnc)
        except (LocusNCBIError) as e:
            tx_db.rollback()
            logging.warn('gene %s: %s' % (hgnc,str(e)))
            continue

        try:
            products = g.grch37p10_products()
        except (LocusNCBIError) as e:
            logging.warn('gene %s: %s' % (hgnc,str(e)))
            continue

        nms = set([ p for p in products if p.startswith('NM_') ])
        if len(nms) == 0:
            logging.warn("Gene {gene} has no NMs ({np} total)".format(
                    gene = hgnc, np = len(products)))
        logging.debug('loading Gene %s, %d tx: %s' % (hgnc, len(nms), ', '.join(sorted(nms)))) 
        for nm in nms:
            if nm in existing_nm_gene:
                if hgnc != existing_nm_gene[nm]:
                    tx_db.connection().execute('update transcript set gene=%s where ac=%s',hgnc,nm)
                    tx_db.commit()
                    logging.warn(nm + ": associated with %s in db, ncbi says %s; updated" % (existing_nm_gene[nm],hgnc))
                else:
                    logging.debug(nm + ": transcript already in database; skipping")
                continue
            try:
                tx_db.connection().execute('SET CONSTRAINTS all DEFERRED')
                n_g_exons, n_t_exons = load_transcript(tx_db,hgnc,g,nm)
                tx_db.commit()
                logging.info("Gene %s, transcript %s: loaded %d genomic/%d transcript exons" %
                             (hgnc, nm, n_g_exons, n_t_exons) )
            except (AssertionError,LocusNCBIError,sqlalchemy.exc.SQLAlchemyError) as e:
                tx_db.rollback()
                logging.warn('Gene %s, transcript %s: %s' % (hgnc,nm,str(e)))

        loaded_nms = set([ r[0] for r in 
                          tx_db.connection().execute('select ac from transcript where gene=%s',hgnc).fetchall() ])
        missing_nms = nms - loaded_nms
        if len(missing_nms) > 0:
            logging.error('gene %s: missing %s' % (hgnc, ','.join(sorted(missing_nms))))

## <LICENSE>
## Copyright 2014 UTA Contributors (https://bitbucket.org/invitae/uta)
## 
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
## 
##     http://www.apache.org/licenses/LICENSE-2.0
## 
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
## </LICENSE>
