import networkx as nx
import numpy as np
from pandas import DataFrame
import toolz as tz


def find_main_branch_nx(g: nx.Graph, weight='branch-distance', in_place=True):
    """Find longest shortest paths in g and annotate edges on the path."""
    if not in_place:
        g = g.copy()
    for conn in nx.connected_components(g):
        curr_val = 0
        curr_pair = None
        h = g.subgraph(conn)
        p = dict(nx.all_pairs_dijkstra_path_length(h, weight=weight))
        for src in p:
            for dst in p[src]:
                val = p[src][dst]
                if val is not None and np.isfinite(val) and val >= curr_val:
                    curr_val = val
                    curr_pair = (src, dst)
        for i, j in tz.sliding_window(2, nx.shortest_path(h,
                                                          source=curr_pair[0],
                                                          target=curr_pair[1],
                                                          weight=weight)):
            g.edges[i, j]['main'] = True
    return g


def find_main_branches(summary: DataFrame) -> np.ndarray:
    """Predict the extent of branching.

    Parameters
    ----------
    summary : pd.DataFrame
        The summary table of the skeleton to analyze.
        This must contain: ['node-id-src', 'node-id-dst', 'branch-distance']

    Returns
    -------
    is_main: array
       True if the index-matched path is the longest shortest path of the
       skeleton
    """
    is_main = np.zeros(summary.shape[0], dtype=bool)
    us = summary['node_id_src']
    vs = summary['node_id_dst']
    ws = summary['branch_distance']

    edge2idx = {(u, v): i for i, (u, v) in enumerate(zip(us, vs))}

    edge2idx.update({(v, u): i for i, (u, v) in enumerate(zip(us, vs))})

    g = nx.Graph()

    g.add_weighted_edges_from(zip(us, vs, ws))

    h = find_main_branch_nx(g)
    for i, j in h.edges():
        is_main[edge2idx[(i, j)]] = 1

    return is_main
