Skip to content
Snippets Groups Projects
main.py 4.27 KiB
from argparse import ArgumentParser
from datetime import datetime
import random
import sys

import simpy
import yaml

from pydyton import Network, NodeFactory
from pydyton.traces import types as traces
from pydyton.routers import types as routers
from pydyton.communities import types as communities

def parse_args(args):
    parser = ArgumentParser()

    parser.add_argument('--seed', '-s', metavar='seed', type=int,
                        default=42)

    parser.add_argument('--trace', '-t', metavar='type', default='random',
                        choices=[t for t in traces])
    parser.add_argument('--trace-args', '-ta', metavar='arg', nargs='+')

    parser.add_argument('--router', '-r', metavar='type', default='direct',
                        choices=[t for t in routers])
    parser.add_argument('--router-args', '-ra', metavar='arg', nargs='+')

    parser.add_argument('--community', '-c', metavar='type', default='none',
                        choices=[t for t in communities])
    parser.add_argument('--community-args', '-ca', metavar='arg', nargs='+')

    parser.add_argument('--node-args', '-na', metavar='arg', nargs='+')

    parser.add_argument('--packet-args', '-pa', metavar='arg', nargs='+')

    def list_to_args(args):
        if args is None:
            return {}
        else:
            args = '\n'.join(args).replace('=',': ')
            return yaml.safe_load(args)

    args = parser.parse_args(args)

    args.trace_args = list_to_args(args.trace_args)
    args.router_args = list_to_args(args.router_args)
    args.community_args = list_to_args(args.community_args)
    args.node_args = list_to_args(args.node_args)
    args.packet_args = list_to_args(args.packet_args)

    return args

class Progress:
    def __init__(self, length, ticks, alpha):
        self.length = length
        self.alpha = alpha

        self.tick = 0
        self.ticks = ticks

        self.__start = None
        self.__times = []

        self.__last_print_len = 0

    def print(self):
        print('\r' + ' '*self.__last_print_len, end='\r')
        line = f'{self.bar} {self.time} to go...'
        print(line, end=' ', flush=True)
        self.__last_print_len = len(line)

    @property
    def bar(self):
        progress = round(self.tick / self.ticks * self.length)
        return '[' + '='*progress + ' '*(self.length - progress) + ']'

    @property
    def time(self):
        ema = sum([
            value * (self.alpha**(len(self.__times) - i + 1))
            for i, value in enumerate(self.__times)
        ])
        remains = (self.ticks - self.tick) * ema

        ranges = [
            ('day', 24*60*60),
            ('hour', 60*60),
            ('minute', 60),
            ('second', 1),
        ]

        for unit, weight in ranges:
            if remains >= weight:
                remains = round(remains / weight)
                plural = 's' if remains != 1 else ''
                return f'{remains} {unit}{plural}'

        return '0 seconds'

    def __next__(self):
        now = datetime.now().timestamp()
        if self.__start is not None:
            self.__times.append(now - self.__start)
        self.__start = now

        self.tick += 1
        if self.tick > self.ticks:
            raise StopIteration

        self.print()
        return self.tick

    def __iter__(self):
        return self

def main(args):
    args = parse_args(args)
    print(args)

    random.seed(args.seed)
    env = simpy.Environment()

    trace = traces[args.trace](**args.trace_args)
    router = routers[args.router]
    community = communities[args.community](**args.community_args)

    node_factory = NodeFactory(router, **args.node_args)

    network = Network(env,
                      packets=args.packet_args,
                      node_factory=node_factory,
                      community=community,
                      trace=trace)

    progress = Progress(50, 500, 0.9)
    time_per_tick = []
    while True:
        try:
            line = '\r'
            for tick in progress:
                until = tick / progress.ticks * trace.duration
                env.run(until=until)
            print(' Done!')
            print(str(network.packets))
            return 0
        except KeyboardInterrupt:
            print('\nBye!')
            return 0

if __name__ == '__main__':
    exit(main(sys.argv[1:]))