"""Example to run a batch of simlations on SHED data."""

__authors__ = 'Jarrod Pas <j.pas@usask.ca>, Hunter McConnell <hunter.mcconnell@usask.ca>'

import os
from random import randint
import sys
import csv
from argparse import ArgumentParser
from collections import namedtuple
from multiprocessing import Pool
from os import path
from pprint import pprint

from pydtnsim import Network, RandomTraffic, Node, EpidemicNode, CSVTrace
from pydtnsim.community import BubbleKCliqueNode, BubbleLouvainNode
from pydtnsim.community import HCBFKCliqueNode, HCBFLouvainNode


Simulation = namedtuple('Simulation', ['trace', 'node_type', 'seed'])


def run_simulation(simulation):
    """Run a simulation."""
    seed = simulation.seed

    csv = path.join(simulation.trace, 'contact.csv')
    metadata = path.join(simulation.trace, 'metadata.json')
    trace = CSVTrace(csv, metadata=metadata)

    epoch = 7*24*60*60  # 7 days

    node_type = simulation.node_type
    node_options = {
        'tick_rate': 5 * 60,  # 5 mins
        'epoch': epoch,
        'k': 3,
        'context': {}
    }
    nodes = {
        node_id: simulation.node_type(**node_options)
        for node_id in range(trace.nodes)
    }

    traffic_options = {
        'seed': seed,
        'start': epoch,
        'step': 60 * 60,  # 1 packet every 1 mins
    }
    traffic = RandomTraffic(nodes, **traffic_options)

    network = Network(nodes, traffic=traffic, trace=trace)
    network.run()

    stats = {
        'trace': simulation.trace,
        'node_type': node_type.__name__,
        'seed': seed,
    }
    stats.update(network.stats_summary)

    # return stats because we can't pickle the network as it is a generator.
    return stats


def main(args):
    """Run simulation for each seed in args."""
    log = pprint if args['pretty'] else print
    pool = Pool()
    simulations = []

    trace = args['shed']
    node_types = [
        Node,               # direct delivery
        EpidemicNode,
        BubbleKCliqueNode,
        HCBFKCliqueNode,
        BubbleLouvainNode,
        HCBFLouvainNode,
    ]

    if args['batch'] > 1:               # batch mode with random seeds
        for _ in range(args['batch']):
            seed = randint(0, 500)
            for node_type in node_types:
                sim = Simulation(trace=trace, node_type=node_type, seed=seed)
                simulations.append(sim)
    else:
        for seed in args['seeds']:      # seed mode with inputted seeds
            for node_type in node_types:
                sim = Simulation(trace=trace, node_type=node_type, seed=seed)
                simulations.append(sim)

    results = {}

    for stats in pool.imap_unordered(run_simulation, simulations):
        if not args['quiet']:
            log(stats)
        type = stats['node_type']
        if type not in results:
            results[type] = []
        results[type].append(stats)

    # find unused filename
    i = 0
    while os.path.exists(f"results{i}.csv"):
        i += 1

    # dump sim stats in csv
    with open(f"results{i}.csv", 'w', newline='') as results_file:
        for node_type in results:
            fieldnames = results[node_type][0].keys()
            writer = csv.DictWriter(results_file, fieldnames=fieldnames)
            writer.writeheader()
            for result in results[node_type]:
                writer.writerow(result)


def parse_args(args):
    """Parse arguments."""
    parser = ArgumentParser()

    parser.add_argument('shed')
    parser.add_argument('--pretty', action='store_true')
    parser.add_argument('--quiet', '-q', action='store_true')
    parser.add_argument('--batch', '-b', 
                        metavar='BATCH', type=int, default=1)
    parser.add_argument('--seeds', '-s',
                        action='append', metavar='SEED', type=int, default=[None])

    args = parser.parse_args(args)
    return vars(args)


if __name__ == '__main__':
    sys.exit(main(parse_args(sys.argv[1:])))