From b5256c4b428335dba3f814fdd18db2c0f71327eb Mon Sep 17 00:00:00 2001
From: Jarrod Pas <j.pas@usask.ca>
Date: Thu, 20 Jul 2017 16:30:35 -0600
Subject: [PATCH] Modularize traces

---
 pydyton/traces.py             | 79 -----------------------------------
 pydyton/traces/__init__.py    |  8 ++++
 pydyton/traces/csvtrace.py    | 30 +++++++++++++
 pydyton/traces/randomtrace.py | 34 +++++++++++++++
 pydyton/traces/trace.py       | 21 ++++++++++
 5 files changed, 93 insertions(+), 79 deletions(-)
 delete mode 100644 pydyton/traces.py
 create mode 100644 pydyton/traces/__init__.py
 create mode 100644 pydyton/traces/csvtrace.py
 create mode 100644 pydyton/traces/randomtrace.py
 create mode 100644 pydyton/traces/trace.py

diff --git a/pydyton/traces.py b/pydyton/traces.py
deleted file mode 100644
index 5faff02..0000000
--- a/pydyton/traces.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import csv
-import random
-
-from core import Process, TickProcess
-
-class Trace(Process):
-    def __init__(self, duration, nodes):
-        super().__init__()
-
-        if duration < 1:
-            raise ValueError('duration must be greater than or equal to 1')
-        if nodes < 2:
-            raise ValueError('nodes must be greater than or equal to 2')
-
-        self.duration = duration
-        self.nodes = nodes
-
-    def process(self, network):
-        raise NotImplementedError
-        yield None
-
-class RandomTrace(Trace, TickProcess):
-    def __init__(self, duration=100, nodes=10,
-                 tick=1, min_toggles=1, max_toggles=2):
-        Trace.__init__(self, duration, nodes)
-        TickProcess.__init__(self, tick)
-
-        if min_toggles < 0:
-            raise ValueError('min_toggles < 0')
-        if min_toggles > max_toggles:
-            raise ValueError('min_toggles > max_toggles')
-        if max_toggles > nodes:
-            raise ValueError('max_toggles > nodes')
-
-        self.min_toggles = min_toggles
-        self.max_toggles = max_toggles
-
-    def process(self, network):
-        ''''''
-        env = self.env
-        links = network.links
-
-        while env.now < self.duration:
-            to_toggle = random.randint(self.min_toggles, self.max_toggles)
-            for a, b in random.sample(links, to_toggle):
-                network.toggle_link(a, b)
-            yield env.timeout(self.tick_time)
-
-class CSVTrace(Trace):
-    def __init__(self, path=None):
-        self.file = open(path)
-        self.reader = csv.reader(self.file)
-
-        time, a, b, state = next(self.reader)
-        if not time == 'time' and a == 'a' and b == 'b' and state == 'state':
-            raise ValueError('improperly formatted csv')
-
-        # first row of data is csv is [duration, -1, -1, nodes]
-        duration, _, _, nodes = map(int, next(self.reader))
-        super().__init__(duration, nodes)
-
-    def process(self, network):
-        ''''''
-        env = self.env
-
-        for row in self.reader:
-            time, a, b, state = map(int, row)
-            if time > env.now:
-                yield env.timeout(time - env.now)
-            network.set_link(a, b, bool(state))
-
-        self.file.close()
-
-
-types = {
-    'random': RandomTrace,
-    'csv':    CSVTrace,
-}
-
diff --git a/pydyton/traces/__init__.py b/pydyton/traces/__init__.py
new file mode 100644
index 0000000..5e044b5
--- /dev/null
+++ b/pydyton/traces/__init__.py
@@ -0,0 +1,8 @@
+from .csvtrace import CSVTrace
+from .randomtrace import RandomTrace
+
+types = {
+    'csv': CSVTrace,
+    'random': RandomTrace
+}
+
diff --git a/pydyton/traces/csvtrace.py b/pydyton/traces/csvtrace.py
new file mode 100644
index 0000000..d5da192
--- /dev/null
+++ b/pydyton/traces/csvtrace.py
@@ -0,0 +1,30 @@
+import csv
+
+from .trace import Trace
+
+class CSVTrace(Trace):
+    def __init__(self, path=None):
+        self.file = open(path)
+        self.reader = csv.reader(self.file)
+
+        time, a, b, state = next(self.reader)
+        if not time == 'time' and a == 'a' and b == 'b' and state == 'state':
+            raise ValueError('improperly formatted csv')
+
+        # first row of data is csv is [duration, -1, -1, nodes]
+        duration, _, _, nodes = map(int, next(self.reader))
+        super().__init__(duration, nodes)
+
+    def process(self, network):
+        ''''''
+        env = self.env
+
+        for row in self.reader:
+            time, a, b, state = map(int, row)
+            if time > env.now:
+                yield env.timeout(time - env.now)
+            network.set_link(a, b, bool(state))
+
+        self.file.close()
+
+
diff --git a/pydyton/traces/randomtrace.py b/pydyton/traces/randomtrace.py
new file mode 100644
index 0000000..942e5f9
--- /dev/null
+++ b/pydyton/traces/randomtrace.py
@@ -0,0 +1,34 @@
+from random import (
+    randint,
+    sample,
+)
+
+from .trace import TickTrace
+
+class RandomTrace(TickTrace):
+    def __init__(self, duration=100, nodes=100, tick=1,
+                 min_toggles=1, max_toggles=10):
+        super().__init__(duration, nodes, tick)
+
+        if min_toggles < 0:
+            raise ValueError('min_toggles < 0')
+        if min_toggles > max_toggles:
+            raise ValueError('min_toggles > max_toggles')
+        if max_toggles > nodes:
+            raise ValueError('max_toggles > nodes')
+
+        self.min_toggles = min_toggles
+        self.max_toggles = max_toggles
+
+    def process(self, network):
+        ''''''
+        env = self.env
+        links = network.links
+
+        while env.now < self.duration:
+            to_toggle = randint(self.min_toggles, self.max_toggles)
+            for a, b in sample(links, to_toggle):
+                network.toggle_link(a, b)
+            yield self.tick()
+
+
diff --git a/pydyton/traces/trace.py b/pydyton/traces/trace.py
new file mode 100644
index 0000000..ca94258
--- /dev/null
+++ b/pydyton/traces/trace.py
@@ -0,0 +1,21 @@
+from pydyton.core import Process, TickProcess
+
+class Trace(Process):
+    def __init__(self, duration, nodes):
+        Process.__init__(self)
+
+        if duration < 1:
+            raise ValueError('duration must be greater than or equal to 1')
+        if nodes < 2:
+            raise ValueError('nodes must be greater than or equal to 2')
+
+        self.duration = duration
+        self.nodes = nodes
+
+
+class TickTrace(Trace, TickProcess):
+    def __init__(self, duration, nodes, tick):
+        Trace.__init__(self, duration, nodes)
+        TickProcess.__init__(self, tick)
+
+
-- 
GitLab