#!/usr/bin/env python

# N.B. This script currently returns
# Exception TypeError: "'NoneType' object is not callable" in  ignored
# four times on every exectution. This results from a bug in python
# described here:
# http://re.runcode.us/q/really-weird-issue-with-shelve-python


from __future__ import division, print_function, unicode_literals

__doc__ = """load NCBI genes and/or transcripts into UTA"""
__version__ = '1.0.0'

import os, sys
import filecache
import logging

import psycopg2
import psycopg2.extras

from locus.core.exceptions import LocusNCBIError
import locus.ncbi.gene
import locus.ncbi.refseq
import locus.ncbi.helpers as lnh


@filecache.filecache(filecache.MONTH)
def _fetch_gene(hgnc):
    """return raw XML for NCBI gene record, cached"""
    return lnh.efetch_gene_by_hgnc_name(hgnc)


@filecache.filecache(filecache.MONTH)
def _fetch_refseq(nm):
    """return raw XML for NCBI refseq record, cached"""
    return lnh.efetch_nuccore_by_ac(nm)


def _strand_as_signed_int(plus_minus):
    if plus_minus == 'plus': return 1
    if plus_minus == 'minus': return -1
    raise RuntimeError('strand "'+plus_minus+'" not understood')


def _gen_upsert_stmts(tbl,pks,cols):
    """return update/insert pair.  Assumes cols[0] is primary key"""
    upd = 'UPDATE %s SET %s WHERE %s' % (
        tbl,
            ','.join([ '%s=%%(%s)s' % (c,c) for c in cols ]),
        ' AND '.join([ '%s=%%(%s)s' % (c,c) for c in pks  ]),
        )
    ins = 'INSERT INTO %s (%s) VALUES (%s)' % (
        tbl,
        ','.join(pks+cols),
        ','.join([ '%('+c+')s' for c in pks+cols ])
        )
    return upd,ins


def _upsert(cur,tbl,cols,data):
    if isinstance(cols[0],list):
        tbl_pks, tbl_cols = cols[0], cols[1:]
    else:
        tbl_pks, tbl_cols = cols[0:1], cols[1:]
    upd,ins = _gen_upsert_stmts(tbl,tbl_pks,tbl_cols)
    cur.execute(upd, data)
    if cur.rowcount == 0:
        cur.execute(ins, data)
    return


def _upsert_gene(cur,g):
    """load given locus.ncbi.gene.Gene"""
    m = g.grch37p10_mapping()
    data = {
        'gene': g.hgnc,
        'chr': m['chr'],
        'strand': _strand_as_signed_int(m['strand']),
        'start_i': m['start_i'],
        'end_i': m['end_i'],
        'maploc': g.maploc,
        'descr': g.desc,
        'summary': g.summary,
        }
    _upsert(cur, 'gene', ['gene','chr','strand','start_i','end_i','maploc','descr','summary'], data)
    logging.debug('gene %s: inserted/updated' % g.hgnc)
    return


def _upsert_transcript(cur,t,g):
    """load given locus.ncbi.refseq.RefSeq"""
    g_exons = g.grch37p10_product_exons(t.acv)

    t_exons = t.exons
    t_ex_names = t.exon_names
    assert len(t_exons) == len(t_ex_names), t.acv+": exon <s,e> list and exon name list have different numbers of elements"

    # load transcript
    cds_start_i,cds_end_i = t.cds_start_end_i
    _upsert(cur, 'transcript', ['ac', 'gene', 'cds_start_i', 'cds_end_i', 'seq'], {
            'ac': t.acv,
            'gene': g.hgnc,
            'cds_start_i': cds_start_i,
            'cds_end_i': cds_end_i,
            'seq': t.seq
            })

    # genomic exons
    for i,se in enumerate(g_exons):
        _upsert(cur,'genomic_exon', [['ac','ord'], 'start_i', 'end_i'], {
                'ac': t.acv,
                'start_i': se[0],
                'end_i': se[1],
                'ord': i+1
                })

    # transcript exons
    for i,sen in enumerate(zip(t_exons,t_ex_names)):
        # sen = [[start,end],name]
        _upsert(cur, 'transcript_exon', [ ['ac', 'ord'], 'name', 'start_i', 'end_i'], {
                'ac': t.acv, 
                'name': sen[1], 
                'ord': i+1,
                'start_i': sen[0][0],
                'end_i': sen[0][1]
                })

    logging.debug('upserted transcript %s' % t.acv)
    return


def load_transcript_and_gene(conn,nm,g=None):
    """load transcript and, if needed, relevant gene"""
    t = locus.ncbi.refseq.RefSeq( _fetch_refseq(nm) )
    assert nm == t.acv, "asked for transcript %s, got %s" % (nm,t.acv)
    if len(t.gene) > 1:
        raise RuntimeError('%s is associated with %d genes (%s)' % (nm,len(t.gene),', '.join(t.gene)))
    if g is None:
        g = locus.ncbi.gene.Gene( _fetch_gene(t.gene[0]) )
    
    cur = conn.cursor()
    cur.execute('SET CONSTRAINTS ALL DEFERRED')

    try:
        _upsert_gene(cur, g)
        conn.commit()
    except psycopg2.IntegrityError as e:
        if 'duplicate key' not in str(e):
            raise
        conn.rollback()
        pass

    try:
        _upsert_transcript(cur, t, g)
        conn.commit()
        logging.info( 'committed transcript %s, gene %s' % (t.acv,g.hgnc) )
    except psycopg2.IntegrityError as e:
        if 'duplicate key' not in str(e):
            raise
        conn.rollback()
    except (AssertionError,LocusNCBIError) as e:
        conn.rollback()
        logging.warn('Transcript %s: %s' % (nm,str(e)))

    return (t.acv, g.hgnc)


def load_gene_and_transcripts(conn,hgnc):
    """load gene and all transcripts for it"""
    cur = conn.cursor()
    cur.execute('SET CONSTRAINTS ALL DEFERRED')

    g = locus.ncbi.gene.Gene( _fetch_gene(hgnc) )
    products = g.grch37p10_products()
    nms = set([ p for p in products if p.startswith('NM_') ])
    if len(nms) == 0:
        logging.warn("Gene {gene} has no NMs ({np} total)".format(
                gene = hgnc, np = len(products)))

    try:
        _upsert_gene(cur,g)
        for nm in nms:
            t = locus.ncbi.refseq.RefSeq( _fetch_refseq(nm) )
            _upsert_transcript(cur, t, g)
        conn.commit()
    except (AssertionError,LocusNCBIError) as e:
        conn.rollback()
        logging.warn('Gene %s: %s' % (hgnc,str(e)))

    logging.info('upserted gene %s w/%d transcripts (%s)' % (hgnc, len(nms), ', '.join(sorted(nms)))) 
        
    
############################################################################

class DummyException(Exception):
    pass

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)

    # database schema defined in db.sql
    # 'postgresql:///' == local socket connection
    conn = psycopg2.connect(database='reece')

    for arg in sys.argv[1:]:
        try:
            if arg.startswith('NM_'):
                load_transcript_and_gene(conn,arg)
            else:
                load_gene_and_transcripts(conn,arg)
        except DummyException as e:
            logging.error('Transcript/Gene '+arg+': '+str(e))

## <LICENSE>
## Copyright 2014 UTA Contributors (https://bitbucket.org/invitae/uta)
## 
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
## 
##     http://www.apache.org/licenses/LICENSE-2.0
## 
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
## </LICENSE>
