"""Example to run a batch of simlations on SHED data."""

__author__ = 'Jarrod Pas <j.pas@usask.ca>'

import sys
from argparse import ArgumentParser
from collections import namedtuple
from multiprocessing import Pool

from pydtn import Network, RandomTraffic, Node, EpidemicNode
from pydtn.community import BubbleNode, HCBFNode
from pydtn.community import LouvainCommunity, KCliqueCommunity
from pydtn.shed import ShedTrace


Simulation = namedtuple('Simulation', ['trace', 'node_type', 'seed'])


class LouvainNode(Node):
    def __init__(self, **options):
        options['community'] = LouvainCommunity(options['epoch'])
        super().__init__(**options)


class LouvainHCBFNode(LouvainNode, HCBFNode):
    pass


class LouvainBubbleNode(LouvainNode, BubbleNode):
    pass


class KCliqueNode(Node):
    def __init__(self, **options):
        options['community'] = KCliqueCommunity(options['epoch'], options['k'])
        super().__init__(**options)


class KCliqueHCBFNode(KCliqueNode, HCBFNode):
    pass


class KCliqueBubbleNode(KCliqueNode, BubbleNode):
    pass


def run_simulation(simulation):
    """Run a simulation."""
    seed = simulation.seed

    trace = ShedTrace(simulation.trace)

    epoch = 7*24*60*60  # 7 days

    node_type = simulation.node_type
    node_options = {
        'tick_rate': 5 * 60,  # 5 mins
        'epoch': epoch,
        'k': 3,
    }
    nodes = {
        node_id: simulation.node_type(**node_options)
        for node_id in range(trace.nodes)
    }

    traffic_options = {
        'seed': seed,
        'start': epoch,
        'step': 1 * 60,  # 1 packet every 1 mins
    }
    traffic = RandomTraffic(nodes, **traffic_options)

    network = Network(nodes, traffic=traffic, trace=trace)
    network.run()

    stats = {
        'trace': trace.path,
        'node_type': node_type.__name__,
        'seed': seed,
    }
    stats.update(network.stats_summary)

    # return stats because we can't pickle the network as it is a generator.
    return stats


def main(args):
    """Run simulation for each seed in args."""
    trace = args['shed']
    pool = Pool()
    simulations = []
    node_types = [
        Node,
        EpidemicNode,
        KCliqueBubbleNode,
        KCliqueHCBFNode,
        LouvainBubbleNode,
        LouvainHCBFNode,
    ]

    for seed in args['seeds']:
        for node_type in node_types:
            sim = Simulation(trace=trace, node_type=node_type, seed=seed)
            simulations.append(sim)

    for stats in pool.imap_unordered(run_simulation, simulations):
        print(stats)


def parse_args(args):
    """Parse arguments."""
    parser = ArgumentParser()

    parser.add_argument('shed')
    parser.add_argument('--seeds', '-s',
                        metavar='SEED', type=int, nargs='+', default=[None])

    args = parser.parse_args(args)
    return vars(args)


if __name__ == '__main__':
    exit(main(parse_args(sys.argv[1:])))