Commit d3ba6eea authored by Jarrod Pas's avatar Jarrod Pas
Browse files

Cleans up csv trace

parent fc60bdfb
Pipeline #1663 passed with stage
in 1 minute and 33 seconds
......@@ -551,6 +551,8 @@ class Trace:
def __init__(self, nodes):
"""Create a contact trace."""
if nodes < 2:
raise ValueError('A trace requires at least 2 nodes')
self._nodes = nodes
@property
......@@ -580,43 +582,38 @@ class CSVTrace(Trace):
node_{a,b} -- nodes involved in contact
join -- whether the contact is going up (1) or going down (0)
Optionally accepts a metadata file so it does not have to read and store
the entire trace at once the beginning.
Optionally accepts a metadata file so it does not have to read the trace
twice.
"""
def __init__(self, path, metadata=None):
def __init__(self, path, nodes=0, metadata=None):
"""Create a csv trace generator."""
self.path = path
self._contacts = None
if metadata is None:
nodes = len(list(self.contacts))
else:
if metadata is not None:
with open(metadata) as metadata_file:
nodes = json.load(metadata_file)['nodes']
else:
with open(self.path) as csv_file:
csv_file = csv.reader(csv_file)
next(csv_file)
nodes = set()
for _, node_a, node_b, _ in csv_file:
nodes.add(node_a)
nodes.add(node_b)
nodes = len(nodes)
super().__init__(nodes)
@property
def contacts(self):
if self._contacts is not None:
yield from self._contacts
contacts = []
def __iter__(self):
"""Yield contacts from csv file."""
with open(self.path) as csv_file:
csv_file = csv.reader(csv_file)
# skip header
next(csv_file)
for row in csv_file:
contact = self.create_contact(*map(int, row))
contacts.append(contact)
yield contact
self._contacts = contacts
def __iter__(self):
"""Yield contacts from csv file."""
yield from self.contacts
yield self.create_contact(*map(int, row))
class RandomTrace(Trace):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment