import sys, os, time
import sha
from twisted.python import util
from twisted.internet import reactor, defer, task
from twisted.spread import pb
from index import *
import chord
from chord import *
import chord_types
from connectioncache import ConnectionCache
from utils import *
waitFor = defer.waitForDeferred

EXPIRE_CHECK_INTERVAL = 5               # Time in seconds to wait between
                                        # checking for expiring entries.
INDEX_GATEWAY_CHECK_INTERVAL = 60
DEBUG = True
DEBUG_VERBOSE = False

def chordToArpdPort(chordPort):
    """Convert the Chord port number to the port number of the
    associated arpd."""
    return chordPort+1

class Syncer:
    def handleOfferBlocks(self, hostport, blocks):
        pass

    def handleRequestBlocks(self, blockList):
        pass


class IndexServer(pb.Root):
    @typechecked(object, float, int, int, int,
                 IndexCollection, Chord, ConnectionCache, Syncer)
    def __init__(self, expireTime, indexGatewayExpireTime, kssK, numReplicas,
                 indexes, chord, cache, syncer):
        self.expireTime = expireTime
        self.indexGatewayExpireTime = indexGatewayExpireTime
        self.kssK = kssK
        self.numReplicas = numReplicas
        self.indexes = indexes
        self.chord = chord
        self.connectionCache = cache
        self.indexGateways = {}
        self.expireTask = task.LoopingCall(self.expire)
        self.expireTask.start(EXPIRE_CHECK_INTERVAL)
        self.syncer = syncer
        self.syncer.initialize(self)
 
    def __hashIndexName(self, indexName):
        """Return a hash of indexName in Chord-key format."""
        s = sha.new()
        s.update(str(indexName))
        sInt = int(s.hexdigest(), 16)
        return chord_types.bigint(sInt)

    def __findResponsibleIndexes(self, indexName, single=False):
        """Perform a Chord lookup to find the node responsible for
        indexing a certain key. If 'single' is set, return only a
        single node, the 'primary' node in the index replica
        set. Otherwise, return all nodes in the replica set that Chord
        knows about, i.e. the numReplicas successors of the key."""
        key = self.__hashIndexName(indexName)
        dfd = self.chord.lookup(key)
        if single:
            dfd.addCallback(lambda x: x[0])
        else:
            dfd.addCallback(lambda x: x[0:self.numReplicas])
        return dfd

    def indexesBetween(self, a, b):
        """Get a list of all indexes that are mapped to Chord IDs in
        the interval (a, b]."""
        r = []
        for name, ind in self.indexes.indexes.items():
            k = self.__hashIndexName(name)
            if chord.betweenRightIncl(a, b, k): 
#                print "Index %s (%s) between" % (name, k)
                r.append(ind)
            else:
                pass
#                print "Index %s (%s) not between" % (name, k)
        return r

    @defer.deferredGenerator
    def replicatedKeyspace(self, k, numReplicas=None):
        """Return the keyspace range (a,b] that should be replicated
        on the index server with the specified key, assuming that each
        index is replicated numReplicas times. This requires that the
        server in question is in our predecessor or successor list --
        usually this shouldn't be a problem because this is where we
        find servers to synchronize with anyway."""

        if numReplicas == None:
            numReplicas = self.numReplicas

        if DEBUG_VERBOSE:
            print "Searching for keyspace to be replicated on", k, \
                  " n =", numReplicas
            
        # Build a list of nodes we know about.
        d = waitFor(self.chord.getPredList(self.chord.myVnode()))
        yield d
        nodes = list(reversed(d.getResult()))
        d = waitFor(self.chord.getSuccList(self.chord.myVnode()))
        yield d
        nodes += d.getResult()
        nodes = uniquify(nodes)
        
        # Figure out the range of the keyspace that node is
        # responsible for replicating.
        cn = [x for x in nodes if x.chordID == k]
        if len(cn) == 0:
            raise Exception("Node %s is not in my pred/succ list." % k)
        ind = nodes.index(cn[0])
        
        if ind-numReplicas > 0:
            a = nodes[ind-numReplicas].chordID
#            print "CASE 1:", a
        else:
            a = nodes[(ind-numReplicas) % len(nodes)].chordID

        if numReplicas > len(nodes):
            a = k
            
        b = k
        
        if DEBUG_VERBOSE:
            print "Keyspace is: a =", a, "  b =", b
        yield a,b
        return


    @defer.deferredGenerator
    def __indexesForMetadata(self, metadata,
                             kssK=None, numReplicas=None):
        """Return a list of all index names for the specified metadata
        block on this node, assuming a certain replica count and KSS
        parameter K. This is the index list generated by KSS, but
        restricted to ones that should be placed on this node."""

        if kssK == None:
            kssK = self.kssK
        if numReplicas == None:
            numReplicas = self.numReplicas

        # What's our keyspace?
        d = waitFor(self.replicatedKeyspace(self.chord.myVnode(),
                                              numReplicas))
        yield d
        (a, b) = d.getResult()

        # Generate the full KSS list
        indexNames = [indexNameFromSet(kssSet)
                      for kssSet in metadata.getKSSSets(kssK)]

        ourIndexNames = [name for name in indexNames
                         if betweenRightIncl(a, b,
                                             self.__hashIndexName(name))]
        yield ourIndexNames
        return

    @defer.deferredGenerator
    def addBlockToIndexes(self, metadata, expireTime):
        """Add a Metadata block to all appropriate indexes that should
        be stored on this node. Returns the value of expireTime (the
        number of seconds from now until the index entry expires,
        *not* the absolute *time* at which it will expire."""
        if DEBUG_VERBOSE:
            print "Adding block", metadata, "to indexes"
        
        d = waitFor(self.__indexesForMetadata(metadata))
        yield d
        indexNames = d.getResult()

        for name in indexNames:
            self.indexes.addToIndex(name, metadata, expireTime)
        yield expireTime
        return
    

    @defer.deferredGenerator
    @typechecked(object, Metadata)
    def remote_add(self, metadata):
        """Add a metadata block to all appropriate indexes on this
        node. Note that this is not a *recursive* insert; it will not
        contact other nodes."""
        return self.addBlockToIndexes(metadata,
                                        time.time() + self.expireTime)

    @typechecked(object, str, Metadata)
    def remote_addToSpecificIndex(self, indexName, metadata):
        """Add a metadata block to the appropriate index on this
        node."""
        print "AddToSpecificIndex:", indexName
        self.indexes.addToIndex(indexName, metadata,
                                time.time() + self.expireTime)
        return self.expireTime

    @defer.deferredGenerator
    @typechecked(object, str, Metadata)
    def __callOnSomeReplica(self, indexName, func):
        """Find some replica node for the index with name 'indexName',
        and call 'func' on its root object. 'func' must be a function
        that returns a Deferred; probably it'll be something like
        'lambda (obj): obj.callRemote(...)'. If an error occurs during
        the connection or execution, then we'll fall back to some
        other node in the replica set; if we've exhaused them all,
        raise an exception."""

        d = waitFor(self.__findResponsibleIndexes(indexName))
        yield d
        servers = d.getResult()

        while True:
            if len(servers) == 0:
                print "callOnSomeReplica: all possible replicas failed!"
                raise Exception("All possible replicas failed!")
            server = servers.pop(0)

            try:
                d = waitFor(
                    self.connectionCache.connect(server.host,
                                                 chordToArpdPort(server.port)))
                yield d
                root = d.getResult()

                d = waitFor(func(root))
                yield d
                yield d.getResult()
                return
            except Exception, e:
                print ("callOnSomeReplica: Failed at contacting host %s " + \
                       " (error: %s), trying next.") % \
                      (str(server), e)
                continue

    @typechecked(object, str, Metadata)
    def findAndAddToIndex(self, indexName, metadata):
        """Add a metadata block to the specified index, even if it's
        on a different node."""

        return self.__callOnSomeReplica(
            indexName,
            lambda obj: obj.callRemote("addToSpecificIndex",
                                       indexName, metadata))

    @typechecked(object, Metadata)
    def remote_recursiveAdd(self, metadata):
        indexNames = [indexNameFromSet(kssSet)
                      for kssSet in metadata.getKSSSets(self.kssK)]
        print indexNames
        dfds = []
        
        for indexName in indexNames:
            if DEBUG_VERBOSE:
                print "Recursive add: adding to index", indexName
            dfd = self.findAndAddToIndex(indexName, metadata)
            dfds.append(dfd)

        return defer.DeferredList(dfds)

    @defer.deferredGenerator
    @typechecked(object, Metadata)
    def remote_recursiveAddWithIG(self, metadata):
        # First try an index gateway insert
        try:
            d = waitFor(self.__findResponsibleIndexes(hash(metadata), True))
            yield d
            gateway = d.getResult()

            d = waitFor(
                self.connectionCache.connect(gateway.host,
                                             chordToArpdPort(gateway.port)))
            yield d
            root = d.getResult()

            d = waitFor(root.callRemote("gatewayInsert", metadata))
            yield d
            yield d.getResult()
            return
        except Exception, e:
            print "Index gateway insert failed, falling back to standard add", e
            d = waitFor(self.remote_recursiveAdd(metadata))
            yield d
            expireTimes = d.getResult()
            yield min(expireTimes)
            return

    @typechecked(object, str, Metadata)
    def remote_recursiveAddToSpecificIndex(self, indexName, metadata):
        """Add a metadata block to the specified index, even if it's
        on a different node."""
        return self.findAndAddToIndex(indexName, metadata)

                                  
    @typechecked(object, str, Query)
    def remote_search(self, indexName, query):
        """Perform a query on a local index."""
        return self.indexes.search(indexName, query)

    @typechecked(object, Query)
    def remote_recursiveSearch(self, query):
        """Identify an appropriate index for a query, locate the
        responsible node for that index and send it the query."""
        indexName = indexNameFromSet(query.getIndexKeywordSet(self.kssK))
        return self.__callOnSomeReplica(
            indexName,
            lambda obj: obj.callRemote("search",
                                       indexName, query))

            
        
    @typechecked(object, tuple, list)
    def remote_replOfferBlocks(self, hostport, blocks):
        """Handle a 'offer blocks' request --- that is, receive a list
        of metadata records that some remote node thinks this node
        should have. This RPC returns nothing interesting, but this
        should trigger a request to obtain any of the missing blocks."""

        if self.syncer is not None:
            return self.syncer.handleOfferBlocks(hostport, blocks)
        else:
            return None

    @typechecked(object, list)
    def remote_replRequestBlocks(self, blockList):
        """Handle a 'request blocks' request --- that is, return a
        list of Metadata blocks and expire times corresponding to the
        IDs sent by the remote node."""

        if self.syncer is not None:
            return self.syncer.handleRequestBlocks(blockList)
        else:
            return None

    def expire(self):
        """Perform expiration run; called periodically."""
        self.indexes.expire(time.time())

        # Expire index gateways
        for blockHash, gateway in self.indexGateways.items():
            if time.time() > (gateway.lastClientRefresh +
                              self.indexGatewayExpireTime):
                if DEBUG:
                    print "Expiring index gateway for block", blockHash
                gateway.stop()
                del self.indexGateways[blockHash]


    @typechecked(object, Metadata)
    def remote_gatewayInsert(self, block):
        """Remote request to insert/refresh an index gateway. This may
        trigger an insertion request."""
        try:
            self.indexGateways[hash(block)].refresh()
        except KeyError:
            self.indexGateways[hash(block)] = IndexGateway(block,
                                                           self.kssK, self)
        return self.indexGatewayExpireTime


class IndexGateway:
    """An index gateway serves as an intermediary to improve
    performance of index insertion. Specifically, a metadata block
    needs to be registered with many indexes (corresponding to the KSS
    sets). Moreover, many nodes will have the same file (hence the
    same metadata block). Rather than have each of them contat each
    index node, they can simply contact the index gateway, which keeps
    track of the expiration times at each index and updates each index
    accordingly. Specificallyt, this is the mechanism described in
    section 3.2 of the Arpeggio design paper (CPK '05).

    Note that index gateways are not a critical component for
    correctness; if one fails, individual nodes can simply perform the
    insertions themselves. Hence, they are not replicated (though this
    might still be useful as a performance optimization later.)"""

    @typechecked(object, Metadata, int, IndexServer)
    def __init__(self, block, kssK, indexServer):
        """Create a gateway for the specified metadata block."""
        self.block = block
        self.kssK = kssK
        self.lastClientRefresh = time.time()
        self.indexExpires = {}
        self.indexServer = indexServer
        for x in block.getKSSSets(self.kssK):
            self.indexExpires[indexNameFromSet(x)] = 0
        self.checkTask = task.LoopingCall(self.checkExpire)
        self.checkTask.start(INDEX_GATEWAY_CHECK_INTERVAL)
        print "Creating index gateway for", id(block)

    @defer.deferredGenerator
    @typechecked(object)
    def checkExpire(self):
        """Check whether any index nodes for this block are due to
        expire, and if so refresh them."""
        
        for indexName, expireTime in self.indexExpires.items():
            if expireTime < time.time():
                print "IG", hash(self.block), " --  refreshing index", indexName
                d = waitFor(
                    self.indexServer.findAndAddToIndex(indexName, self.block))
                yield d
                self.indexExpires[indexName] = time.time() + d.getResult()
                print "IG", hash(self.block), " --  new expire time is", time.time() + d.getResult()

    def stop(self):
        """Shut down this index gateway by disabling the periodic
        check task. Note that the gateway also needs to be removed
        from whatever structure is keeping track of index gateways."""
        self.checkTask.stop()

    def refresh(self):
        print "IG", hash(self.block), " -- client refreshing"
        self.lastClientRefresh = time.time()


