"""Example using the gui to run pydtndim"""

__author__ = "Hunter McConnell <hunter.mcconnell@usask.ca"

from argparse import ArgumentParser
from collections import namedtuple
from os import path
from multiprocessing import Pool
from typing import ChainMap
from numpy import mean

import yaml
import sys
import csv
import matplotlib.pyplot as plt
import time

from pydtnsim import Network, RandomTraffic, Node, EpidemicNode, CSVTrace
from pydtnsim.community import BubbleKCliqueNode, BubbleLouvainNode
from pydtnsim.community import HCBFKCliqueNode, HCBFLouvainNode
from gui2 import Gui

Simulation = namedtuple("Simulation", ["tag", "trace", "node_type", "seed", "node_options", "traffic_options"])

def run_simulation(simulation):
    """Run a simulation"""
    
    csv = path.join(simulation.trace, "contact.csv")
    metadata = path.join(simulation.trace, "metadata.json")
    trace = CSVTrace(csv, metadata=metadata)

    nodes = {
        node_id: simulation.node_type(**simulation.node_options)
        for node_id in range(trace.nodes)
    }

    traffic = RandomTraffic(nodes, **simulation.traffic_options)

    network = Network(nodes, trace, traffic)
    network.run()

    stats = {
        "trace": simulation.trace,
        "node_type": simulation.node_type.__name__,
        "seed": simulation.seed,
        "tag" : simulation.tag,
    }
    stats.update(network.stats_summary)

    # return stats because we can't pickle the network as it is a generator.
    return stats

def main(args):
    """Run a simulation for each seed, for each independant variable, graph the results."""
    if args["no_gui"]:
        with open(args["config"], "r", newline="") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
    else:
        # put up GUI for user input
        gui = Gui(None)
        gui.title("Pydtnsim")

        if args["config"]:
            gui.load_config(args["config"])

        gui.mainloop()
        config = gui.param

    pool = Pool()
    simulations = []

    # create base list of sims
    base_sims = []
    trace = config["SimOptions"]["contact_dir"]
    seeds = config["SimOptions"]["seeds"]
    node_options = ChainMap(config["NodeOptions"], {"context": {}})
    traffic_options = ChainMap(config["TrafficOptions"], {"start": config["NodeOptions"]["epoch"]})

    for seed in seeds:
        for node_type in config["NodeChoices"]["Nodes"]:
            sim = Simulation(0, trace, node_type, seed, node_options, traffic_options)
            base_sims.append(sim)

    # expand base list by the independant variable
    choice = config["GraphingOptions"]["choice"]
    input = config["GraphingOptions"]["input"]
    if choice in config["NodeOptions"].keys():
        input.sort()
        for tag, option in enumerate(input):
            node_options[choice] = option
            for base_sim in base_sims:
                sim = Simulation(tag, base_sim.trace, base_sim.node_type, base_sim.seed, node_options, base_sim.traffic_options)
                simulations.append(sim)
    elif choice in config["TrafficOptions"].keys():
        input.sort()
        for tag, option in enumerate(input):
            traffic_options[choice] = option
            for base_sim in base_sims:
                sim = Simulation(tag, base_sim.trace, base_sim.node_type, base_sim.seed, base_sim.node_options, traffic_options)
                simulations.append(sim)
    elif choice == "contact_dir":
        for tag, trace in enumerate(input):     # just trust that they list the traces in order
            for base_sim in base_sims:
                sim = Simulation(tag, trace, base_sim.node_type, base_sim.seed, base_sim.node_options, base_sim.traffic_options)
                simulations.append(sim)
    else:
        simulations = base_sims

    results = {}

    print("sim running, please wait :)")
    start = time.time()
    for stats in pool.imap_unordered(run_simulation, simulations):
        type = stats["node_type"]
        if type not in results:
            results[type] = []
        results[type].append(stats)
    end = time.time()
    print("sim runtime:", end-start)

    # dump stats in csv
    with open("testing.csv", "w", newline="") as results_file:
        for node_type in results:
            fieldnames = results[node_type][0].keys()
            writer = csv.DictWriter(results_file, fieldnames=fieldnames, extrasaction='ignore')  # extrasaction avoids intermittent dictwriter valueerror
            writer.writeheader()
            for result in results[node_type]:
                writer.writerow(result)

    # graph results
    dependant = config["GraphingOptions"]["dependant"]
    (_, axes) = plt.subplots(len(dependant))
    x = input

    for node, stats in results.items():
        sorted_results = sorted(stats, key=lambda x: x["tag"])

        if len(dependant) > 1:      # avoids axes is not subscriptable error when plotting one graph
            for ax, dep in zip(axes, dependant):
                yyys = [ [] for _ in range(len(input))]
                # split the stats into batches
                for stat in sorted_results:
                    yyys[stat['tag']].append([stat[dep]])

                y = [mean(stat) for stat in yyys]
                ax.plot(x, y, "o-", label=node)
                ax.set_ylabel(dep)
        else:
            yyys = [ [] for _ in range(len(input))]
            # split the stats into batches
            for stat in sorted_results:
                yyys[stat['tag']].append([stat[dependant[0]]])

            y = [mean(stat) for stat in yyys]
            axes.plot(x, y, "o-", label=node)
            axes.set_ylabel(dependant[0])


    plt.xlabel(choice)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
    plt.show()





    
def parse_args(args):
    """Parse arguments."""
    parser = ArgumentParser()
    parser.add_argument(
        "--no_gui",
        "-n",
        action="store_true",
        help="skip gui and read from config file provided",
    )
    parser.add_argument(
        "--config",
        "-c",
        help="config file to read from",
        # default="config.yaml",
    )
    args = parser.parse_args(args)
    return vars(args)

if __name__ == "__main__":
    sys.exit(main(parse_args(sys.argv[1:])))