from collections import defaultdict
from itertools import product

import networkx as nx
from kids.cache import cache

from pydyton.core import TickProcess

class EpochCommunity(TickProcess):
    def __init__(self, epoch, **kwargs):
        super().__init__(epoch)

        self.graph = nx.Graph()
        self.old_graph = None
        self.community = {}

    def set_link(self, a, b, state, now):
        if a not in self.graph:
            self.graph.add_node(a)

        if b not in self.graph:
            self.graph.add_node(b)

        if b not in self.graph[a]:
            self.graph.add_edge(a, b, { 'start': -1 })

        edge = self.graph[a][b]
        if state:
            edge['start'] = now
            if 'duration' not in edge:
                edge['duration'] = 0
        else:
            edge['duration'] = now - edge['start']
            edge['start'] = -1

    def next_epoch(self, now):
        self.community = {}
        edges_to_keep = []
        self.cache_clear()

        for a, b, start in self.graph.edges(data='start'):
            if start > -1:
                self.set_link(a, b, False, now)
                edges_to_keep.append((a, b))

        self.old_graph = self.graph
        self.graph = nx.Graph()
        for a, b in edges_to_keep:
            self.set_link(a, b, True, now)

        return self.old_graph

    def __getitem__(self, node):
        if node not in self.community:
            self.community[node] = frozenset([node])
        return self.community[node]

    def cache_clear(self):
        self.get_ld.cache_clear()
        self.get_lp.cache_clear()
        self.get_ui.cache_clear()
        self.get_gp.cache_clear()
        self.get_ncf.cache_clear()
        self.get_cbc.cache_clear()

    @cache
    def get_ld(self, x):
        ''''''
        g = self.old_graph
        if x not in g:
            return []
        c_x = self[x]
        return [
            g[x][y]['duration']
            for y in c_x
            if y in g[x] and g[x][y]['duration'] > 0
        ]

    @cache
    def get_lp(self, x):
        '''local popularity of a node'''
        return sum(self.get_ld(x))

    @cache
    def get_gp(self, node):
        '''global popularity of a node'''
        if node not in self.old_graph:
            return 0
        edges = self.old_graph[node]
        community = self[node]
        return sum([
            edge['duration']
            for other, edge in edges.items()
            if other not in community
        ])

    @cache
    def get_ui(self, x):
        '''unique interactions with a node'''
        return len(self.get_ld(x))

    @cache
    def get_cbc(self, c_x, c_y):
        ''''''
        if c_x == c_y:
            return float('inf')
        g = self.old_graph
        return sum([
            g[x][y]['duration']
            for x, y in product(c_x, c_y)
            if x in g and y in g[x]
        ])

    @cache
    def get_ncf(self, x, c_y):
        ''''''
        g = self.old_graph
        if x not in c_y or x not in g:
            return 0
        return sum([
            g[x][y]['duration']
            for y in c_y
            if y in g[x]
        ])