from collections import defaultdict
from itertools import product

import networkx as nx
from community import best_partition as louvain_partition

from .core import TickProcess

class EpochCommunity(TickProcess):
    def __init__(self, epoch=1, **kwargs):
        super().__init__(epoch)

        self.graph = nx.Graph()
        self.old_graph = None
        self.community = defaultdict(frozenset)
        self._cbc_memo = {}

    def set_link(self, a, b, state, now):
        if a not in self.graph:
            self.graph.add_node(a)

        if b not in self.graph:
            self.graph.add_node(b)

        if b not in self.graph[a]:
            self.graph.add_edge(a, b, { 'start': -1 })

        edge = self.graph[a][b]
        if state:
            edge['start'] = now
            if 'duration' not in edge:
                edge['duration'] = 0
        else:
            edge['duration'] = now - edge['start']
            edge['start'] = -1

    def next_epoch(self, now):
        self.community = defaultdict(frozenset)
        edges_to_keep = []
        self._cbc_memo = {}

        for a, b, start in self.graph.edges(data='start'):
            if start > -1:
                self.set_link(a, b, False, now)
                edges_to_keep.append((a, b))

        self.old_graph = self.graph
        self.graph = nx.Graph()
        for a, b in edges_to_keep:
            self.set_link(a, b, True, now)

        return self.old_graph

    def __getitem__(self, node):
        return self.community[node]

    def get_lp(self, node):
        '''local popularity of a node'''
        if node not in self.old_graph:
            return 0

        edges = self.old_graph[node]
        community = self[node]
        return sum([
            edge['duration']
            for other, edge in edges.items()
            if other in community
        ])

    def get_gp(self, node):
        '''global popularity of a node'''
        if node not in self.old_graph:
            return 0

        edges = self.old_graph[node]
        community = self[node]
        return sum([
            edge['duration']
            for other, edge in edges.items()
            if other not in community
        ])

    def get_ui(self, node):
        '''unique interactions with a node'''
        if node not in self.old_graph:
            return 0

        edges = self.old_graph[node]
        community = self[node]
        return len([
            other
            for other in edges
            if other in community
        ])

    def get_cbc(self, a, b):
        g = self.old_graph
        c_x = self[a]
        c_y = self[b]
        memo = (c_x, c_y)

        if a not in g or b not in g or c_x == c_y:
            return 0

        if memo in self._cbc_memo:
            return self._cbc_memo[memo]

        cbc = sum([
            g[x][y]['duration']
            for x, y in product(c_x, c_y)
            if y in g[x]
        ])

        '''
        for x in c_x:
            for y in c_y:
                if x in g[y]:
                    cbc += g[x][y]['duration']
        '''

        self._cbc_memo[memo] = cbc
        self._cbc_memo[(memo[1], memo[0])] = cbc
        return cbc

    def get_ncf(self, x, b):
        g = self.old_graph
        c_y = self[b]

        if x not in g or b not in g:
            return 0

        return sum([
            g[x][y]['duration']
            for y in c_y
            if y in g[x]
        ])


class KCliqueCommunity(EpochCommunity):
    def __init__(self, k=3, threshold=300, epoch=604800, **kwargs):
        super().__init__(epoch=epoch, **kwargs)
        self.k = k
        self.threshold = threshold

    def process(self, network):
        while True:
            yield self.env.timeout(self.tick)
            g = self.next_epoch(env.now)

            G = nx.Graph()
            G.add_nodes_from(g.nodes())
            for a, b, duration in g.edges(data='duration'):
                if duration > self.threshold:
                    G.add_edge(a, b)

            for community in nx.k_clique_communities(G, self.k):
                for node in community:
                    self.community[node] = community


class LouvainCommunity(EpochCommunity):
    def __init__(self, epoch=604800, **kwargs):
        super().__init__(epoch=epoch, **kwargs)

    def process(self, network):
        while True:
            yield self.env.timeout(self.tick)
            g = self.next_epoch(env.now)

            p = louvain_partition(g, weight='duration')
            communities = defaultdict(set)
            for node, c in louvain_partition(g, weight='duration').items():
                communities[c].add(node)

            for community in communities.values():
                community = frozenset(community)
                for node in community:
                    self.community[node] = community


def none(**kwargs):
    return None


types = {
    'kclique': KCliqueCommunity,
    'louvain': LouvainCommunity,
    'none': none,
}