#!/usr/bin/env python

from __future__ import division, print_function, unicode_literals

__doc__ = """generates data for agreement/discordance between refseq
transcripts and reference genome

* NM_000022.2   ADA -20q13.12   12  True    True    1   1X/0I/0D
    20  43264927    43264927    NM_000022.2:c.36    3   60  60  X   C   T   http://1.usa.gov/KXlHan
* NM_000035.3   ALDOB   -9q21.3-q22.2   9   True    True    0   0X/0I/0D
"""
__version__ = '0.0.0'

############################################################################

import collections
import logging, os, pprint, sys, time
import psycopg2, psycopg2.extras
import locus.tools.refagree as ltr
import bitlyapi

import IPython

############################################################################

_Mismatch = collections.namedtuple('Mismatch', [
        'chr', 'strand', 
        'g_start_i', 'g_end_i', 'g_aseq',
        'ac', 'exon', 'e_start_i', 'e_end_i',
        't_start_i', 't_end_i', 't_aseq',
        'c_start', 'c_end',
        'type', 
        ])
class Mismatch(_Mismatch):
    @property
    def seqviewer_url(self):
        return lug.url_for_slice(self.chr, self.g_start_i - 50, self.g_end_i + 50)
    @property
    def hgvsc_pos(self):
        r = self.ac + ':c.' + self.c_start
        if self.c_start != self.c_end:
            r += '_' + self.c_end
        return r


def fetch_acvs(cur):
    cur = conn.cursor()
    cur.execute('select ac from transcripts.transcript')
    return [ e[0] for e in cur ]

def fetch_gene_info_for_transcript(cur,acv):
    sql = """
select distinct G.chr,G.strand,G.start_i,G.end_i,G.gene,G.maploc,G.descr,G.summary
from gene G
join transcript T on G.gene=T.gene
where T.ac = %(acv)s
"""
    cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
    cur.execute(sql, {'acv': acv})
    assert cur.rowcount == 1, 'fetched %d genes for %s' % (cur.rowcount,acv)
    return dict(cur.fetchone())

def fetch_transcript_info(cur,acv):
    cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
    sql = """
SELECT ac, chr, cds_start_i, cds_end_i, exists (select * from transcript_exon TE join gtx_alignment GA on TE.transcript_exon_id=GA.transcript_exon_id and GA.cigar~'[XDI]' where TE.ac=T.ac) as mismatches_p, seq
FROM transcripts.transcript T
JOIN gene G on T.gene=G.gene
WHERE T.ac=%(acv)s
"""
    cur.execute( sql, {'acv': acv} )
    assert cur.rowcount <= 1, 'fetched %d transripts for %s' % (cur.rowcount,acv)
    return dict(cur.fetchone()) if cur.rowcount == 1 else None

def fetch_transcript_exon_coords(conn,acv):
    cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
    sql = """SELECT E.start_i,E.end_i FROM transcript_exon E WHERE E.ac=%(acv)s ORDER BY start_i;"""
    cur.execute( sql, {'acv': acv} )
    return cur.fetchall()

def fetch_genome_exon_coords(conn,acv,assy=None):
    """return genome interbase coordinates of exons in transcript order; start<end regardless of strand"""
    cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
    sql = """
SELECT G.gene,T.ac,G.chr,G.strand,GE.start_i,GE.end_i
FROM transcript T
JOIN genomic_exon GE on T.ac=GE.ac
JOIN gene G ON T.gene=G.gene
WHERE T.ac=%(acv)s
ORDER BY T.ac,GE.start_i*G.strand
"""
    cur.execute( sql, {'acv': acv} )
    return cur.fetchall()



def header(src=None):
    return """# transcript-genome agreement
# run on: {ts}
# accession source: {src}
# Records consist of a summary line followed by 0 or more mismatch lines:
#
# * <ac> \\t <locus> \\t <strand><maploc> \\t <#exons> \\t <n_ex_eq?> \\t <ex_len_eq?> \\t <n_mismatches> \\t <mismatch_summary>
# \\t <chr> \\t <gstart> \\t <gend> \\t <hgvsc_pos> \\t <exon#> \\t <exon_start> \\t <exon_end> \\t <type> \\t <g_allele> \\t <t_allele> \\t <seqviewer_url>
#
# \\t == tab
# Lines starting with # are comments. Empty lines are ignored.
# All coordinates and indexes are human (1-based, inclusive).
# Mismatch key (a la cigar): X=sub, I=ins, D=del
# This would have been XML if humans weren't involved. Darn humans. 
""".format(src=src,ts=time.strftime('%F %TZ',time.gmtime()))


def build_exon_disrepancies( gene_info, transcript_info, exon_map ):
    """return a list of Mismatch records that reports contiguous regions
    of genomic mismatches. Mismatch records include genomic, cds, and exon
    coordinates."""
    def _cdot(cds_start_i,cds_end_i,r_i):
        """convert "transcript coord" (rel to exon 0, pos 0) to
        cds position, 1 based, with possible * coord after stop"""
        c_i = r_i - cds_start_i
        if c_i > cds_end_i:
            return '*' + str(r_i-cds_end_i+1)
        else:
            return str(c_i+1)

    acv = transcript_info['accession']
    #logging.debug('{acv}: exon mapping={map}'.format(acv=acv,map=pprint.pformat(exon_map)))
    mms = []
    t_ex_cum_len = 0
    for i,em in enumerate(exon_map):
        #logging.debug('{acv}: aligning exon {i}'.format(acv=acv, i=i))
        g_ex_start_i, g_ex_end_i, t_ex_start_i, t_ex_end_i = em
        g_ex_seq = lug.fetch_genomic_sequence_interval(
            transcript_info['chromosome'],g_ex_start_i,g_ex_end_i)
        t_ex_seq = transcript_info['sequence'][t_ex_start_i:t_ex_end_i]

        e_off, e_dir = 0, 1
        if gene_info['strand'] == -1:
            t_ex_seq = Bio.Seq.Seq(t_ex_seq).reverse_complement().tostring()
            e_off, e_dir = len(t_ex_seq), -1

        g_ex_aseq,t_ex_aseq = lug.align2(g_ex_seq.upper(),t_ex_seq.upper())
        assert len(g_ex_aseq)==len(t_ex_aseq), "global alignment should be same length"
        # TODO: add CDS position; consider a namedtuple structure
        # need genomic position

        cv = lug.alignment_cigar_vector(g_ex_aseq,t_ex_aseq)
        #logging.debug(acv + ': cigar=' + pprint.pformat(cv))
        for mv in lug.alignment_cigar_vector(g_ex_aseq,t_ex_aseq):
            if mv[1] == 'M':    # skip matches
                continue
            e_start_i = mv[0]
            e_end_i   = mv[0]+mv[2]
            g_start_i = g_ex_start_i + e_start_i
            g_end_i   = g_ex_start_i + e_end_i
            t_start_i = t_ex_cum_len + e_off + e_dir * e_start_i
            t_end_i   = t_ex_cum_len + e_off + e_dir * e_end_i
            if gene_info['strand'] == -1:
                t_start_i,t_end_i = t_end_i,t_start_i
            m = Mismatch(
                type = mv[1],
                ac = acv,
                strand = gene_info['strand'],
                chr = transcript_info['chromosome'], 
                exon = i + 1, 
                e_start_i = e_start_i, 
                e_end_i = e_end_i,
                g_start_i = g_start_i,
                g_end_i = g_end_i, 
                g_aseq = g_ex_aseq[e_start_i:e_end_i],
                t_start_i = t_start_i, 
                t_end_i = t_end_i, 
                t_aseq = t_ex_aseq[e_start_i:e_end_i], # keep + stranded
                # N.B. c_start,c_end are 1-based, inclusive
                c_start = _cdot(transcript_info['cds_start_i'],transcript_info['cds_end_i'],t_start_i),
                c_end = _cdot(transcript_info['cds_start_i'],transcript_info['cds_end_i'],t_end_i-1)
                )
            mms += [m]
        t_ex_cum_len += len(t_ex_seq)
    return mms



############################################################################

def strand_sign(strand_val):
    if strand_val is None: return None
    if strand_val == 1: return '+'
    if strand_val == -1: return '-'
    raise Exception('invalid strand value (expected None, +1, or -1)')



if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    conn = psycopg2.connect(database='reece')

    # acv = ACessions.Version (e.g., NM_0123.4)
    if len(sys.argv) > 1:
        # from command line
        acvs = sys.argv[1:]
        src = 'command line (%d accessions)' % len(acvs)
    else:
        # from database
        acvs = fetch_acvs(conn)
        acvs = [ acv for acv in acvs if acv <= 'NM_020986.3' ]
        src = 'database (%d accessions)' % len(acvs)

    acvs.sort()

    print( header(src) )

    for i,acv in enumerate(acvs):
        if i % 25 == 0:
            logging.info('%d/%d (%.1f%%)' % (i+1, len(acvs), (i+1)/len(acvs)*100))

        try:
            g_info = fetch_gene_info_for_transcript(conn,acv)
            t_info = fetch_transcript_info(conn,acv)
            g_info['chromosome'] = g_info['chr']
            t_info['chromosome'] = t_info['chr']
            t_info['accession'] = t_info['ac']
            t_info['sequence'] = t_info['seq']
            if g_info['maploc'] is None:
                g_info['maploc'] = g_info['chr']

            if t_info is None:
                raise RuntimeError('transcript not found (not in cache?)')

            g_exons = fetch_genome_exon_coords(conn,acv)
            t_exons = fetch_transcript_exon_coords(conn,acv)

            g_ex_lens = [ e['end_i']-e['start_i'] for e in g_exons ]
            t_ex_lens = [ e['end_i']-e['start_i'] for e in t_exons ]

            exon_map = ltr.make_exon_mapping( g_exons, t_exons )

            mismatches = []
            if t_info['mismatches_p']:
                # align only if transcripts database says there are mismatches
                logging.info('aligning exons for '+acv)
                mismatches = ltr.build_exon_disrepancies( gene_info=g_info,
                                                          transcript_info=t_info,
                                                          exon_map=exon_map )

            subs = [ mm for mm in mismatches if mm.type == 'X' ]
            ins =  [ mm for mm in mismatches if mm.type == 'I' ]
            dels = [ mm for mm in mismatches if mm.type == 'D' ]

            print('* ' + '\t'.join([acv, g_info['gene'],
                                    strand_sign(g_info['strand']) + g_info['maploc'],
                                    str(len(t_exons)),                 # num transcript exons
                                    str(len(g_exons) == len(t_exons)), # num exons match
                                    str(g_ex_lens == t_ex_lens),       # exon lengths match
                                    str(len(mismatches)),              # num contiguous mismatches
                                    '%dX/%dI/%dD' % (len(subs),len(ins),len(dels)),
                                    ]))
            for mm in mismatches:
                try:
                    url = 'N/A'
                    #url = mm.seqviewer_url
                except bitlyapi.bitly.APIError as e:
                    logging.warn(acv + ': ' + str(e))
                    url = 'N/A'

                print('\t'+'\t'.join([ mm.chr, 
                                       str(mm.g_start_i+1), str(mm.g_end_i),
                                       mm.hgvsc_pos,
                                       str(mm.exon+1), str(mm.e_start_i+1), str(mm.e_end_i),
                                       mm.type, mm.g_aseq, mm.t_aseq,
                                       url
                                       ]))

        except RuntimeError as e:
            print('#' + acv + ': ' + e.message)
            logging.warn(acv + ': ' + e.message)


## <LICENSE>
## Copyright 2014 UTA Contributors (https://bitbucket.org/invitae/uta)
## 
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
## 
##     http://www.apache.org/licenses/LICENSE-2.0
## 
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
## </LICENSE>
