# Import Psyco if available
try:
    import psyco
    psyco.full()
except ImportError:
    pass

import sys, os, cPickle, codecs, sha, random, bisect, operator
from utils import *
import Gnuplot, Gnuplot.funcutils


CDF_STEPS=100

g = Gnuplot.Gnuplot(debug=1)
random.seed(42)
def latex(f):
    g('set terminal push')
    #g('set terminal latex 10')
    g('set terminal epslatex color "default" 10')
    g('set format xy "$%g$"')
    g.set_string('output', f.replace(".tex",".eps"))
    g('set size .675, .675')
    #g('set size 3/5, 3/5')
    g.refresh()
    g('set terminal pop')
    g.set_string('output')
    # Darnit, fix the path in the output file
#    os.system("echo ',s,{%(b)s},{figures/%(b)s},\nw' | ed %(e)s" %
#              {'b': f.replace(".tex",""), 'e': f})

def png(f):
    g('set terminal push')
    g('set terminal png')
    g.set_string('output', f)
    g.refresh()
    g('set terminal pop')
    g.set_string('output')

def latexandpng(f):
    latex(f + ".tex")
    png(f + ".png")


STOPWORDSFILE = "english.stop"
stopwords = set([x.strip().replace("'","")
                 for x in file(STOPWORDSFILE)])
xchars = []
rchars = []

for c in range(256):
    if ord("A") <= c <= ord("Z") or \
       ord("a") <= c <= ord("z") or \
       c >= 0x80:
        xchars.append(chr(c))
    else:
        xchars.append(" ")
xchars = "".join(xchars)
rchars = "'`"

class translatexchars:
    def __getitem__(self, c):
        if c < 256 and chr(c) in rchars:
            return None
    
        if ord("A") <= c <= ord("Z") or \
                ord("a") <= c <= ord("z") or \
                ord("0") <= c <= ord("9") or \
                c >= 0x80:
            return c
        else:
            return u" "
translator = translatexchars()


def getKeywords(s):
    x = set(s.translate(translator).lower().split())
#    x -= stopwords
    return [i for i in x if len(i) > 2]

def getKSSSets(s, K):
    kmd = list(getKeywords(s))
#        print "K =", K
#        print "len(kmd) =", len(kmd)
    if K > len(kmd):
        K = len(kmd)                # Full power set is smaller than K
        
    for k in range(1,K+1):          # Generate sets of size k (1 to K)
        #            print "k =", k
        ind = range(k)              # Indices of each element in the set 
        yield frozenset([kmd[x] for x in ind])
        done = False
        while not done:
            # Find the next possible set of unique indices
            for i in reversed(range(k)):
                ind[i] += 1
                if ind[i] == len(kmd):
                    ind[i] = 0
                    if i == 0:
                        done = True
                    else:
                        break
                    
            if done:
                break
                
            if len(set(ind)) != k:
                continue            # Indices not unique, try again

            # Make sure the indices are in monotonically
            # increasing order (if not, we get redundancy).
            if ind != sorted(ind):
                continue
                
#                print ind
            yield frozenset([kmd[x] for x in ind])

def getIndexNames(s, K):
    for kssSet in getKSSSets(s, K):        
        n = ""
        l = list(kssSet)
        l.sort()
        for x in l:
            n += x + "&"
            yield n
        

def buildInvertedIndex(sourcefilename, outfilename, K):
    index = {}
    seen = set()
    def save():
        cPickle.dump(index, file(outfilename, "w"))
                     
    src = codecs.open(sourcefilename, mode="r", encoding="utf-8")
    for i, l in enumerate(src):
        if i % 10000 == 0:
            print i,len(seen),len(index)
            save()
        spl = l.split("\t")
        filename = spl[0]
        try:
            size = int(spl[3])
        except IndexError:
            size = 0

        ident = (filename.lower(), size)
        if ident in seen:
            continue
        seen.add(ident)
        
        for x in getIndexNames(filename, K):
            try:
                index[x].append((filename,size))
            except KeyError:
                index[x] = [filename]
    save()

def plotIndexSizeDistribution(invIndexes, numNodes, numVnodes):
    PERCENTILES = [50, 75, 80, 85, 90, 95, 96, 97, 98, 99]
    percentiles = [list() for x in PERCENTILES]

    # Figure out where the nodes are placed
    nodes = [(sha.new(str(random.uniform(0,1))).digest(), i)
               for j in range(numVnodes)
               for i in range(numNodes)]
    nodes.sort()
    nodeIDs = [x[0] for x in nodes]
    nodeNums = [x[1] for x in nodes]

    series = []
    
    for K, indexAndName in enumerate(invIndexes):
        (invIndex, seriesName) = indexAndName
        indexLengths = [0 for i in range(numNodes)]

        for indexName, indexContents in invIndex.iteritems():
            # Figure out which key it is
            indexHash = sha.new(indexName.encode("utf-8")).digest()
            responsibleNodeIndex = bisect.bisect_left(nodeIDs, indexHash)
            if responsibleNodeIndex == len(nodeIDs):
                responsibleNodeIndex = 0
            responsibleNode = nodeNums[responsibleNodeIndex]
            indexLengths[responsibleNode] += len(indexContents)
        print "Total length", sum(indexLengths), "; mean", mean(indexLengths), "; stddev", stddev(indexLengths)

        scaledLengths = [float(x)/sum(indexLengths)*numNodes for x in indexLengths]
        print scaledLengths
        cdf = []
        points = [float(x)/100*stddev(scaledLengths)+mean(scaledLengths)
                  for x in range(-200, 200)]
        for i in points:
            if (i <= 0):
                continue
            cdf.append((i,
                        float(len([x for x in scaledLengths
                                   if x <= i]))/numNodes))
        for num, l in zip(PERCENTILES, percentiles):
            l.append((K+1, percentile(scaledLengths, num)))
        series.append(Gnuplot.Data(cdf, title=seriesName))

    g('set data style linespoints')
    g('set nologscale')
    g('set key left Left reverse')
    g('set ylabel "Fraction of nodes" 1.5, 0')
    g('set xlabel "Relative query load" 0, .5')
    g.plot(*series)
    latexandpng("index-load-distribution")

    for num, l in zip(PERCENTILES, percentiles):
        g('set data style linespoints')
        g('set nologscale')
        g('set key left Left reverse')
        g('set ylabel "%dth percentile query load 1.5", 0' % num)
        g('set xlabel "K" 0, .5')
        g.plot(Gnuplot.Data(l))
        latexandpng("index-load-" + str(num) + "-percentile" )

    
def plotQueryLoadDistribution(queryFilename, Kmax, numNodes, numVnodes):
    PERCENTILES = [50, 75, 80, 85, 90, 95, 96, 97, 98, 99]
    percentiles = [list() for x in PERCENTILES]
    series = []
    # Figure out where the nodes are placed
    nodeIDs = [(sha.new(str(random.uniform(0,1))).digest(), i)
               for j in range(numVnodes)
               for i in range(numNodes)]
    nodeIDs.sort()

    for K in range(1, Kmax+1):
    
        queryCounts = [0 for i in range(numNodes)]

        queryFile = codecs.open(queryFilename, encoding="utf-8")
        for query in queryFile:
            indexes = list(getIndexNames(query, K))
            if len(indexes) == 0:
                continue
            indexName = random.choice(indexes)
            indexHash = sha.new(indexName.encode("utf-8")).digest()
            responsibleNode = nodeIDs[0][1]
            for nodeKey, nodeID in nodeIDs:
                if nodeKey >= indexHash:
                    responsibleNode = nodeID
                    break
            queryCounts[responsibleNode] += 1
        
        print "Total length", sum(queryCounts), "; mean", mean(queryCounts), "; stddev", stddev(queryCounts)

        scaledCounts = [float(x)/sum(queryCounts)*numNodes
                        for x in queryCounts]
        cdf = []
        points = [float(x)/100*stddev(scaledCounts)+mean(scaledCounts)
                  for x in range(-100, 100)]
        for i in points:
            if (i <= 0):
                continue
            cdf.append((i,
                        float(len([x for x in scaledCounts
                                   if x <= i]))/numNodes))
        for num, l in zip(PERCENTILES, percentiles):
            l.append((K, percentile(scaledCounts, num)))
        series.append(Gnuplot.Data(cdf, title="K=%d" % K))

        
    g('set data style linespoints')
    g('set nologscale')
    g('set key left Left reverse')
    g('set ylabel "Fraction of nodes" 1.5, 0')
    g('set xlabel "Relative query load" 0, .5')
    g.plot(*series)
    latexandpng("query-load-distribution")

    for num, l in zip(PERCENTILES, percentiles):
        g('set data style linespoints')
        g('set nologscale')
        g('set key left Left reverse')
        g('set ylabel "%dth percentile query load", 1.5, 0' % num)
        g('set xlabel "K" 0, .5')
        g.plot(Gnuplot.Data(l))
        latexandpng("query-load-" + str(num) + "-percentile" )
    
    

def plotQueryPopularity(queryFilename):
    popularity = {}

    queryFile = codecs.open(queryFilename, encoding="utf-8")
    for query in queryFile:
        query = query.strip()
        try:
            popularity[query] += 1
        except KeyError:
            popularity[query] = 1

    items = popularity.items()
    items.sort(key=operator.itemgetter(1), reverse=True)

    for item in items[0:10]:
        print str(item)

    d = Gnuplot.Data([x[1] for x in items])
    g('set data style linespoints')
    g('set logscale xy 10')
    g('set key left Left reverse')
    g('set ylabel "Frequency" 1.5, 0')
    g('set xlabel "Rank" 0, .5')
    g.plot(d)
    latexandpng("query-popularity-gtkg")


def processQueriesFile(inFilename, outFilename):
    inFile = codecs.open(inFilename, encoding="utf-8")
    outFile = file(outFilename, "w")

    for query in inFile:
        query = query.strip()
        keywords = set(getKeywords(query))
        keywords -= stopwords
        if (len(keywords) == 0):
            continue
        outFile.write(" ".join(keywords).encode("utf-8"))
        outFile.write("\t")
        outFile.write(query.encode("utf-8"))
        outFile.write("\n")

def processResultsFile(inFilename, outFilename):
    PADLEN = 1023
    inFile = codecs.open(inFilename, encoding="utf-8")
    outFile = file(outFilename, "w")

    seen = set()

    for l in inFile:
        spl = l.split("\t")
        filename = spl[0]
        try:
            size = int(spl[3])
        except IndexError:
            size = 0

        ident = (filename.lower(), size)
        if ident in seen:
            continue
        seen.add(ident)

    lst = list(seen)
    random.shuffle(lst)

    totalSize = 0
    for name, size in lst:
        if (totalSize > 1.9 * 1024 * 1024 * 1024):
            break
        
        keywords = set(getKeywords(name))
        keywords -= stopwords
        if (len(keywords) == 0):
            continue

        s = " ".join(keywords).encode("utf-8") + "\t" + name.encode("utf-8") + "\t"
        if (len(s) > PADLEN):
            print "string too long"
            continue
        totalSize += 1024
        outFile.write(s.ljust(PADLEN))
        outFile.write("\n")
    
#buildInvertedIndex("results.txt", "results-inv.data-1", 1)
#buildInvertedIndex("results.txt", "results-inv.data-2", 2)
#buildInvertedIndex("results-100000.txt", "results-100000-inv.data-3", 3)
#buildInvertedIndex("results-100000.txt", "results-100000-inv.data-4", 4)
#buildInvertedIndex("results-100000.txt", "results-100000-inv.data-5", 5)
#plotIndexSizeDistribution([
#    (cPickle.load(file("results-100000-inv.data-1")), "$K$ = 1"),
#    (cPickle.load(file("results-100000-inv.data-2")), "$K$ = 2"),
#    (cPickle.load(file("results-100000-inv.data-3")), "$K$ = 3"),
#    (cPickle.load(file("results-100000-inv.data-4")), "$K$ = 4")
    #(cPickle.load(file("results-100000-inv.data-3")), "$K$ = 5")
#    ], 100, 32)
#raw_input()
#plotIndexSizeDistribution(cPickle.load(file("results-100000-inv.data-3")), 100, 32)
#raw_input()

plotQueryLoadDistribution("queries.txt",
                          5, 1000, 8)
#plotQueryPopularity("../gtk-gnutella-trace/gtk-gnutella-0.96.3/query.log")
#raw_input()


#processQueriesFile("queries.txt", "../p2psim/queries.txt")
#processResultsFile("results.txt", "../p2psim/results.txt")
