Skip to content
Snippets Groups Projects
display.py 2.1 KiB
Newer Older
ArktikHunter's avatar
ArktikHunter committed
"""
example module for displaying sim results graphically
"""

__all__ = [
    'display'
]

__author__ = "Hunter McConnell <hunter.mcconnell@usask.ca>"


import sys
import matplotlib.pyplot as plt
from multiprocessing import Pool

from shed import Simulation, run_simulation
from pydtnsim import Network, RandomTraffic, Node, EpidemicNode, CSVTrace
from pydtnsim.community import BubbleKCliqueNode, BubbleLouvainNode
from pydtnsim.community import HCBFKCliqueNode, HCBFLouvainNode


# run the sim on each taxi dataset, then plot the results together

def main():
    pool = Pool()

    traces = [
        '../../processed/taxi10',
        '../../processed/taxi20',
        '../../processed/taxi50',
        #'../../processed/taxi100',
    ]

    node_types = [
        Node,               # direct delivery
        EpidemicNode,
        BubbleKCliqueNode,
        HCBFKCliqueNode,
        BubbleLouvainNode,
        HCBFLouvainNode,
    ]

    simulations = []

    for trace in traces:
        for node_type in node_types:
            sim = Simulation(trace=trace, node_type=node_type, seed=None)
            simulations.append(sim)

    results = {}

    for stats in pool.imap_unordered(run_simulation, simulations):
        if stats["node_type"] not in results:
            results[stats["node_type"]] = {}
        results[stats["node_type"]][int(traces.index(stats["trace"]))] = stats
        print(stats)

    # graph the results

    figure, (ax1, ax2) = plt.subplots(2, sharex=True)

    ax1.set_ylabel("Delivery Ratio")
    ax2.set_ylabel("Delivery Cost")
    
    x = [
        10,
        20,
        50,
        #100,
    ]

    for node, stats in results.items():
        test = list(zip(*sorted(stats.items())))
        y1 = [stat["delivery-ratio"] for stat in test[1]]
        y2 = [stat["delivery-cost"] for stat in test[1]]
        #y3 = [stat["latency"] for stat in test[1]]

        ax1.plot(x, y1, 'o-', label=node)
        ax2.plot(x, y2, 'o-', label=node)

    plt.xlabel("Distance threshhold (m)")
    plt.legend()
    plt.show()



if __name__ == '__main__':
    #sys.exit(main(parse_args(sys.argv[1:])))
    sys.exit(main())