import sys, os, time
import Numeric
from Gnuplot import Gnuplot, Data, Func
from Scientific.Statistics import *
from glob import glob

global g
g = Gnuplot(debug=1)

class Point(object):
    def __init__(self, time, cwnd, packets):
        self.time, self.cwnd, self.packets = time, cwnd, packets
        self.rate = float(packets)/time

def read(filename, low = 0, hi = sys.maxint):
    for line in file(filename, "r"):
        s = line.split()
        p = Point(float(s[0]), float(s[1]), int(s[2]))
        if low <= p.time <= hi:
            
            yield p
def wait():
#    raw_input()
    pass

#
# Exporters
#
def latex(f):
    g('set terminal push')
    #g('set terminal latex 10')
    g('set terminal epslatex color "default" 10')
    g('set format xy "$%g$"')
    g.set_string('output', f.replace(".tex",".eps"))
    #g('set size 1,1')
    #g('set size 3/5, 3/5')
    g.refresh()
    g('set terminal pop')
    g.set_string('output')
    # Darnit, fix the path in the output file
    time.sleep(3)
    os.system("echo ',s,{%(b)s},{%(b)s.eps},\nw' | ed %(e)s" %
              {'b': f.replace(".tex",""), 'e': f})

def png(f):
    g('set terminal push')
    g('set terminal png')
    g.set_string('output', f)
    g.refresh()
    g('set terminal pop')
    g.set_string('output')


class Entry:
    def __init__(self, time, cwnd, packets):
        self.time = float(time)
        self.cwnd = float(cwnd)
        self.packets = float(packets)

def parseFile(filename):
    f = file(filename, "r")
    r = []
    for line in f.readlines():
        words = line.split()
        ent = Entry(float(words[0]), float(words[1]), float(words[2]))
        r.append(ent)
    return r


def parseRTT():
    ret = []
    for run in range(5):
        runData = []
        for conc in range(2,5):
            concData = []
            for port in range(6829, 6829+conc):
                concData.append(parseFile("results/rtt-run" + str(run) +
                                          "-n" + str(conc) +
                                          "cwnd" + str(port)))
            runData.append(concData)
        ret.append(runData)
    return ret


def plot11():
    rttData = parseRTT()

    # Plot cwnd vs time
    g('set data style lines')
    g('set key left Left reverse')
    g('set lmargin 5')
    g('set ylabel "cwnd" 1.5, 0')
    g('set xlabel "Time (sec)" 0, .5')
    for conc in range(0,3):
        gpData = []
        for x in range(len(rttData[0][conc])):
            gpData.append(
                Data([[ent.time, ent.cwnd]
                              for ent in rttData[0][conc][x]],
                             title = "Conn " + str(x) + " (delay " + str(0.5*(x+1))+ " s)"))
        g('set title "Window size over time, ' + str(conc+2) + ' concurrent connections"')
        g.plot(*gpData)
        png("figures/1.1.2-n" + str(conc+2) + ".png")
        latex("figures/1.1.2-n" + str(conc+2) + ".tex")

    gpData = []
    for conc in range(0,3):
        concSeries = []
        for port in range(conc+2):
            rtt = 0.5 * (port+1)
            throughputs = [float(rttData[run][conc][port][-1].packets) /
                           rttData[run][conc][port][-1].time
                           for run in range(5)]
            thruMean = mean(throughputs)
            thruSD = standardDeviation(throughputs)
            concSeries.append([rtt, thruMean, thruSD])

        tput = [x[1] for x in concSeries]
        tput = [x/sum(tput) for x in tput]
        print "1.1.1: " + str(conc+2) + " connections: " + str(tput)
        gpData.append(Data(concSeries, title = str(conc+2) + " connections"))
    g('set data style errorlines')
    g('set key left Left reverse')
    g('set lmargin 5')
    g('set ylabel "Throughput (packets/sec)"')
    g('set xlabel "RTT (sec)"')
    g('set title "Throughput vs RTT"')
    g.plot(*gpData)
    png("figures/1.1.3.png")
    latex("figures/1.1.3.tex")

def parseLoss():
    ret = []
    for run in range(5):
        runData = []
        for conc in range(2,5):
            concData = []
            for port in range(6829, 6829+conc):
                concData.append(parseFile("results/loss-run" + str(run) +
                                          "-n" + str(conc) +
                                          "cwnd" + str(port)))
            runData.append(concData)
        ret.append(runData)
    return ret

def plot12():
    lossData = parseLoss()

    # Plot cwnd vs time
    g('set data style linespoints')
    g('set key left Left reverse')
    g('set lmargin 5')
    g('set ylabel "cwnd"')
    g('set xlabel "Time (sec)"')
    for conc in range(0,3):
        gpData = []
        for x in range(len(lossData[0][conc])):
            gpData.append(
                Data([[ent.time, ent.cwnd]
                              for ent in lossData[0][conc][x]],
                             title = "Conn " + str(x) + " (loss " + str(0.05*(x+1))+ ")"))
        g('set title "Window size over time, ' + str(conc+2) + ' concurrent connections"')
        g.plot(*gpData)
        png("figures/1.2.2-n" + str(conc+2) + ".png")
        latex("figures/1.2.2-n" + str(conc+2) + ".tex")

    gpData = []
    for conc in range(0,3):
        concSeries = []
        for port in range(conc+2):
            loss = 0.05 * (port+1)
            throughputs = [float(lossData[run][conc][port][-1].packets) /
                           lossData[run][conc][port][-1].time
                           for run in range(5)]
            thruMean = mean(throughputs)
            thruSD = standardDeviation(throughputs)
            concSeries.append([loss, thruMean, thruSD])
        tput = [x[1] for x in concSeries]
        tput = [x/sum(tput) for x in tput]
        print "1.2.1: " + str(conc+2) + " connections: " + str(tput)
        gpData.append(Data(concSeries, title = str(conc+2) + " connections"))
    g('set data style errorlines')
    g('set key left Left reverse')
    g('set lmargin 5')
    g('set ylabel "Throughput (packets/sec)"')
    g('set xlabel "Loss rate"')
    g('set title "Throughput vs loss rate"')
    g.plot(*gpData)
    png("figures/1.2.3.png")
    latex("figures/1.2.3.tex")

def plot34():
    g('set data style lines')
    g('set key left Left reverse')
    g('set lmargin 5')
    g('set ylabel "Utilization"')
    g('set xlabel "r"')
    g('set xrange [0:1]')
    g.plot(Func('((1 - (1+x)/2)*(2/(1+x) - 1))/2 + ((2/(1+x)) -1)*(1+x)/2 + (2 - (2/(1+x)))'))
    png("figures/3.4.png")
    latex("figures/3.4.tex")


def match(x, pattern):
   assert len(x) == len(pattern)
   for a,b in zip(x,pattern):
       if b == None:
           continue
       if a != b:
           return False
   return True

def extract(d, pattern):
   out = {}
   for key, value in d.iteritems():
       if match(key, pattern):
           out[key] = value
   return out

def normalizetime(pts):
   pts.sort(lambda a,b: cmp(a[0],b[0]))
   normalized = [[a-pts[0][0], b] for a,b in pts]
   return [n for n in normalized if n[0] <= 30]

def p2readall():
    all = {}
    for filename in glob("results/p2-*"):
        x = filename[len("results/p2-"):].split("-")
        run = int(x[0][3:])
        vtype = x[1]
        part = x[2][0]
        loss = int(x[2][5:])
        rate = int(x[3][4:])
        queue = int(x[4][5:])
        delay = int(float(x[5][5:])*10)

        for i, p in enumerate(read(filename)):
            time = p.time
            key = (run,vtype,part,loss,rate,queue,
                   delay,time,i)
            assert key not in all, `key`
            all[key] = p
    return all

def hardcopy(g, filename):
   g.hardcopy(filename, enhanced=1, color=1, eps=True, fontsize=32)

def plot22():
    all = p2readall()
    g('set size 0.75,0.75')
    def pdicttodata(pdict):
        for p in pdict.keys():
            print p
        line = normalizetime([[p.time, p.cwnd] for p in pdict.values()])
        d = Data(line)
        return d
    def hc(x):
        fn = "figures/p22-"+x
        png(fn+".png")
        latex(fn+".tex")
        #print >> lout, r"\includegraphics[scale=0.3]{%s}" % fn

    # Configuration a
    for loss in [0, 5, 10, 15]:
        data = []
        for n, delay in enumerate([5, 10, 15, 20]):
            pdict = extract(all,
                            (0, "aimd", "a", loss, None, None,
                             delay, None, None))
            d = pdicttodata(pdict)
            d.set_option(with="lines", title = "path %d, delay %2.1f sec" % (n, delay/10.))
            data.append(d)
        pdict = extract(all,
                        (0, "shared", "a", loss, None, None,
                         None, None, None))
        d = pdicttodata(pdict)
        d.set_option(with="lines", title = "shared")
        data.append(d)
        g.title("Window evolution over time; Configuration a, %d\\\\%% loss" % (loss))
        g.xlabel("Time (s)")
        g.ylabel("cwnd")
        g.plot(*data)
        hc("a-%02d" % loss)

    # Configuration b
    for delay in [0, 5, 10, 15]:
        data = []
        for n, loss in enumerate([5, 10, 15, 20]):
            pdict = extract(all,
                            (0, "aimd", "b", loss, None, None,
                             delay, None, None))
            d = pdicttodata(pdict)
            d.set_option(with="lines", title = "path %d, %d\\\\%% loss" % (n, loss))
            data.append(d)
        pdict = extract(all,
                        (0, "shared", "b", None, None, None,
                         delay, None, None))
        d = pdicttodata(pdict)
        d.set_option(with="lines", title = "shared")
        data.append(d)
        g.title("Window evolution over time; Configuration b, delay %2.1f sec" % (delay/10.))
        g.xlabel("Time (s)")
        g.ylabel("cwnd")
        g.plot(*data)
        hc("b-%02d" % delay)

    for loss in [0, 5, 10, 15]:
        data = []
        for i in [1, 2, 3, 4]:
            pdict = extract(all,
                            (0, "aimd", "c", loss, 10*i, 5*i,
                             None, None, None))
            d = pdicttodata(pdict)
            d.set_option(with="lines", title = "path %d, bw %d, queue %d" % (i, 10*i, 5*i))
            data.append(d)
        pdict = extract(all,
                        (0, "shared", "c", loss, None, None,
                         None, None, None))
        d = pdicttodata(pdict)
        d.set_option(with="lines", title = "shared")
        data.append(d)
        g.title("Window evolution over time; Configuration c, %d\\\\%% loss" % (loss))
        g.xlabel("Time (s)")
        g.ylabel("cwnd")
        g.plot(*data)
        hc("c-%02d" % loss)

def plot21():
    all = p2readall()
    def hc(x):
        fn = "figures/p21-"+x
        png(fn+".png")
        latex(fn+".tex")
        #print >> lout, r"\includegraphics[scale=0.3]{%s}" % fn

    # A
    data = []
    for n, delay in enumerate([5, 10, 15, 20, None]):
        line = []
        for loss in [0, 5, 10, 15]:
            if delay is None:
                vtype = "shared"
            else:
                vtype = "aimd"
                
            pdict = extract(all,
                            (None, vtype, "a", loss, None, None,
                             delay, None, None))
            tput = [x.rate for x in pdict.values()]
            if delay == None:
                tput = [x/4. for x in tput]
            line.append([loss, mean(tput), standardDeviation(tput)])
            
        d = Data(line)
        d.set_option(with="errorlines")
        if delay is None:
            d.set_option(title="shared")
        else:
            d.set_option(title = "path %d, delay = %2.1fs" % (n, delay/10.0))
        data.append(d)
        
    g.title("Configuration a, bandwidth = 20 packets/s, queue = 10 packets")
    g.xlabel("loss rate (\\\\%)")
    g.ylabel("Throughput (packets/s)")
    g.plot(*data)
    hc("a")


    # B
    data = []
    for n, loss in enumerate([5, 10, 15, 20, None]):
        line = []
        for delay in [0, 5, 10, 15]:
            if loss is None:
                vtype = "shared"
            else:
                vtype = "aimd"
                
            pdict = extract(all,
                            (None, vtype, "b", loss, None, None,
                             delay, None, None))
            tput = [x.rate for x in pdict.values()]
            if loss == None:
                tput = [x/4. for x in tput]
            line.append([delay, mean(tput), standardDeviation(tput)])
            
        d = Data(line)
        d.set_option(with="errorlines")
        if loss is None:
            d.set_option(title="shared")
        else:
            d.set_option(title = "path %d, loss = %d\\\\%%" % (n, loss))
        data.append(d)
        
    g.title("Configuration b, bandwidth = 20 packets/s, queue = 10 packets")
    g.xlabel("delay (ds)")
    g.ylabel("Throughput (packets/s)")
    g.plot(*data)
    hc("b")


    # C
    data = []
    for n in [1,2,3,4, None]:
        line = []
        for loss in [0, 5, 10, 15]:
            if n is None:
                vtype = "shared"
                bw = queue = None
            else:
                vtype = "aimd"
                bw = 10*n
                queue = 5*n
                
            pdict = extract(all,
                            (None, vtype, "c", loss, bw, queue,
                             None, None, None))
            tput = [x.rate for x in pdict.values()]
            if n == None:
                tput = [x/4. for x in tput]
            line.append([loss, mean(tput), standardDeviation(tput)])
            
        d = Data(line)
        d.set_option(with="errorlines")
        if n is None:
            d.set_option(title="shared")
        else:
            d.set_option(title = "path %d, bandwidth = %d packets/s, queue = %d packets" % (n, bw, queue))
        data.append(d)
        
    g.title("Configuration c, delay 5 ds")
    g.xlabel("loss rate (\\\\%)")
    g.ylabel("Throughput (packets/s)")
    g.plot(*data)
    hc("c")
        

def main():
#    plot11()
#    plot12()
#    plot34()
    plot21()
#    plot22()
    
if __name__ == "__main__":
    main()
