"""pydtn is a module for simulating delay tolerant networks."""

__all__ = [
    'Contact',
    'Traffic',

    'Network',
    'Buffer',

    'Node',
    'EpidemicNode',
]
__version__ = '0.1'
__author__ = 'Jarrod Pas <j.pas@usask.ca>'

from collections import ChainMap, defaultdict, namedtuple, OrderedDict
import csv
from itertools import count
from random import Random
from time import time

import simpy


class Network:
    """
    Network simulation.

    TODO: elaborate.
    """

    def __init__(self, nodes=None, trace=None, traffic=None):
        """Create a network."""
        self.env = simpy.Environment()

        self._stats = defaultdict(int)
        self._stats['wall-time'] = 0

        if nodes is None:
            nodes = {}
        self.nodes = nodes

        if not trace:
            trace = []
        self.trace = trace
        self.trace_done = simpy.Event(self.env)

        if not traffic:
            traffic = []
        self.traffic = traffic

        self.packets = []

        # start components in order of priority
        self.env.process(self.traffic_loop())
        self.env.process(self.trace_loop())

        for node in self.nodes.values():
            node.start(self)

    def traffic_loop(self):
        """Trigger traffic for nodes."""
        for traffic in self.traffic:
            if traffic.created > self.now:
                yield self.env.timeout(traffic.created - self.now)

            packet = Packet(self, traffic)
            self.packets.append(packet)

            packet.source.recv(packet)

    def trace_loop(self):
        """Trigger contacts for nodes."""
        for contact in self.trace:
            if contact.time > self.now:
                yield self.env.timeout(contact.time - self.now)

            node_a = self.nodes[contact.a]
            node_b = self.nodes[contact.b]

            if contact.join:
                node_a.join(node_b)
                node_b.join(node_a)
            else:
                node_a.leave(node_b)
                node_b.leave(node_a)

        self.trace_done.succeed()

    @property
    def now(self):
        """Return current simulation time."""
        return self.env.now

    def run(self, until=None):
        """
        Run the simulation.

        Keyword Arguments:
        until -- tick to run the simulation to and stop
        """
        if until is None:
            until = self.trace_done

        start = time()
        self.env.run(until=until)
        end = time()

        self._stats['wall-time'] += end - start

    @property
    def stats(self):
        """Return statistics for the simulation."""
        def gather(group):
            """Gather statistics from a group and return a dict of lists."""
            stats = defaultdict(list)
            for member in group:
                for stat, value in member.stats.items():
                    stats[stat].append(value)
            return dict(stats)

        packet_stats = gather(self.packets)
        node_stats = gather(self.nodes.values())

        stats = {}
        stats['sim-time'] = self.env.now

        stats['broadcasts'] = sum(node_stats['broadcasts'])

        stats['packets'] = len(self.packets)

        delivered = list(filter(lambda r: r > 0, packet_stats['recieved']))
        stats['delivered'] = len(delivered)

        stats['delivery-ratio'] = stats['delivered'] / stats['packets']

        stats['delivery-cost'] = stats['broadcasts'] / stats['delivered']

        stats['delay'] = sum(packet_stats['delay']) / stats['delivered']

        stats.update(self._stats)
        return stats


class Packet:
    """An item to route through the network."""

    def __init__(self, network, traffic):

        self.network = network
        self._traffic = traffic

        self.recieved = None
        self._stats = defaultdict(int)

    @property
    def source(self):
        """Return source node of the packet."""
        return self.network.nodes[self._traffic.source]

    @property
    def destination(self):
        """Return destination node of the packet."""
        return self.network.nodes[self._traffic.destination]

    @property
    def created(self):
        """Return created time of packet."""
        return self._traffic.created

    @property
    def time_to_live(self):
        """Return time to live of packet."""
        return self._traffic.time_to_live

    @property
    def deadline(self):
        """Return deadline of packet."""
        return self.created + self.time_to_live

    @property
    def expired(self):
        """Is the packet expired."""
        return self.network.now > self.deadline

    def sent(self, target, reason=None):
        """Send the packet from source to taget."""
        self._stats['hops'] += 1
        if reason:
            self._stats['hops-%s' % reason] += 1

        if target is self.destination:
            if self.recieved is None:
                self.recieved = 0
                self._stats['delay'] = self.network.now - self.created
            self.recieved += 1

    @property
    def stats(self):
        """Return statistcs for a packet."""
        stats = {}

        if self.recieved:
            stats['recieved'] = self.recieved

        stats.update(self._stats)
        return stats

    def __len__(self):
        """Return size of packet."""
        return self._traffic.payload


class Buffer:
    """A place for a node to hold packets."""

    def __init__(self, **options):
        """
        Create a buffer.

        Stores the order that packet were added to the buffer.
        Can be removed from while being iterated over.

        Keyword arguments:
        capacity -- the maximum number of packets to hold (default infinity).
        """
        self.capacity = options.get('capacity', simpy.core.Infinity)
        self.store = OrderedDict()

    @property
    def full(self):
        """Is the buffer full."""
        return len(self.store) >= self.capacity

    def add(self, packet):
        """Add a packet to the buffer."""
        if self.full:
            return False
        self.store[packet] = None
        return True

    def remove(self, packet):
        """Remove packet from the buffer."""
        if packet in self.store:
            del self.store[packet]

    def __contains__(self, packet):
        """Is the packet in the buffer."""
        return packet in self.store

    def __iter__(self):
        """
        Return an iterator for packets in the buffer.

        Allows for removing items during iteration.
        """
        return iter(list(self.store.keys()))


class Node:
    """Basic implementation of a node implements direct routing."""

    class SendFailed(Exception):
        """Raised when a send fails."""

    def __init__(self, name, **options):
        """Create a node."""
        self.name = name

        self.network = None

        self.buffer = Buffer()
        self._neighbours = set()

        self.options = ChainMap(options, {
            'bandwidth': simpy.core.Infinity,
            'tick_rate': 1,
        })

        self.stats = defaultdict(int)
        self.stats['broadcasts'] = 0

    def start(self, network):
        """
        Start event loop.

        If it has already been started do nothing.
        """
        if self.network is not None:
            return

        self.network = network
        self.network.env.process(self.tick())

    def tick(self):
        """Thread of execution for a node."""
        tick = self.options['tick_rate']
        while True:
            start = self.network.now
            yield from self.forward_all()
            delay = tick - (self.network.now - start) % tick
            yield self.network.env.timeout(delay)

    def join(self, node):
        """
        Approach the neighbouhood of this node.

        This is an idempotent operation.
        """
        if node not in self._neighbours:
            self._neighbours.add(node)

    def leave(self, node):
        """
        Leave the neighbouhood of this node.

        This is an idempotent operation.
        """
        if node in self._neighbours:
            self._neighbours.remove(node)

    @property
    def neighbours(self):
        """Return nodes within transmission range."""
        return frozenset(self._neighbours)

    def forward(self, packet):
        """Forward packet directly to their destination if possible."""
        if packet.destination in self.neighbours:
            return {packet.destination: 'direct'}
        # forward to nobody
        return {}

    def forward_all(self):
        """Try to forward all packets."""
        env = self.network.env

        for packet in self.buffer:
            if packet.expired:
                self.packet_expiry(packet)
                continue

            forwards = self.forward(packet)
            if forwards:
                delay = len(packet) / self.options['bandwidth']
                yield env.timeout(delay)
                self.stats['broadcasts'] += 1

            for target, reason in forwards.items():
                try:
                    self.send(packet, target, reason)
                    self.send_success(packet, target)
                except Node.SendFailed:
                    self.send_failure(packet, target)

    def send(self, packet, target, reason=None):
        """
        Send a packet to another node.

        Keyword Arguments:
        reason -- why the packet was sent (optional)
        """
        if target in self.neighbours:
            target.recv(packet, source=target)
            packet.sent(target, reason=reason)
        else:
            raise Node.SendFailed('nodes are not neighbours')

    def send_success(self, packet, target):
        """
        Call when a send succeeds.

        Removes the sent packet from the buffer.
        """
        self.buffer.remove(packet)

    def send_failure(self, packet, target):
        """Call when a send fails."""
        pass

    def packet_expiry(self, packet):
        """Call when a packet expires."""
        self.buffer.remove(packet)

    def recv(self, packet, source=None):
        """
        Recieve a packet.

        If the packet has reached it's destination notes the time then
        """
        if packet.destination is self:
            return

        if not self.buffer.add(packet):
            raise Node.SendFailed('buffer full')

    def __repr__(self):
        """Return representation."""
        return '%s(**%r)' % (
            self.__class__.__name__,
            self.options
        )


class EpidemicNode(Node):
    """Node which forwards epidemically."""

    def __init__(self, **options):
        """Create an epidemic node."""
        super().__init__(**options)
        self.__sent = defaultdict(set)

    def forward(self, packet):
        """
        Forward based on the epidemic heuristic.

        Epidemic Heuristic:
        - Forward the packet to all neighbours that I have not forwarded
          to so far.
        """
        forward = {
            neighbour: 'epidemic'
            for neighbour in self.neighbours
            if neighbour not in self.__sent[packet]
        }

        return forward

    def send_success(self, packet, target):
        """
        Call when a send succeeds.

        Adds target to packet tracking.
        """
        super().send_success(packet, target)
        self.__sent[packet].add(target)

    def packet_expiry(self, packet):
        """
        Call when a packet expires.

        Removes packet from tracking.
        """
        super().packet_expiry(packet)
        if packet in self.__sent:
            del self.__sent[packet]


Contact = namedtuple('Contact', ['time', 'a', 'b', 'join'])


def random_trace(nodes, seed=None, step=1):
    """Generate a random contact trace."""
    random = Random(seed)

    for now in count(step=step):
        join = bool(random.getrandbits(1))
        node_a, node_b = random.sample(nodes, 2)
        yield Contact(now, node_a, node_b, join)


def csv_trace(path):
    """Generate contact trace from csv file."""
    with open(path) as trace_file:
        reader = csv.reader(trace_file)
        # skip header
        next(reader)
        # skip first line, it is additional stats about the trace:
        # duration, placeholder, placehold, node_count
        next(reader)
        for row in reader:
            now, source, destination, join = map(int, row)
            yield Contact(now, source, destination, join)


Traffic = namedtuple('Traffic', [
    'source', 'destination', 'created', 'time_to_live', 'payload'
])


def random_traffic(nodes, step=1, seed=None, time_to_live=None, payload=None):
    """Generate traffic from random source to random destination every step."""
    random = Random(seed)

    if time_to_live is None:
        time_to_live = simpy.core.Infinity

    if payload is None:
        payload = 0

    for created in count(step=step):
        source, destination = random.sample(nodes, 2)
        yield Traffic(source, destination, created, time_to_live, payload)