From 7e0c0264a4ee345f4c9d7e6e5d890339f738cff2 Mon Sep 17 00:00:00 2001
From: Jarrod Pas <j.pas@usask.ca>
Date: Thu, 20 Jul 2017 16:03:03 -0600
Subject: [PATCH] Generifies processes

---
 pydyton/communities.py |  25 ++--
 pydyton/core.py        | 318 ++---------------------------------------
 pydyton/network.py     | 312 ++++++++++++++++++++++++++++++++++++++++
 pydyton/traces.py      |  48 ++++---
 4 files changed, 359 insertions(+), 344 deletions(-)
 create mode 100644 pydyton/network.py

diff --git a/pydyton/communities.py b/pydyton/communities.py
index da9c2b8..0d678be 100644
--- a/pydyton/communities.py
+++ b/pydyton/communities.py
@@ -4,12 +4,12 @@ from itertools import product
 import networkx as nx
 from community import best_partition as louvain_partition
 
-class EpochCommunity:
+from core import TickProcess
+
+class EpochCommunity(TickProcess):
     def __init__(self, epoch=1, **kwargs):
-        if epoch < 1:
-            raise ValueError('epoch < 1')
+        super().__init__(epoch)
 
-        self.epoch = epoch
         self.graph = nx.Graph()
         self.old_graph = None
         self.community = defaultdict(frozenset)
@@ -51,13 +51,6 @@ class EpochCommunity:
 
         return self.old_graph
 
-    def tick(self, env, network):
-        raise NotImplementedError
-        yield env.timeout(0)
-
-    def process(self, env, network):
-        return env.process(self.tick(env, network))
-
     def __getitem__(self, node):
         return self.community[node]
 
@@ -149,9 +142,9 @@ class KCliqueCommunity(EpochCommunity):
         self.k = k
         self.threshold = threshold
 
-    def tick(self, env, network):
+    def process(self, network):
         while True:
-            yield env.timeout(self.epoch)
+            yield self.env.timeout(self.tick)
             g = self.next_epoch(env.now)
 
             G = nx.Graph()
@@ -169,13 +162,11 @@ class LouvainCommunity(EpochCommunity):
     def __init__(self, epoch=604800, **kwargs):
         super().__init__(epoch=epoch, **kwargs)
 
-    def tick(self, env, network):
+    def process(self, network):
         while True:
-            yield env.timeout(self.epoch)
+            yield self.env.timeout(self.tick)
             g = self.next_epoch(env.now)
 
-            # I made a change in community package to get this to work
-            # change graph.copy() to nx.Graph(graph) in community_louvain.py
             p = louvain_partition(g, weight='duration')
             communities = defaultdict(set)
             for node, c in louvain_partition(g, weight='duration').items():
diff --git a/pydyton/core.py b/pydyton/core.py
index cacc49e..33d06bb 100644
--- a/pydyton/core.py
+++ b/pydyton/core.py
@@ -1,311 +1,21 @@
-from argparse import ArgumentParser
-from collections import defaultdict, OrderedDict
-from itertools import combinations
-import math
-import random
-import sys
+class Process:
+    def __init__(self):
+        self.env = None
+        self.__process = None
 
-import simpy
-import yaml
-import networkx as nx
-
-from communities import types as communities
-from routers import types as routers
-from traces import types as traces
-
-class Network:
-    ''''''
-    def __init__(self, env,
-                 node_factory=None,
-                 community=None,
-                 trace=None):
-        ''''''
-        self.env = env
-
-        # contact trace
-        if trace is None:
-            trace = traces['random']()
-        self.trace = trace
-        self.trace_proc = self.trace.process(self.env, self)
-
-        # community detection
-        self.community = community
-        if community is not None:
-            self.community_proc = self.community.process(self.env, self)
-
-        # packet generation
-        '''
-        if packets is None:
-            packets = packets['uniform']()
-        self.packets = packets
-        self.packets_proc = self.packets.process(self.env, self)
-        '''
-
-        # create node network
-        if node_factory is None:
-            node_factory = NodeFactory(tick_rate=1, router=routers['direct'])
-        self.nodes = [
-            node_factory(env, self, nid)
-            for nid in range(self.trace.nodes)
-        ]
-        self.links = [
-            (a, b)
-            for a, b in combinations(self.nodes, 2)
-        ]
-
-        # set up networkx graph
-        self.graph = nx.Graph()
-        self.graph.add_nodes_from(self.nodes)
-        self.graph.add_edges_from([
-            (a, b, { 'state': False })
-            for a, b in self.links
-        ])
-
-        # TODO: better packet generation
-        self.packets = []
-        for i in range(100):
-            source, dest = random.choice(self.links)
-            packet = Packet(i, source, dest, trace.duration, None)
-            self.packets.append(packet)
-            source.recv(packet)
-
-    def set_link(self, a, b, state):
-        if isinstance(a, int):
-            a = self.nodes[a]
-
-        if isinstance(b, int):
-            b = self.nodes[b]
-
-        edge = self[a][b]
-        if edge['state'] == state:
-            return
-
-        if state is None:
-            state = not edge['state']
-        edge['state'] = state
-
-        if self.community:
-            self.community.set_link(a, b, state, self.env.now)
-
-    def toggle_link(self, a, b):
-        self.set_link(a, b, None)
-
-    def send_link(self, a, b, packet):
-        ''''''
-        if self[a][b]['state']:
-            # TODO: transfer delay
-            b.recv(packet)
-        else:
-            raise Exception('Nodes {} and {} not connected'.format(a, b))
-
-    def __getitem__(self, node):
-        ''''''
-        return self.graph[node]
-
-
-def NodeFactory(router, **kwargs):
-    nid = -1
-    def factory(env, network, nid):
-        nid += 1
-        return Node(env, network, nid, router=router, **kwargs)
-    return factory
-
-
-class Node:
-    ''''''
-    def __init__(self, env, network, nid,
-                 tick_time=1, router=None):
-        ''''''
-        self.env = env
-
-        self.network = network
-        self.id = nid
-
-        self.tick_time = tick_time
-        self.ticker = env.process(self.tick())
-
-        self.buffer = Buffer(self.env)
-
-        # bind router as a class method
-        if router is None:
-            router = routers['direct']
-        self.router = router.__get__(self, Node)
-        self.router_state = {}
-
-    def tick(self):
-        ''''''
-        while True:
-            packets_to_delete = []
-
-            for packet in self.buffer:
-                if packet.ttl < self.env.now:
-                    packets_to_delete.append(packet)
-                    continue
-                if self.router(packet, self.router_state):
-                    packets_to_delete.append(packet)
-
-            for packet in packets_to_delete:
-                self.buffer.remove(packet)
-
-            yield self.env.timeout(self.tick_time)
-
-    def send(self, to, packet):
-        # TODO: transfer delay
-        self.network.send_link(self, to, packet)
-
-    def recv(self, packet):
-        if packet.destination == self:
-            packet.recv()
-        else:
-            self.buffer.add(packet)
-
-    @property
-    def community(self):
-        return self.network.community[self]
-
-    @property
-    def links(self):
-        '''
-        Returns a list of connected links.
-        '''
-        links = {
-            met: data
-            for met, data in self.network[self].items()
-            if data['state']
-        }
-        return links
-
-    def __repr__(self):
-        return 'Node(id={})'.format(self.id)
-
-
-class Buffer:
-    def __init__(self, env, capacity=0):
+    def start(self, env, *args):
         self.env = env
+        self.__process = env.process(self.process(*args))
 
-        if capacity <= 0:
-            self.capacity = float('inf')
-        else:
-            self.capacity = capacity
-
-        self.buffer = OrderedDict()
-        self.used = 0
-
-    def add(self, packet):
-        if self.used < self.capacity:
-            self.used += 1
-            self.buffer[packet] = None
-        else:
-            raise Exception('buffer full')
-
-    def remove(self, packet):
-        self.used -= 1
-        del self.buffer[packet]
-
-    def __contains__(self, packet):
-        return packet in self.buffer
-
-    def __iter__(self):
-        return iter(self.buffer)
-
-    def __len__(self):
-        return len(self.buffer)
-
-
-class Packet:
-    def __init__(self, id, source, destination, ttl, payload):
-        self.id = id
-        self.source = source
-        self.destination = destination
-        self.ttl = ttl
-        self.payload = payload
-
-        self.recieved = False
-        self.recieved_count = 0
-
-    def recv(self):
-        self.recieved = True
-        self.recieved_count += 1
-
-    def __str__(self):
-        return "Packet(id={}, source={}, destination={})".format(
-            self.id,
-            self.source,
-            self.destination
-        )
-
-
-def parse_args(args):
-    parser = ArgumentParser()
-
-    parser.add_argument('--seed', '-s', metavar='seed', type=int,
-                        default=42)
-
-    parser.add_argument('--trace', '-t', metavar='type', default='random',
-                        choices=[t for t in traces])
-    parser.add_argument('--trace-args', '-ta', metavar='arg', nargs='+')
-
-    parser.add_argument('--router', '-r', metavar='type', default='direct',
-                        choices=[t for t in routers])
-    parser.add_argument('--router-args', '-ra', metavar='arg', nargs='+')
-
-    parser.add_argument('--community', '-c', metavar='type', default='none',
-                        choices=[t for t in communities])
-    parser.add_argument('--community-args', '-ca', metavar='arg', nargs='+')
-
-    parser.add_argument('--node-args', '-na', metavar='arg', nargs='+')
-
-    def list_to_args(args):
-        if args is None:
-            return {}
-        else:
-            args = '\n'.join(args).replace('=',': ')
-            return yaml.safe_load(args)
-
-    args = parser.parse_args(args)
-
-    args.trace_args = list_to_args(args.trace_args)
-    args.router_args = list_to_args(args.router_args)
-    args.community_args = list_to_args(args.community_args)
-    args.node_args = list_to_args(args.node_args)
-
-    return args
-
-
-def main(args):
-    args = parse_args(args)
-    print(args)
-
-    random.seed(args.seed)
-    env = simpy.Environment()
-
-    trace = traces[args.trace](**args.trace_args)
-    router = routers[args.router]
-    community = communities[args.community](**args.community_args)
-
-    node_factory = NodeFactory(router, **args.node_args)
-
-    network = Network(env,
-                      node_factory=node_factory,
-                      community=community,
-                      trace=trace)
-
+    def process(self, *args):
+        raise NotImplementedError
 
-    while True:
-        try:
-            for tick in range(1, 51):
-                t = tick / 50 * trace.duration
-                env.run(until=t)
-                tick = '=' * tick + ' ' * (50 - tick)
-                print('\rRunning [{}] '.format(tick), end='')
-            print(' Done!')
-            print(len([p for p in network.packets if p.recieved]))
-            return 0
-        except KeyboardInterrupt:
-            print('\nBye!')
-            return 0
 
-if __name__ == '__main__':
-    exit(main(sys.argv[1:]))
+class TickProcess(Process):
+    def __init__(self, tick):
+        super().__init__()
+        if tick < 1:
+            raise ValueError('tick must be greatert than or equal to 1')
+        self.tick = tick
 
 
diff --git a/pydyton/network.py b/pydyton/network.py
new file mode 100644
index 0000000..7374c98
--- /dev/null
+++ b/pydyton/network.py
@@ -0,0 +1,312 @@
+from argparse import ArgumentParser
+from collections import defaultdict, OrderedDict
+from itertools import combinations
+import math
+import random
+import sys
+
+import simpy
+import yaml
+import networkx as nx
+
+from communities import types as communities
+from routers import types as routers
+from traces import types as traces
+
+class Network:
+    ''''''
+    def __init__(self, env,
+                 packets=None,
+                 node_factory=None,
+                 community=None,
+                 trace=None):
+        ''''''
+        self.env = env
+
+        # contact trace
+        if trace is None:
+            trace = traces['random']()
+        self.trace = trace
+        self.trace.start(env, self)
+
+        # community detection
+        self.community = community
+        if community is not None:
+            self.community.start(env, self)
+
+        # packet generation
+        '''
+        if packets is None:
+            packets = packets['uniform']()
+        self.packets = packets
+        self.packets_proc = self.packets.process(self.env, self)
+        '''
+
+        # create node network
+        if node_factory is None:
+            node_factory = NodeFactory(tick_rate=1, router=routers['direct'])
+        self.nodes = [
+            node_factory(env, self, nid)
+            for nid in range(self.trace.nodes)
+        ]
+        self.links = [
+            (a, b)
+            for a, b in combinations(self.nodes, 2)
+        ]
+
+        # set up networkx graph
+        self.graph = nx.Graph()
+        self.graph.add_nodes_from(self.nodes)
+        self.graph.add_edges_from([
+            (a, b, { 'state': False })
+            for a, b in self.links
+        ])
+
+        # TODO: better packet generation
+        self.packets = []
+        for i in range(100):
+            source, dest = random.choice(self.links)
+            packet = Packet(i, source, dest, trace.duration, None)
+            self.packets.append(packet)
+            source.recv(packet)
+
+    def set_link(self, a, b, state):
+        if isinstance(a, int):
+            a = self.nodes[a]
+
+        if isinstance(b, int):
+            b = self.nodes[b]
+
+        edge = self[a][b]
+        if edge['state'] == state:
+            return
+
+        if state is None:
+            state = not edge['state']
+        edge['state'] = state
+
+        if self.community:
+            self.community.set_link(a, b, state, self.env.now)
+
+    def toggle_link(self, a, b):
+        self.set_link(a, b, None)
+
+    def send_link(self, a, b, packet):
+        ''''''
+        if self[a][b]['state']:
+            # TODO: transfer delay
+            b.recv(packet)
+        else:
+            raise Exception('Nodes {} and {} not connected'.format(a, b))
+
+    def __getitem__(self, node):
+        ''''''
+        return self.graph[node]
+
+
+def NodeFactory(router, **kwargs):
+    nid = -1
+    def factory(env, network, nid):
+        nid += 1
+        return Node(env, network, nid, router=router, **kwargs)
+    return factory
+
+
+class Node:
+    ''''''
+    def __init__(self, env, network, nid,
+                 tick_time=1, router=None):
+        ''''''
+        self.env = env
+
+        self.network = network
+        self.id = nid
+
+        self.tick_time = tick_time
+        self.ticker = env.process(self.tick())
+
+        self.buffer = Buffer(self.env)
+
+        # bind router as a class method
+        if router is None:
+            router = routers['direct']
+        self.router = router.__get__(self, Node)
+        self.router_state = {}
+
+    def tick(self):
+        ''''''
+        while True:
+            packets_to_delete = []
+
+            for packet in self.buffer:
+                if packet.ttl < self.env.now:
+                    packets_to_delete.append(packet)
+                    continue
+                if self.router(packet, self.router_state):
+                    packets_to_delete.append(packet)
+
+            for packet in packets_to_delete:
+                self.buffer.remove(packet)
+
+            yield self.env.timeout(self.tick_time)
+
+    def send(self, to, packet):
+        # TODO: transfer delay
+        self.network.send_link(self, to, packet)
+
+    def recv(self, packet):
+        if packet.destination == self:
+            packet.recv()
+        else:
+            self.buffer.add(packet)
+
+    @property
+    def community(self):
+        return self.network.community[self]
+
+    @property
+    def links(self):
+        '''
+        Returns a list of connected links.
+        '''
+        links = {
+            met: data
+            for met, data in self.network[self].items()
+            if data['state']
+        }
+        return links
+
+    def __repr__(self):
+        return 'Node(id={})'.format(self.id)
+
+
+class Buffer:
+    def __init__(self, env, capacity=0):
+        self.env = env
+
+        if capacity <= 0:
+            self.capacity = float('inf')
+        else:
+            self.capacity = capacity
+
+        self.buffer = OrderedDict()
+        self.used = 0
+
+    def add(self, packet):
+        if self.used < self.capacity:
+            self.used += 1
+            self.buffer[packet] = None
+        else:
+            raise Exception('buffer full')
+
+    def remove(self, packet):
+        self.used -= 1
+        del self.buffer[packet]
+
+    def __contains__(self, packet):
+        return packet in self.buffer
+
+    def __iter__(self):
+        return iter(self.buffer)
+
+    def __len__(self):
+        return len(self.buffer)
+
+
+class Packet:
+    def __init__(self, id, source, destination, ttl, payload):
+        self.id = id
+        self.source = source
+        self.destination = destination
+        self.ttl = ttl
+        self.payload = payload
+
+        self.recieved = False
+        self.recieved_count = 0
+
+    def recv(self):
+        self.recieved = True
+        self.recieved_count += 1
+
+    def __str__(self):
+        return "Packet(id={}, source={}, destination={})".format(
+            self.id,
+            self.source,
+            self.destination
+        )
+
+
+def parse_args(args):
+    parser = ArgumentParser()
+
+    parser.add_argument('--seed', '-s', metavar='seed', type=int,
+                        default=42)
+
+    parser.add_argument('--trace', '-t', metavar='type', default='random',
+                        choices=[t for t in traces])
+    parser.add_argument('--trace-args', '-ta', metavar='arg', nargs='+')
+
+    parser.add_argument('--router', '-r', metavar='type', default='direct',
+                        choices=[t for t in routers])
+    parser.add_argument('--router-args', '-ra', metavar='arg', nargs='+')
+
+    parser.add_argument('--community', '-c', metavar='type', default='none',
+                        choices=[t for t in communities])
+    parser.add_argument('--community-args', '-ca', metavar='arg', nargs='+')
+
+    parser.add_argument('--node-args', '-na', metavar='arg', nargs='+')
+
+    def list_to_args(args):
+        if args is None:
+            return {}
+        else:
+            args = '\n'.join(args).replace('=',': ')
+            return yaml.safe_load(args)
+
+    args = parser.parse_args(args)
+
+    args.trace_args = list_to_args(args.trace_args)
+    args.router_args = list_to_args(args.router_args)
+    args.community_args = list_to_args(args.community_args)
+    args.node_args = list_to_args(args.node_args)
+
+    return args
+
+
+def main(args):
+    args = parse_args(args)
+    print(args)
+
+    random.seed(args.seed)
+    env = simpy.Environment()
+
+    trace = traces[args.trace](**args.trace_args)
+    router = routers[args.router]
+    community = communities[args.community](**args.community_args)
+
+    node_factory = NodeFactory(router, **args.node_args)
+
+    network = Network(env,
+                      node_factory=node_factory,
+                      community=community,
+                      trace=trace)
+
+
+    while True:
+        try:
+            for tick in range(1, 51):
+                t = tick / 50 * trace.duration
+                env.run(until=t)
+                tick = '=' * tick + ' ' * (50 - tick)
+                print('\rRunning [{}] '.format(tick), end='')
+            print(' Done!')
+            print(len([p for p in network.packets if p.recieved]))
+            return 0
+        except KeyboardInterrupt:
+            print('\nBye!')
+            return 0
+
+if __name__ == '__main__':
+    exit(main(sys.argv[1:]))
+
+
diff --git a/pydyton/traces.py b/pydyton/traces.py
index 61cdd7d..5faff02 100644
--- a/pydyton/traces.py
+++ b/pydyton/traces.py
@@ -1,30 +1,30 @@
 import csv
 import random
 
-class Trace:
-    def __init__(self, duration=1, nodes=2, **kwargs):
+from core import Process, TickProcess
+
+class Trace(Process):
+    def __init__(self, duration, nodes):
+        super().__init__()
+
         if duration < 1:
-            raise ValueError('duration < 1')
+            raise ValueError('duration must be greater than or equal to 1')
         if nodes < 2:
-            raise ValueError('nodes < 2')
+            raise ValueError('nodes must be greater than or equal to 2')
 
         self.duration = duration
         self.nodes = nodes
 
-    def trace(self, env, network):
-        yield env.timeout(self.duration)
-        env.exit()
+    def process(self, network):
+        raise NotImplementedError
+        yield None
 
-    def process(self, env, network):
-        return env.process(self.trace(env, network))
+class RandomTrace(Trace, TickProcess):
+    def __init__(self, duration=100, nodes=10,
+                 tick=1, min_toggles=1, max_toggles=2):
+        Trace.__init__(self, duration, nodes)
+        TickProcess.__init__(self, tick)
 
-class RandomTrace(Trace):
-    def __init__(self, duration=100, nodes=10, tick_time=1, min_toggles=1,
-                 max_toggles=2, **kwargs):
-        super().__init__(duration, nodes)
-
-        if tick_time < 0:
-            raise ValueError('tick_time < 0')
         if min_toggles < 0:
             raise ValueError('min_toggles < 0')
         if min_toggles > max_toggles:
@@ -32,42 +32,44 @@ class RandomTrace(Trace):
         if max_toggles > nodes:
             raise ValueError('max_toggles > nodes')
 
-        self.tick_time = tick_time
         self.min_toggles = min_toggles
         self.max_toggles = max_toggles
 
-    def trace(self, env, network):
+    def process(self, network):
         ''''''
+        env = self.env
         links = network.links
+
         while env.now < self.duration:
             to_toggle = random.randint(self.min_toggles, self.max_toggles)
             for a, b in random.sample(links, to_toggle):
                 network.toggle_link(a, b)
             yield env.timeout(self.tick_time)
-        env.exit()
 
 class CSVTrace(Trace):
-    def __init__(self, path=None, **kwargs):
+    def __init__(self, path=None):
         self.file = open(path)
         self.reader = csv.reader(self.file)
 
         time, a, b, state = next(self.reader)
         if not time == 'time' and a == 'a' and b == 'b' and state == 'state':
             raise ValueError('improperly formatted csv')
-        # first row of data is csv is [duration, -1, -1, nodes]
 
+        # first row of data is csv is [duration, -1, -1, nodes]
         duration, _, _, nodes = map(int, next(self.reader))
         super().__init__(duration, nodes)
 
-    def trace(self, env, network):
+    def process(self, network):
         ''''''
+        env = self.env
+
         for row in self.reader:
             time, a, b, state = map(int, row)
             if time > env.now:
                 yield env.timeout(time - env.now)
             network.set_link(a, b, bool(state))
+
         self.file.close()
-        env.exit()
 
 
 types = {
-- 
GitLab