from argparse import ArgumentParser
from datetime import datetime
import random
import sys

import simpy
import yaml

from pydtn import Network, NodeFactory
from pydtn.traces import types as traces
from pydtn.routers import types as routers
from pydtn.communities import types as communities

def parse_args(args):
    parser = ArgumentParser()

    parser.add_argument('--seed', '-s', metavar='seed', type=int,
                        default=42)

    parser.add_argument('--trace', '-t', metavar='type', default='random',
                        choices=[t for t in traces])
    parser.add_argument('--trace-args', '-ta', metavar='arg', nargs='+')

    parser.add_argument('--router', '-r', metavar='type', default='direct',
                        choices=[t for t in routers])
    parser.add_argument('--router-args', '-ra', metavar='arg', nargs='+')

    parser.add_argument('--community', '-c', metavar='type', default='none',
                        choices=[t for t in communities])
    parser.add_argument('--community-args', '-ca', metavar='arg', nargs='+')

    parser.add_argument('--node-args', '-na', metavar='arg', nargs='+')

    parser.add_argument('--packet-args', '-pa', metavar='arg', nargs='+')

    def list_to_args(args):
        if args is None:
            return {}
        else:
            args = '\n'.join(args).replace('=',': ')
            return yaml.safe_load(args)

    args = parser.parse_args(args)

    args.trace_args = list_to_args(args.trace_args)
    args.router_args = list_to_args(args.router_args)
    args.community_args = list_to_args(args.community_args)
    args.node_args = list_to_args(args.node_args)
    args.packet_args = list_to_args(args.packet_args)

    return args

class Progress:
    def __init__(self, length, ticks, alpha):
        self.length = length
        self.alpha = alpha

        self.tick = 0
        self.ticks = ticks

        self.average = 0
        self.__start = None

        self.__last_print_len = 0

    @property
    def bar(self):
        progress = round(self.tick / self.ticks * self.length)
        return '[' + '='*progress + ' '*(self.length - progress) + ']'

    @property
    def time(self):
        remains = (self.ticks - self.tick) * self.average

        ranges = [
            ('day', 24*60*60),
            ('hour', 60*60),
            ('minute', 60),
            ('second', 1),
        ]

        for unit, weight in ranges:
            if remains >= weight:
                remains = round(remains / weight)
                plural = 's' if remains != 1 else ''
                return f'{remains} {unit}{plural}'

        return '0 seconds'

    def __next__(self):
        now = datetime.now().timestamp()
        if self.__start is not None:
            dt = now - self.__start
            self.average = self.average + self.alpha * (dt - self.average)
        self.__start = now

        self.tick += 1
        if self.tick > self.ticks:
            raise StopIteration

        print('\r' + ' '*self.__last_print_len, end='\r')
        line = f'{self.bar} {self.time} to go...'
        print(line, end=' ', flush=True)
        self.__last_print_len = len(line)

        return self.tick

    def __iter__(self):
        return self

def main(args):
    args = parse_args(args)
    print(args)

    out = {
        'seed': args.seed,
        'router': args.router,
        'community': args.community,
    }
    if args.trace == 'csv':
        out['trace'] = args.trace_args['path']

    for k, v in out.items():
        print(f'{k}: {v}')

    random.seed(args.seed)
    env = simpy.Environment()

    trace = traces[args.trace](**args.trace_args)
    router = routers[args.router]
    community = communities[args.community](**args.community_args)

    node_factory = NodeFactory(router, **args.node_args)

    network = Network(env,
                      packets=args.packet_args,
                      node_factory=node_factory,
                      community=community,
                      trace=trace)

    progress = Progress(50, 500, 0.75)
    while True:
        try:
            for tick in progress:
                until = tick / progress.ticks * trace.duration
                env.run(until=until)
            print(' Done!')
            print(str(network.packets))
            return 0
        except KeyboardInterrupt:
            print('\nBye!')
            return 0

if __name__ == '__main__':
    exit(main(sys.argv[1:]))