import sys, os, socket

SFSPATH = "/usr/home/dan/chord-build/sfsnet"
sys.path.append(os.path.join(SFSPATH, "devel/"))
sys.path.append(os.path.join(SFSPATH, "svc/"))
sys.path.append(os.path.join(SFSPATH, "arpeggio/"))

from twisted.python import util
from twisted.internet import reactor,defer,protocol

import RPC
import chord_types, cd_prot

from utils import *

CD_PATH = "./cd"
DEBUG = False

def itos(val):
    """Converts an integer into the equivalent four-character
    string.
    
    Unfortunately, the Python sockets API is mixed with respect to
    which functions expect integers and which functions expect
    four-character strings (why do _any_ of them expect strings?).
    This function exists to make peace between the two worlds."""
    return (chr((val>>24)&0xFF) + chr((val>>16)&0xFF) +
            chr((val>>8)&0xFF) + chr(val&0xFF))

#
# Utility functions for Chord ring intervals
#

def between(a, b, n):
    """True if n is in (a,b) on the identifier circle."""
    if a == b:
        return n != a
    elif a < b:
        return (n > a) and (n < b)
    else:
        return (n > a) or (n < b)

def betweenLeftIncl(a, b, n):
    """True if n is in [a, b) on the identifier circle."""
    return n == a or between(a, b, n)

def betweenRightIncl(a, b, n):
    """True if n is in (a, b] on the identifier circle."""
    return n == b or between(a, b, n)

def betweenBothIncl(a, b, n):
    """True if n is in [a, b] on the identifier circle."""
    return n == a or n == b or between(a, b, n)


class ChordNode:
    """Representation of a Chord node's location: ID, host, port, and
    vnode number."""
    @typechecked(object, str, int, int, chord_types.bigint)
    def __init__(self, host, port, vnodeID, chordID):
        self.host = host
        self.port = port
        self.vnodeID = vnodeID
        self.chordID = chordID

    @typechecked(cd_prot.chord_node_wire_plus_id)
    def fromWirePlusID(wireID):
        wire = wireID.wire
        host = socket.inet_ntoa(itos(wire.machine_order_ipv4_addr))
        port = wire.machine_order_port_vnnum>>16
        vnodeID = wire.machine_order_port_vnnum&0xFFFF
        chordID = wireID.id
        return ChordNode(host, port, vnodeID, chordID)
    fromWirePlusID = staticmethod(fromWirePlusID)

    def __repr__(self):
        return "%s @ %s:%d #%d" % (self.chordID, self.host, self.port,
                                   self.vnodeID)

    def __eq__(self, other):
        return (self.host == other.host and
                self.port == other.port and
                self.vnodeID == other.vnodeID and
                self.chordID == other.chordID)

    def __hash__(self):
        return genericHash([self.host, self.port,
                            self.vnodeID, self.chordID])
        
#
# Note on client/server semantics:
#
#  The approach this takes is somewhat backwards.  Despite the fact
#  that cd acts as an RPC server, because it needs to be launched by
#  cc and have a one-to-one connection with it, cc acts as a one-shot
#  server that accepts the immediate connection that cd makes to it
#  over the provided UNIX domain socket when cd is started.
#    
#  It really does make sense if you think of the transport
#  server/client and the RPC server/client as separate things.
#  Really.  (If X can abuse the terms server and client, so can we!)
#

class RPCProtocol(protocol.Protocol):
    """Protocol for communicating via the SFS RPC protocol. This is
    largely hacked from SFS's RPC.AClient."""
    
    def connectionMade(self):
        self.xidcbmap = {}
        self.inbuffer = ''
        self.fragments = []
        self.bytesleft = 0
        self.factory.connected(self)
        
    def dataReceived(self, data):
        if len(self.inbuffer) > 0:
            data = self.inbuffer + data
        (fraglen, lastfrag) = RPC.parse_frag_len(data)
        # print "a", fraglen, lastfrag, len(data)
        while 4 + fraglen <= len(data):
            frag = data[4:4+fraglen]
            self.fragments.append(frag)
            if lastfrag:
                self.handle_reply()
            data = data[4+fraglen:]
            if len(data) > 0:
                (fraglen, lastfrag) = RPC.parse_frag_len(data)
                # print "b", fraglen, lastfrag, len(data)
            # else:
                # print "c"
        self.inbuffer = data

    def handle_reply(self):
        reply = ''.join(self.fragments)
        self.fragments = []
        u = RPC.unpack_reply(reply)
        # print "Reply for xid %x" % u[0]
        try:
            (cb, proc) = self.xidcbmap[u[0]]
            del self.xidcbmap[u[0]]
        except KeyError:
            sys.stderr.write("Reply for unknown xid %x received: %s" %
                             (u[0], str(u[1:])))
            return

        if not cb:
            return
        # XXX should really return some useful info to cb if error case
        #     either if denied, or if some weird bug like PROG_UNAVAIL.
        if u[1] == RPC.RPCProto.MSG_ACCEPTED:
            res = proc.unpack_res(u[-1])
            cb.callback(res)
        else:
            cb.callback(None)

    def sendRPC(self, pnum, arg):
        proc = self.factory.clientBase.module.programs[self.factory.clientBase.PROG][self.factory.clientBase.VERS][pnum]
        xid = self.factory.clientBase.xidgen.next()
        # print "Call for xid %x" % xid
        p = RPC.pack_call(xid, self.factory.clientBase.PROG,
                          self.factory.clientBase.VERS, pnum)
        proc.pack_arg(p, arg)
        request = p.get_buffer()
        val = RPC.strbuf()
        RPC.writefrags(request, val.write)
        self.transport.write(val.s)
        cb = defer.Deferred()
        self.xidcbmap[xid] = (cb, proc)
        return cb
        
class RPCFactory(protocol.ServerFactory):
    protocol = RPCProtocol
    def __init__(self, connectDfd, module, PROG, VERS):
        self.connectDfd = connectDfd
        self.module = module
        self.PROG = PROG
        self.VERS = VERS
    def connected(self, proto):
        self.clientBase = RPC.ClientBase(self.module, self.PROG, self.VERS,
                                         None, None)
        self.connectDfd.callback(proto)

class LoggingProcessProtocol(protocol.ProcessProtocol):
    """Log all output and events from a process to stdout."""
    
    def __init__(self):
        pass
    def log(self, msg):
        print "CD process:", msg
    def connectionMade(self):
        self.log("connectionMade!")
    def outReceived(self, data):
        self.log("stdout: %s" % data)
    def errReceived(self, data):
        self.log("stderr: %s" % data)
    def inConnectionLost(self):
        self.log("inConnectionLost! stdin is closed! (we probably did it)")
    def outConnectionLost(self):
        self.log("outConnectionLost! The child closed their stdout.")
    def errConnectionLost(self):
        self.log("errConnectionLost! The child closed their stderr.")
    def processEnded(self, status_object):
        self.log("processEnded, status %d" % status_object.value.exitCode)


class Chord:
    STATUS_UNCONNECTED = 0
    STATUS_CONNECTING = 1
    STATUS_CONNECTED = 2

    def __init__(self):
        self.status = Chord.STATUS_UNCONNECTED
        self.vnodes = []

    @typechecked(object, str, int, str, int)
    def startChord(self, localHost, localPort, wellknownHost, wellknownPort):
        """Start the Chord daemon (cd), establish a RPC connection to
        it, and join a Chord ring. The Chord protocol connection is
        bound to localHost:localPort, and it is bootstrapped by
        connecting to the node at wellknownHost:wellknownPort. If the
        wellknown host/port is the same as the local host/port, then
        this will bootstrap a new Chord ring. Returns a Deferred that
        fires with an unspecified value once the connection is
        established; no queries can be executed until then."""
        
        if self.status != self.STATUS_UNCONNECTED:
            return defer.fail("Can't start Chord, it's not unconnected.")
        self.localHost = localHost
        self.localPort = localPort
        self.wellknownHost = wellknownHost
        self.wellknownPort = wellknownPort
        self.status = Chord.STATUS_CONNECTING

        dfd = defer.Deferred()
        dfd.addCallback(self.__connected)
        
        # Start the one-shot RPC server
        sockname = os.tempnam(None, "arpcc")
        reactor.listenUNIX(sockname,
                           RPCFactory(dfd, cd_prot, cd_prot.CD_PROGRAM, 1))

        # Start cd
        reactor.spawnProcess(LoggingProcessProtocol(),
                             CD_PATH, [CD_PATH, "-C", sockname])

        self.connectedDfd = defer.Deferred()
        return self.connectedDfd

    def __connected(self, proto):
        """Internal function called when RPC connection to cd is
        established. Instantiates and connects Chord."""
        
        self.proto = proto
        print "Established RPC connection to cd"
        
        # Start Chord
        arg = cd_prot.cd_newchord_arg()
        arg.wellknownhost = socket.gethostbyname(self.wellknownHost)
        arg.wellknownport = self.wellknownPort
        arg.myname = self.localHost
        arg.myport = self.localPort
        arg.maxcache = 64  # XXX Uh
        arg.nvnodes = 1    # XXX Uh again
        arg.routing_mode = cd_prot.MODE_CHORD
        
        print "Instantiating Chord"
        newchordRes = self.proto.sendRPC(cd_prot.CD_NEWCHORD, arg)
        newchordRes.addCallback(self.__instantiated)

        return proto

    def __instantiated(self, res):
        """Internal function called once Chord has been instantiated
        and connected. Gathers vnode information and calls
        callbacks."""
        if res.stat == chord_types.CHORD_NOTINRANGE:
            print "Chord object already exists, continuing anyways"
        else:
            print "Chord object created"
            print "Vnodes:"
            for x in res.resok.vnodes:
                print str(x)
            self.vnodes = res.resok.vnodes
            
            self.status = Chord.STATUS_CONNECTED
            self.connectedDfd.callback(self)

    def lookup(self, key):
        """Perform a lookup of a key, returning the list of successors
        and routing route."""

        if self.status != self.STATUS_CONNECTED:
            return defer.fail("Can't perform a lookup without"
                              " a connected Chord.")
        dfd = defer.Deferred()
        arg = cd_prot.cd_lookup_arg()
        arg.vnode = self.vnodes[0]
        arg.key = chord_types.bigint(key)
        if DEBUG:
            print "Looking up", arg.key
        rpcDfd = self.proto.sendRPC(cd_prot.CD_LOOKUP, arg)

        def lookupCB(res):
            if res.stat == chord_types.CHORD_OK:
                succs = [ChordNode.fromWirePlusID(wireID)
                         for wireID in res.resok.successors]
                route = [ChordNode.fromWirePlusID(wireID)
                         for wireID in res.resok.route]
                if DEBUG:
                    print "Found", arg.key, "at:"                     
                    for node in succs:
                        print " %s" % node
                    print "Route:"                     
                    for node in route:
                        print " %s" % node
                dfd.callback(succs)
            else:
                print "Lookup failed:", res.stat
                dfd.errback("Lookup failed: " + str(res.stat))

        rpcDfd.addCallback(lookupCB)
        return dfd

    def getSuccList(self, vnode):
        if self.status != self.STATUS_CONNECTED:
            return defer.fail("Can't get successor lists when Chord "
                              "isn't connected.")
        
        if vnode not in self.vnodes:
            return defer.fail("Invalid vnode")
        
        dfd = defer.Deferred()
        arg = cd_prot.cd_getsucclist_arg()
        arg.vnode = vnode
        if DEBUG:
            print "Getting successor list for ", arg.vnode
        rpcDfd = self.proto.sendRPC(cd_prot.CD_GETSUCCLIST, arg)

        def succlistCB(res):
            if res.stat == chord_types.CHORD_OK:
                nodes = [ChordNode.fromWirePlusID(wireID)
                         for wireID in res.resok.nodes]
                if DEBUG:
                    print "Successors:"
                    for node in nodes:
                        print " %s" % node
                dfd.callback(nodes)
            else:
                print "GetSuccList failed:", res.stat
                dfd.errback("GetSuccList failed: " + str(res.stat))

        rpcDfd.addCallback(succlistCB)
        return dfd

    def getPredList(self, vnode):
        if self.status != self.STATUS_CONNECTED:
            return defer.fail("Can't get predecessor lists when Chord "
                              "isn't connected.")
        
        if vnode not in self.vnodes:
            return defer.fail("Invalid vnode")
        
        dfd = defer.Deferred()
        arg = cd_prot.cd_getsucclist_arg()
        arg.vnode = vnode
        if DEBUG:
            print "Getting predecessor list for ", arg.vnode
        rpcDfd = self.proto.sendRPC(cd_prot.CD_GETPREDLIST, arg)

        def predlistCB(res):
            if res.stat == chord_types.CHORD_OK:
                nodes = [ChordNode.fromWirePlusID(wireID)
                         for wireID in res.resok.nodes]
                if DEBUG:
                    print "Predecessors:"
                    for node in nodes:
                        print " %s" % node
                dfd.callback(nodes)
            else:
                print "GetPredList failed:", res.stat
                dfd.errback("GetPredList failed: " + str(res.stat))

        rpcDfd.addCallback(predlistCB)
        return dfd

    def myVnode(self):
        return self.vnodes[0]

    def myChordNode(self):
        return ChordNode(self.localHost, self.localPort,
                         0, self.vnodes[0])
                         



def trace(x):
    print "----- TRACE ------", x
    return x
    
def testChord():
    c = Chord()
    chordDfd = c.startChord("18.141.0.190", 44267, "18.141.0.190", 44267)
    chordDfd.addCallback(lambda x: c.lookup(c.myVnode()))
    chordDfd.addCallback(trace)
    reactor.run()
        
if __name__ == "__main__":
    testChord()
