import math, time, os
from jarray import zeros, array
from Gnuplot import Gnuplot, Data

from java.lang import Comparable, String

from edu.mit.six825.bn.functiontable import *
from edu.mit.six825.bn.bayesnet import *
from edu.mit.six825.bn.inputs import *
from techniques.utils import *
import VESolver
import LWSolver
import GibbsSolver

# Some usful constants

TRUE = ComparableBoolean.TRUE
FALSE = ComparableBoolean.FALSE
burglary = Nets.getBurglary()
insurance = Nets.getInsurance()
carpo = Nets.getCarpo()

# Global gnuplot instance
global g
g = Gnuplot(debug = 1)

# Utility functions
def enumerate(lst):
    return [(i,lst[i]) for i in range(len(lst))]


# mean/stddev, from
# http://www.atnf.csiro.au/people/Enno.Middelberg/python/avg.p
def stat(numbers):
    sum=0
    # Calculate avg, sx and sigmax
    if len(numbers)>1:
	for i in numbers:
	    sum=sum+i
	    # Store total sum for later
	    total=sum
	avg=sum/len(numbers)
	sum=0
	for i in numbers:
	    sum=sum+(i-avg)**2
	sx=math.sqrt(sum/(len(numbers)-1))
	sigmax=math.sqrt(sum/len(numbers))
    else:
	avg=numbers[0]
	sx=0
	sigmax=0
	sum=numbers[0]
    # Calculate mean
    numbers.sort()
    if (len(numbers) % 2)==1:
	mean=numbers[(len(numbers)-1)/2]
    else:
	mean=(numbers[(len(numbers)/2)-1]+numbers[(len(numbers)/2)])/2
    # Return results
    return avg, sx, sigmax, total, mean

def avg(nums):
    return stat(nums)[0]
def stddev(nums):
    return stat(nums)[2]


# Gnuplot exporters

def latex(f):
    g('set terminal push')
    #g('set terminal latex 10')
    g('set terminal epslatex color "default" 10')
    g('set format xy "$%g$"')
    g.set_string('output', f.replace(".tex",".eps"))
    #g('set size 1,1')
    #g('set size 3/5, 3/5')
    g.refresh()
    g('set terminal pop')
    g.set_string('output')
    # Darnit, fix the path in the output file
    time.sleep(3)
    os.system("echo ',s,{%(b)s},{%(b)s.eps},\nw' | ed %(e)s" %
              {'b': f.replace(".tex",""), 'e': f})

def png(f):
    g('set terminal push')
    g('set terminal png')
    g.set_string('output', f)
    g.refresh()
    g('set terminal pop')
    g.set_string('output')

# Convert strings to Java strings
def stringify(x):
    if type(x) == type("foo"):
        return String(x)
    else:
        return x


# Represents all the parameters to a query

class Query:
    def __init__(self, bn, queryName, evidenceList):
        self.bn = bn
        self.queryVar = bn.nodes.getNode(queryName)
        functionVars = [bn.nodes.getNode(x[0]).var for x in evidenceList]
            
        evidenceVals = [stringify(x[1]) for x in evidenceList]
        self.evidence = Assignment(
            FunctionVariableSet(array(functionVars, FunctionVariable)),
            array(evidenceVals, Comparable))

def queryResToDict(res):
    d = {}
    assert(res.variables.size() == 1)
    var = res.variables.getVariable(0)
    domain = var.domain
    for i in range(domain.size()):
        x = domain.getValue(i)
        as = Assignment(
            FunctionVariableSet(array([var], FunctionVariable)),
            array([stringify(x)], Comparable))
        d[str(x)]= res.evaluate(as)
    
    return d

def runTest(solverType, query,
            orderer=VESolver.GreedyOrder(),
            weight=10.0,
            samples=10000,
            discard=5000,
            randomOrder=0):
    solver = solverType(query.bn)
    solver.setEvidence(query.evidence)
    if solverType == VESolver:
        solver.setOrderer(orderer)
    elif solverType == LWSolver:
        solver.setRequiredWeight(weight)
    elif solverType == GibbsSolver:
        solver.setRequiredSamples(samples)
        solver.setDiscardPrefixLen(discard)
        solver.randomFlipChoice(randomOrder)
        
    return queryResToDict(solver.query(query.queryVar))

def kullbackLeibler(exactDist, approxDist):
    totalDiv = 0.0
    for x in exactDist.keys():
        if approxDist[x] == 0:          # Avoid division by zero
            approx = 0.000001
        else:
            approx = approxDist[x]
        totalDiv += (exactDist[x] *
                     math.log(exactDist[x] / approx) / math.log(2))
    return totalDiv


def plotGibbsBurnin(query, filename, titleinfo):
    burninFractions = [0.1 * x for x in range(0,9)]
    sampleCounts = [10000]
    runs = 5

    exactDist = runTest(VESolver, query)

    g('set data style errorlines')
    g('set key left Left reverse')
    g('set lmargin 5')
    g('set ylabel "KL divergence" 1.5, 0')
    g('set xlabel "Fraction of samples discarded" 0, .5')
    g('set title "Divergence vs burnin period --- \%s"' % titleinfo)
    gpData = []
    for sampleCount in sampleCounts:
        series = []
        for burninFraction in burninFractions:
            divergences = [kullbackLeibler(exactDist, runTest(
                GibbsSolver, query,
                samples = int(sampleCount),#*(1+burninFraction)),
                discard = int(sampleCount*
                              burninFraction)))
                           for run in range(runs)]
            divMean = avg(divergences)
            divSD = stddev(divergences)
            series.append([burninFraction, divMean, divSD])
        gpData.append(Data(series, title=str(sampleCount) + " samples"))
    g.plot(*gpData)
    png(filename + ".png")
    latex(filename + ".tex")

def plotLWQuality(query, filename, titleinfo):
    #weights = [0.1, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
    weights = [10, 20, 40, 100, 200, 400, 1000, 2000, 4000, 10000]
    runs = 10

    exactDist = runTest(VESolver, query)

    g('set size 0.75,0.75')
    g('set logscale x 10')
    g('set data style linespoints')
    g('set key left Left reverse')
    g('set lmargin 5')
    g('set ylabel "KL divergence" 1.5, 0')
    g('set xlabel "Samples performed" 0, .5')
    g('set title "LW quality vs samples --- %s"' % titleinfo)
    g('set yrange [0:*]')
    g('set xrange [10:10000]')

    gpData = []
    aggData = [[] for x in weights]

    for run in range(runs):
        series = []
        for n,weight in enumerate(weights):
            div = kullbackLeibler(exactDist,
                                  runTest(LWSolver, query, weight = weight))
            series.append([weight, div])
            aggData[n].append(div)
        gpData.append(Data(series, title="Run " + str(run)))
    g.plot(*gpData)
    png(filename + ".png")
    latex(filename + ".tex")

    g('set data style errorlines')
    series = []
    for x,weight in zip(aggData,weights):
        series.append([weight, avg(x), stddev(x)])
    g.plot(Data(series))
    png(filename + "-agg.png")
    latex(filename + "-agg.tex")

def plotGibbsQuality(query, filename, titleinfo):
    #weights = [0.1, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
    weights = [100, 200, 400, 1000, 2000, 4000, 10000, 20000, 40000, 100000, 200000, 400000, 1000000]
    runs = 10
    dRate = 0.4

    exactDist = runTest(VESolver, query)

    g('set size 0.75,0.75')
    g('set logscale x 10')
    g('set data style linespoints')
    g('set key left Left reverse')
    g('set lmargin 5')
    g('set ylabel "KL divergence" 1.5, 0')
    g('set xlabel "Samples performed" 0, .5')
    g('set title "Gibbs quality vs samples --- %s"' % titleinfo)
    g('set yrange [0:*]')
    g('set xrange [100:1000000]')
    gpData = []
    aggData = [[] for x in weights]

    for run in range(runs):
        series = []
        for n,weight in enumerate(weights):
            div = kullbackLeibler(exactDist,
                                  runTest(GibbsSolver, query,
                                          samples = weight,
                                          discard = int(dRate * weight)))
            series.append([weight, div])
            aggData[n].append(div)
        gpData.append(Data(series, title="Run " + str(run)))
    g.plot(*gpData)
    png(filename + ".png")
    latex(filename + ".tex")

    g('set data style errorlines')
    series = []
    for x,weight in zip(aggData,weights):
        series.append([weight, avg(x), stddev(x)])
    g.plot(Data(series))
    png(filename + "-agg.png")
    latex(filename + "-agg.tex")


query11 = Query(burglary, "Burglary",
                [("JohnCalls", TRUE), ("MaryCalls", TRUE)])
query12 = Query(burglary, "Earthquake",
                [("Burglary", TRUE), ("JohnCalls", TRUE)])
query13 = Query(insurance, "PropCost",
                [("Age", "Adolescent"), ("Airbag", "False"),
                 ("Mileage", "TwentyThou")])
query21 = Query(insurance, "PropCost",
                [("Age", "Adolescent"), ("Airbag", "False"),
                 ("MakeModel", "Luxury"), ("Mileage", "TwentyThou")])
query22 = Query(insurance, "PropCost",
                [("Age", "Adolescent"), ("Airbag", "False"),
                 ("GoodStudent", "True"), ("Mileage", "TwentyThou")])
query23 = Query(carpo, "N104",
                [("N116", "0"), ("N41", "2"), ("N84", "1")])
query24 = Query(carpo, "N73",
                [("N116", "0"), ("N152","1"), ("N43", "1")])


# exact = runTest(VESolver, query11, orderer=VESolver.GreedyOrder())
# lw = runTest(LWSolver, query11, weight=5)

# print "Exact:", exact
# print "LW:", lw
# print "KL divergence: ", kullbackLeibler(exact, lw)
#print runTest(GibbsSolver, query13)

# plotGibbsBurnin(query11, "figures/gibbs-burnin-query11",
#                  "P(Burglary | John,Mary)")
# plotGibbsBurnin(query12, "figures/gibbs-burnin-query12",
#                  "P(Earthquake | Burglary,John)")
# plotGibbsBurnin(query13, "figures/gibbs-burnin-query13",
#                 "P(PropCost | Age=Adolescent, Airbag=False, Mileage=TwentyThou)")
# plotGibbsBurnin(query21, "figures/gibbs-burnin-query21",
#                 "P(PropCost | Age=Adolescent, Airbag=False, MakeModel=Luxury, Mileage=TwentyThou)")
# plotGibbsBurnin(query22, "figures/gibbs-burnin-query22",
#                 "P(PropCost | Age=Adolescent, Airbag=False, GoodStudent=True, Mileage=TwentyThou)")
# plotGibbsBurnin(query23, "figures/gibbs-burnin-query23",
#                 "P(N104 | N116=0, N41=2, N84=1)")
# plotGibbsBurnin(query24, "figures/gibbs-burnin-query24",
#                 "P(N73 | N116=0, N152=1, N43=1)")


plotLWQuality(query21, "figures/lw-quality-query21",
              "Query 1")
plotLWQuality(query22, "figures/lw-quality-query22",
              "Query 2")
plotLWQuality(query23, "figures/lw-quality-query23",
              "Query 3")
plotLWQuality(query24, "figures/lw-quality-query24",
              "Query 4")

plotGibbsQuality(query21, "figures/gibbs-quality-query21",
              "Query 1")
plotGibbsQuality(query22, "figures/gibbs-quality-query22",
              "Query 2")
plotGibbsQuality(query23, "figures/gibbs-quality-query23",
              "Query 3")
plotGibbsQuality(query24, "figures/gibbs-quality-query24",
              "Query 4")

