from argparse import ArgumentParser
import sys
import random

import simpy
import yaml

from pydyton import Network, NodeFactory
from pydyton.traces import types as traces
from pydyton.routers import types as routers
from pydyton.communities import types as communities

def parse_args(args):
    parser = ArgumentParser()

    parser.add_argument('--seed', '-s', metavar='seed', type=int,
                        default=42)

    parser.add_argument('--trace', '-t', metavar='type', default='random',
                        choices=[t for t in traces])
    parser.add_argument('--trace-args', '-ta', metavar='arg', nargs='+')

    parser.add_argument('--router', '-r', metavar='type', default='direct',
                        choices=[t for t in routers])
    parser.add_argument('--router-args', '-ra', metavar='arg', nargs='+')

    parser.add_argument('--community', '-c', metavar='type', default='none',
                        choices=[t for t in communities])
    parser.add_argument('--community-args', '-ca', metavar='arg', nargs='+')

    parser.add_argument('--node-args', '-na', metavar='arg', nargs='+')

    def list_to_args(args):
        if args is None:
            return {}
        else:
            args = '\n'.join(args).replace('=',': ')
            return yaml.safe_load(args)

    args = parser.parse_args(args)

    args.trace_args = list_to_args(args.trace_args)
    args.router_args = list_to_args(args.router_args)
    args.community_args = list_to_args(args.community_args)
    args.node_args = list_to_args(args.node_args)

    return args


def main(args):
    args = parse_args(args)
    print(args)

    random.seed(args.seed)
    env = simpy.Environment()

    trace = traces[args.trace](**args.trace_args)
    router = routers[args.router]
    community = communities[args.community](**args.community_args)

    node_factory = NodeFactory(router, **args.node_args)

    network = Network(env,
                      node_factory=node_factory,
                      community=community,
                      trace=trace)


    while True:
        try:
            for tick in range(1, 51):
                t = tick / 50 * trace.duration
                env.run(until=t)
                tick = '=' * tick + ' ' * (50 - tick)
                print('\rRunning [{}] '.format(tick), end='')
            print(' Done!')
            print(len([p for p in network.packets if p.recieved]))
            return 0
        except KeyboardInterrupt:
            print('\nBye!')
            return 0

if __name__ == '__main__':
    exit(main(sys.argv[1:]))