import csv
import sys
from argparse import ArgumentParser
from collections import defaultdict, namedtuple
from itertools import groupby, count
from multiprocessing import Pool

from pydtn import Network, random_traffic, Node, EpidemicNode, Contact
from pydtn.community import BubbleNode, HCBFNode, LouvainCommunity


class ShedTrace:
    def __init__(self, path, slot_size=300):
        self.path = path
        self.slot_size = slot_size

        pairs = defaultdict(set)
        with open(path) as slots:
            reader = csv.reader(slots)
            next(reader)
            for row in reader:
                _, source, _, target, _, slot = row
                pair = min(source, target), max(source, target)
                slot = int(slot)
                pairs[pair].add(slot)

        node = count()
        nodes = {}
        self.contacts = []
        for (source, target), slots in pairs.items():
            if source not in nodes:
                nodes[source] = next(node)
            source = nodes[source]

            if target not in nodes:
                nodes[target] = next(node)
            target = nodes[target]

            slots = sorted(slots)

            # groups consecutive slots
            # if the lambda is mapped it will return:
            # [1, 2, 3, 6, 7, 9] -> [-1, -1, -1, -3, -3, -4]
            for _, group in groupby(enumerate(slots), lambda p: p[0]-p[1]):
                times = list(map(lambda g: g[1], group))
                start = times[0] * self.slot_size
                end = (times[-1] + 1) * self.slot_size

                self.contacts.append(Contact(start, source, target, True))
                self.contacts.append(Contact(end, source, target, False))

        self.contacts.sort()
        self.nodes = len(nodes)

    def __iter__(self):
        return iter(self.contacts)


Task = namedtuple('Task', ['trace', 'node_type', 'seed'])


def run_task(task):
    seed = task.seed

    trace = task.trace
    epoch = 7*24*60*60  # 7 days

    node_type = task.node_type
    node_options = {
        'tick_rate': 5 * 60,  # 5 mins
        'community': LouvainCommunity(epoch),
    }

    traffic_speed = 30 * 60  # 1 packet every 30 mins

    nodes = {
        node_id: task.node_type(**node_options)
        for node_id in range(trace.nodes)
    }

    traffic = random_traffic(nodes,
                             start=epoch,
                             speed=traffic_speed,
                             seed=seed)

    network = Network(nodes, traffic=traffic, trace=trace)
    network.run()

    stats = {
        'trace': trace.path,
        'node_type': node_type.__name__,
        'seed': seed,
    }
    stats.update(network.stats_summary)
    return stats


def main(args):
    trace = ShedTrace(args['shed'])
    pool = Pool()
    tasks = []

    for seed in args['seeds']:
        for node_type in [Node, EpidemicNode, BubbleNode, HCBFNode]:
            tasks.append(Task(trace=trace, node_type=node_type, seed=seed))

    for stats in pool.imap_unordered(run_task, tasks):
        print(stats)


def parse_args(args):
    parser = ArgumentParser()

    parser.add_argument('shed')
    parser.add_argument('--seeds', '-s',
                        metavar='SEED', type=int, nargs='+', default=[None])

    args = parser.parse_args(args)
    return vars(args)


if __name__ == '__main__':
    exit(main(parse_args(sys.argv[1:])))