Commit d3ba6eea authored by Jarrod Pas's avatar Jarrod Pas
Browse files

Cleans up csv trace

parent fc60bdfb
Pipeline #1663 passed with stage
in 1 minute and 33 seconds
...@@ -551,6 +551,8 @@ class Trace: ...@@ -551,6 +551,8 @@ class Trace:
def __init__(self, nodes): def __init__(self, nodes):
"""Create a contact trace.""" """Create a contact trace."""
if nodes < 2:
raise ValueError('A trace requires at least 2 nodes')
self._nodes = nodes self._nodes = nodes
@property @property
...@@ -580,43 +582,38 @@ class CSVTrace(Trace): ...@@ -580,43 +582,38 @@ class CSVTrace(Trace):
node_{a,b} -- nodes involved in contact node_{a,b} -- nodes involved in contact
join -- whether the contact is going up (1) or going down (0) join -- whether the contact is going up (1) or going down (0)
Optionally accepts a metadata file so it does not have to read and store Optionally accepts a metadata file so it does not have to read the trace
the entire trace at once the beginning. twice.
""" """
def __init__(self, path, metadata=None): def __init__(self, path, nodes=0, metadata=None):
"""Create a csv trace generator.""" """Create a csv trace generator."""
self.path = path self.path = path
self._contacts = None self._contacts = None
if metadata is None: if metadata is not None:
nodes = len(list(self.contacts))
else:
with open(metadata) as metadata_file: with open(metadata) as metadata_file:
nodes = json.load(metadata_file)['nodes'] nodes = json.load(metadata_file)['nodes']
else:
with open(self.path) as csv_file:
csv_file = csv.reader(csv_file)
next(csv_file)
nodes = set()
for _, node_a, node_b, _ in csv_file:
nodes.add(node_a)
nodes.add(node_b)
nodes = len(nodes)
super().__init__(nodes) super().__init__(nodes)
@property def __iter__(self):
def contacts(self): """Yield contacts from csv file."""
if self._contacts is not None:
yield from self._contacts
contacts = []
with open(self.path) as csv_file: with open(self.path) as csv_file:
csv_file = csv.reader(csv_file) csv_file = csv.reader(csv_file)
# skip header # skip header
next(csv_file) next(csv_file)
for row in csv_file: for row in csv_file:
contact = self.create_contact(*map(int, row)) yield self.create_contact(*map(int, row))
contacts.append(contact)
yield contact
self._contacts = contacts
def __iter__(self):
"""Yield contacts from csv file."""
yield from self.contacts
class RandomTrace(Trace): class RandomTrace(Trace):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment