import heapq
import sys
import socket
import random

import xctl
from qemu import Qemu, IQemuResponseHandler, LoggingHandler
from synchronizer import Synchronizer, ISynchronizable
from network import NetworkSwitch, NetworkLogger, INetworkEndPoint
from codec import QemuEncoder, QemuDecoder
import utils

import pdb ## Run with python -i to be able to debug

DEBUG_SCHEDULE = False
DEBUG_SEND_CMDS = True
DEBUG_SEND_CMDS_VERBOSE = False
DEBUG_RECV_CMDS = True
DEBUG_BARRIER_BUG = True

def secToNsec(secs):
    return long(secs*1000000000)

def nsecToSec(nsecs):
    return float(nsecs)/1000000000

class Scheduler(object):
    __codeToEventClass = {}

    def __init__(self):
        self.__schedule = []
        heapq.heapify(self.__schedule)

    def add(self, event):
        assert isinstance(event, SchedulerEvent)
        heapq.heappush(self.__schedule, event)

    def peek(self):
        if len(self.__schedule):
            return self.__schedule[0]
        else:
            return None

    # Very inefficient. Use sparingly.
    def removeEvent(self, event):
        lst = []
        while True:
            try:
                e = heapq.heappop(self.__schedule)
                if e != event:
                    lst.append(e)
                else:
                    print "Removing an event."
            except IndexError:
                break
        self.__schedule = lst
        heapq.heapify(self.__schedule)

    def pop(self):
        try:
            return heapq.heappop(self.__schedule)
        except IndexError:
            return None

    def dump(self):
        q = self.__schedule[:]
        print "<Schedule"
        while True:
            elt = heapq.heappop(q)
            if len(q):
                print "  %s," % elt
            else:
                print "  %s>" % elt
                break

    def encode(self, enc):
        enc.put_buffer("schd")
        enc.put_be32(len(self.__schedule))
        for elt in self.__schedule:
            if elt is None:
                enc.put_string("")
            else:
                enc.put_string(elt._getCode())
                elt._encode(enc)

    @classmethod
    def decode(cls, dec):
        sig = dec.get_buffer(4)
        assert sig == "schd", "Expected 'schd', got '%s'" % sig
        ret = Scheduler()
        n = dec.get_be32()
        for i in range(n):
            code = dec.get_string()
            if code == "":
                ret.__schedule.append(None)
            else:
                eventClass = cls.__codeToEventClass.get(code, None)
                assert eventClass is not None, \
                       "Could not find event class %s" % `code`
                event = eventClass._decode(dec)
                ret.__schedule.append(event)
        return ret

    @classmethod
    def _registerEventClass(cls, evt):
        code = evt._getCode()
        if code in cls.__codeToEventClass:
            raise RuntimeError, "Event class '%s' already registered" % `code`
        cls.__codeToEventClass[code] = evt

class SchedulerEvent(object):
    def __init__(self, time):
        assert isinstance(time, long) or isinstance(time, int)
        self.time = long(time)
        self.prio = self._getPriority()

    def __cmp__(self, other):
        c = cmp(self.time, other.time)
        if c != 0:
            return c
        return cmp(self.prio, other.prio)

    def __repr__(self):
        return "<%d:SchedulerEvent>" % self.time

    def _encode(self, enc):
        enc.put_be64(self.time)

    @staticmethod
    def _decode(dec):
        time = dec.get_be64()
        return SchedulerEvent(time)

    @staticmethod
    def _getCode():
        return "SchedulerEvent"

    @staticmethod
    def _getPriority():
        # Smaller values get descheduled earlier
        return 0
Scheduler._registerEventClass(SchedulerEvent)

class TraceEvent(SchedulerEvent):
    def __init__(self, time):
        SchedulerEvent.__init__(self, time)

    def __repr__(self):
        return "<%d:TraceEvent>" % self.time

    def _encode(self, enc):
        enc.put_be64(self.time)

    @staticmethod
    def _decode(dec):
        time = dec.get_be64()
        return TraceEvent(time)

    @staticmethod
    def _getCode():
        return "TraceEvent"

    @staticmethod
    def _getPriority():
        return -1
Scheduler._registerEventClass(TraceEvent)

class SyncPointEvent(SchedulerEvent):
    def __init__(self, time):
        SchedulerEvent.__init__(self, time)

    def __repr__(self):
        return "<%d:SyncPointEvent>" % self.time

    def _encode(self, enc):
        enc.put_be64(self.time)

    @staticmethod
    def _decode(dec):
        time = dec.get_be64()
        return SyncPointEvent(time)

    @staticmethod
    def _getCode():
        return "SyncPointEvent"

    @staticmethod
    def _getPriority():
        return 100
Scheduler._registerEventClass(SyncPointEvent)

class PacketDeliveryEvent(SchedulerEvent):
    def __init__(self, time, tries, packet):
        SchedulerEvent.__init__(self, time)
        assert isinstance(tries, int) or isinstance(tries, long)
        assert isinstance(packet, utils.EthernetPacket)
        self.tries = tries
        self.packet = packet

    def __repr__(self):
        return "<%d:PacketDeliveryEvent(%d,%s)>" % \
               (self.time, self.tries, self.packet)

    def _encode(self, enc):
        enc.put_be64(self.time)
        enc.put_be32(self.tries)
        enc.put_string(self.packet.getBytes())

    @staticmethod
    def _decode(dec):
        time = dec.get_be64()
        tries = dec.get_be32()
        data = dec.get_string()
        return PacketDeliveryEvent(time, tries,
                                   utils.EthernetPacket(data))

    @staticmethod
    def _getCode():
        return "PacketDeliveryEvent"
Scheduler._registerEventClass(PacketDeliveryEvent)

class SnapshotEvent(SchedulerEvent):
    def __init__(self, time, incremental):
        SchedulerEvent.__init__(self, time)
        assert isinstance(incremental, bool)
        self.incremental = incremental

    def __repr__(self):
        return "<%d:SnapshotEvent>" % self.time

    def _encode(self, enc):
        enc.put_be64(self.time)
        if self.incremental:
            enc.put_byte(1)
        else:
            enc.put_byte(0)

    @staticmethod
    def _decode(dec):
        time = dec.get_be64()
        incremental = (dec.get_byte() != 0)
        return SnapshotEvent(time, incremental)

    @staticmethod
    def _getCode():
        return "SnapshotEvent"
Scheduler._registerEventClass(SnapshotEvent)

class LoadSnapshotEvent(SchedulerEvent):
    def __init__(self, time, ident):
        SchedulerEvent.__init__(self, time)
        self.ident = ident

    def __repr__(self):
        return "<%d:LoadSnapshotEvent(%s)>" % (self.time, `self.ident`)

    def _encode(self, enc):
        enc.put_be64(self.time)
        enc.put_string(self.ident)

    @staticmethod
    def _decode(dec):
        time = dec.get_be64()
        ident = dec.get_string()
        return LoadSnapshotEvent(time, ident)

    @staticmethod
    def _getCode():
        return "LoadSnapshotEvent"
Scheduler._registerEventClass(LoadSnapshotEvent)

def testSchedulerSerialization():
    print "Testing scheduler serialization"
    s = Scheduler()
    s.add(SyncPointEvent(10))
    s.add(PacketDeliveryEvent
          (20, 2, utils.EthernetPacket("\xff"*12+"\x00\x08"+"\x00"*4)))
    s.add(SchedulerEvent(2))
    s.add(SyncPointEvent(8))
    print "Before:"
    s.dump()
    enc = QemuEncoder()
    s.encode(enc)
    dec = QemuDecoder(str(enc))
    s2 = Scheduler.decode(dec)
    print "After:"
    s2.dump()

class TestHandler(IQemuResponseHandler, ISynchronizable, INetworkEndPoint):
    def __init__(self, qemu, dumper):
        self.qemu = qemu
        self.dumper = dumper
        self.synchronizer = None
        self.network = None
        self.scheduler = Scheduler()
        self.packetEventsMap = {}       # Maps (st,rt,src,dst) to sched event

        self.closing = False
        self.sendQ = []
        self.sending = None
        self.sendFailures = 0
        self.curTSC = 0
        self.latencies = {}             # Maps dstOctet to ns latency
        self.lossRates = {}             # Maps dstOctet to loss frac
        # 5 ms quanta
        self.quanta = 5 * self.qemu.getTicksPerSec() / 1000
        self.sendBandwidth = 10000000000000
        self.sendQLen = 10000000000000
        self.sendQLastLen = 0
        self.sendQLastTime = 0

        self.scheduler.add(TraceEvent(0))
        # Initialize things with a full snapshot
        #self.scheduler.add(SnapshotEvent(self.qemu.secToTicks(0), False))
        #self.scheduler.add(SnapshotEvent(self.qemu.secToTicks(5), True))
        #self.scheduler.add(LoadSnapshotEvent(self.qemu.secToTicks(10)))
        self.__needIncrementalSnapshot = False

    def close(self):
        self.qemu.cmdQuit()
        self.closing = True

    def onQemuClose(self, qemu):
        if not self.closing:
            print "Closed prematurely"
        self.synchronizer.leave(self)
        self.network.leave(self)

    #
    # Synchronization
    #

    def setSynchronizer(self, sync):
        self.synchronizer = sync
        sync.join(self)
        sync.reached(self)

    def onSyncProceed(self):
        self.scheduler.add(SyncPointEvent(self.curTSC + self.quanta))
        self.proceed()

    def onQemuBPReached(self, qemu, bp):
        if DEBUG_SCHEDULE:
            print "Breakpoint reached %d:" % bp
            self.scheduler.dump()
        self.curTSC = bp
        self.proceed()

    def handleEvent(self, event):
        if isinstance(event, TraceEvent):
            print "%s: Trace @ %d secs (%d ticks)" % \
                  (self.qemu.getLabel(),
                   self.qemu.ticksToSec(event.time),
                   event.time)
            self.scheduler.add(TraceEvent(event.time +
                                          self.qemu.secToTicks(1)))
            self.proceed()
        elif isinstance(event, SyncPointEvent):
            self.synchronizer.reached(self)
        elif isinstance(event, PacketDeliveryEvent):
            self.deliverPacket(event)
        elif isinstance(event, SnapshotEvent):
            self.save(event.incremental)
        elif isinstance(event, LoadSnapshotEvent):
            self.load(event.ident)
        else:
            print "BUG: Unknown event type %s" % `event`

    def proceed(self):
        event = self.scheduler.peek()
        if event is None:
            print "BUG: No events in schedule!  Proceeding forever"
            self.qemu.cmdRun(0)
        elif event.time <= self.curTSC:
            event = self.scheduler.pop()
            if event.time < self.curTSC:
                print (("BUG: Breakpoint time disagreed with scheduler" +
                        " (tsc %d, event %d)") % (self.curTSC, event.time))
            self.handleEvent(event)
        else:
            self.qemu.cmdRun(event.time)

    #
    # Networking
    #

    def setNetwork(self, net):
        self.network = net
        net.join(self, self.qemu.getMAC())

    def setLatency(self, dstOctet, latency):
        self.latencies[dstOctet] = latency

    def setLossRate(self, dstOctet, rate):
        self.lossRates[dstOctet] = rate

    def setSendBottleneck(self, bw, qlen):
        self.sendBandwidth = bw
        self.sendQLen = qlen

    def onQemuPacket(self, qemu, tsc, macaddr, packet):
        self.dumper.putPacket(self.qemu.ticksToSec(tsc), packet)

        try:
            linkLatency = self.latencies[packet.getDestMAC().toTuple()[-1]]
            lossRate = self.lossRates[packet.getDestMAC().toTuple()[-1]]
        except KeyError:
            print "Warning: latency or loss rate not specified " + \
                  "for %d, assuming 1s / no loss" % \
                  (packet.getDestMAC().toTuple()[-1])
            linkLatency = 1e9
            lossRate = 0.0

        # Update queue status
        curTime = self.qemu.ticksToSec(tsc)
        packetLen = 8 * len(packet.getBytes()) # in bits
        self.sendQLastLen -= (curTime - self.sendQLastTime) * \
                             self.sendBandwidth
        if self.sendQLastLen <= 0:
            self.sendQLastLen = 0
        self.sendQLastTime = curTime

        if self.sendQLastLen + packetLen <= self.sendQLen:
            self.sendQLastLen += packetLen
            qLatency = secToNsec(self.sendQLastLen / self.sendBandwidth)
            print "Delaying packet %f seconds in queue" % qLatency
            totalLatency = linkLatency + qLatency
        else:
            print "Queue full, dropping packet", packet
            return

        # XXX Need to ensure RNG state is saved for deterministic
        # replay
        if random.random() < lossRate:
            # Packet was dropped.
            # XXX Report to master
            print "Dropping packet:", packet
        else:
            self.network.send(self.qemu.ticksToNsec(tsc),
                              self.qemu.ticksToNsec(tsc +
                                                    self.qemu.nsecToTicks(qLatency)),
                              self.qemu.ticksToNsec(tsc +
                                                    self.qemu.nsecToTicks(totalLatency)),
                              packet)

    def onNetworkPacket(self, net, st, nsec, packet):
        tsc = self.qemu.nsecToTicks(nsec)
        pde = PacketDeliveryEvent(tsc, 0, packet)
        key = (st, nsec, packet.getSrcMAC())
        print "Adding packet event mapping ", str(st), str(nsec), str(packet.getSrcMAC())
        self.packetEventsMap[key] = pde
        self.scheduler.add(pde)

    def deliverPacket(self, packetEvent):
        self.qemu.cmdPacket(0, packetEvent.packet)
        self.sending = packetEvent

    def dropPacket(self, st, rt, src, dst):
        key = (st, rt, src)
        print "Looking for packet event mapping ", str(st), str(rt), str(src)
        try:
            pde = self.packetEventsMap[key]
        except KeyError:
            print "Couldn't find packet to drop!"
            return
        self.scheduler.removeEvent(pde)
        del self.packetEventsMap[key]

    def onQemuPacketOk(self, qemu, tsc, mid):
        self.dumper.putPacket(self.qemu.ticksToSec(tsc), self.sending.packet)
        self.sending = None
        self.proceed()

    def onQemuPacketErr(self, qemu, tsc, mid):
        if self.sending.tries <= 5:
            # Reschedule for 1 ms later (just retrying now won't help
            # because the Qemu can't make any progress)
            self.scheduler.add(
                PacketDeliveryEvent(tsc + long(self.qemu.nsecToTicks(1e6)),
                                    self.sending.tries + 1,
                                    self.sending.packet))
        else:
            # Give up
            print "WARNING: Packet delivery permanently failed"
            self.sending = None
        self.proceed()

    #
    # Serialization
    #

    def snapshotOnNext(self):
        self.scheduler.add(SnapshotEvent(self.curTSC,
                                         self.__needIncrementalSnapshot))
        self.__needIncrementalSnapshot = True

    def save(self, incremental):
        print "Saving",
        if incremental:
            print "incremental..."
        else:
            print "full..."
        enc = QemuEncoder()
        self.scheduler.encode(enc)
        self.qemu.cmdSave(incremental, str(enc))

    def onQemuSaved(self, qemu, ident):
        print "Saved to %s" % `ident`
        # This ident will get picked up by the slaved when the next
        # barrier is reached
        self.lastSnapshotIdent = ident
        self.proceed()

    def loadOnNext(self, ident):
        self.scheduler.add(LoadSnapshotEvent(self.curTSC, ident))

    def load(self, ident):
        print "Loading %s..." % `ident`
        self.qemu.cmdLoad(ident)

    def onQemuLoaded(self, qemu, tsc, extra):
        dec = QemuDecoder(extra)
        self.scheduler = Scheduler.decode(dec)
        self.curTSC = tsc
        print "Loaded"
        self.proceed()

    #
    # Misc
    #

    def getTimeNsecs(self):
        return self.qemu.ticksToNsec(self.curTSC)

class SlavedControlListenerHandler(xctl.IListenerHandler):
    def __init__(self, slave):
        self.__slave = slave

    def onListenerNewConnection(self, chan):
        self.__slave.onControlListenerNewConnection(chan)

class SlavedNetListenerHandler(xctl.IListenerHandler):
    def __init__(self, slave):
        self.__slave = slave

    def onListenerNewConnection(self, chan):
        self.__slave.onNetListenerNewConnection(chan)

class NetworkRemoteEndpoint(INetworkEndPoint, xctl.IChannelHandler):
    def __init__(self, pollLoop, host, port):
        self.__sock = socket.socket()
        self.__addr = host, port
        self.__sock.connect(self.__addr)
        self.__channel = xctl.Channel(self.__sock)
        self.__channel.registerHandler(self)
        pollLoop.register(self.__channel)

    # Handle a *real* network packet
    def onPacket(self, chan, packet):
        print "Received incoming traffic on outgoing inter-slave link!"

    def cmdSendPacket(self, st, rt, packet):
        if DEBUG_SEND_CMDS:
            print "Delivering packet to slaved at %s" % str(self.__addr)
        enc = QemuEncoder()
        enc.put_be64(st)
        enc.put_be64(rt)
        enc.put_string(packet.getBytes())
        self.__channel.sendPacket(xctl.Packet("pckt", str(enc)))

    # Handle a *virtual* network packet.
    def onNetworkPacket(self, network, st, rt, packet):
        #print "Got network packet: %s" % str(packet)
        self.cmdSendPacket(st, rt, packet)

    def isNetworkLocalEndpoint(self):
        return False

class Slaved(xctl.IChannelHandler, ISynchronizable):
    MACADDR_BASE = "52:54:00:12:43:"
    
    def __init__(self, pl, ctlport, netport):
        self.__pollLoop = pl
        self.__ctlPort = ctlport
        self.__netPort = netport
        self.__handlers = []
        self.__sync = Synchronizer()
        self.__sync.join(self, remainder = True)
        self.__nodes = {}
        self.__macAddrMap = {}          # Maps octet to host/port
        self.__localMacAddrMap = {}     # Maps octet to handler
        self.__networkEndpointMap = {}  # Maps host/port to remote endpoint
        self.__myMacAddrs = set()
        self.__net = NetworkSwitch()
        self.__netLog = NetworkLogger()
        self.__net.addObserver(self.__netLog)
        self.__numPacketsRequired = -1
        self.__numPacketsReceived = 0

        self.__takingSnapshot = False

        # Set up the listeners
        self.__ctlListener = xctl.Listener(ctlport)
        self.__ctlListener.registerHandler(SlavedControlListenerHandler(self))
        self.__pollLoop.register(self.__ctlListener)
        self.__netListener = xctl.Listener(netport)
        self.__netListener.registerHandler(SlavedNetListenerHandler(self))
        self.__pollLoop.register(self.__netListener)

    def onControlListenerNewConnection(self, chan):
        self.__ctlChannel = chan
        self.__pollLoop.register(chan)
        chan.registerHandler(self)

    def onNetListenerNewConnection(self, chan):
        print "Incoming remote network connection established."
        self.__pollLoop.register(chan)
        chan.registerHandler(self)

    def onPacket(self, chan, packet):
        if chan == self.__ctlChannel:
            self.onControlPacket(chan, packet)
        else:
            self.onNetPacket(chan, packet)

    def onControlPacket(self, chan, packet):
        if packet == None:
            method = "Close"
            args = ()
        else:
            cmd = packet.cmd
            dec = QemuDecoder(packet.args)
            if cmd == "sypr":
                method = "SynchronizerProceed"
                numPackets = dec.get_be32()
                isSnapshotByte = dec.get_byte()
                isSnapshot = isSnapshotByte != 0
                args = (numPackets, isSnapshot)
            elif cmd == "init":
                method = "InitSlave"
                quantum = dec.get_be32()
                args = (quantum,)
            elif cmd == "node":
                method = "NodeCreate"
                nodeID = dec.get_be32()
                ipOctet = dec.get_byte()
                hda = dec.get_string()
                kernel = dec.get_string()
                append = dec.get_string()
                cdrom = dec.get_string()
                cpuSpeed = dec.get_be32()
                args = (nodeID, ipOctet, hda, kernel, append, cdrom, cpuSpeed)
            elif cmd == "strt":
                method = "Start"
                args = ()
            elif cmd == "macm":
                method = "MacMap"
                octet = dec.get_byte()
                host = dec.get_string()
                port = dec.get_be16()
                args = (octet, host, port)
            elif cmd == "ltnc":
                method = "Latency"
                srcOctet = dec.get_byte()
                dstOctet = dec.get_byte()
                latency = dec.get_be64()
                args = (srcOctet, dstOctet, latency)
            elif cmd == "loss":
                method = "LossRate"
                srcOctet = dec.get_byte()
                dstOctet = dec.get_byte()
                rate = dec.get_be32() / 1e6 
                args = (srcOctet, dstOctet, rate)
            elif cmd == "snbn":
                method = "SendBottleneck"
                srcOctet = dec.get_byte()
                bw = dec.get_be32()
                qlen = dec.get_be32()
                args = (srcOctet, bw, qlen)
            elif cmd == "roll":
                method = "Rollback"
                nids = dec.get_be32()
                ids = [dec.get_string() for _ in range(nids)]
                args = (ids,)
            elif cmd == "drpp":
                method = "DropPacket"
                sendTime = dec.get_be64()
                deliveryTime = dec.get_be64()
                srcMac = utils.MacAddr(dec.get_string())
                dstMac = utils.MacAddr(dec.get_string())
                args = (sendTime, deliveryTime, srcMac, dstMac)
            elif cmd == "dlyp":
                method = "DelayPacket"
                sendTime = dec.get_be64()
                deliveryTime = dec.get_be64()
                srcMac = utils.MacAddr(dec.get_string())
                dstMac = utils.MacAddr(dec.get_string())
                newDeliveryTime = dec.get_be64()
                args = (sendTime, deliveryTime, srcMac, dstMac,
                        newDeliveryTime)
            else:
                raise RuntimeError, "Unknown command received %s" % `cmd`
        func = getattr(self, "onControl" + method)
        func(*args)

    def onControlInitSlave(self, quantum):
        print "Initializing slave"
        self.__quantum = quantum

    def onControlNodeCreate(self, nodeID, ipOctet, hda, kernel, append,
                            cdrom, cpuSpeed):
        ipOctetHex = "%02x" % ipOctet
        macAddr = utils.MacAddr(self.MACADDR_BASE + ipOctetHex)
        print "Creating node %d (%s)" % (nodeID, macAddr)
        self.__myMacAddrs.add(ipOctet)
        qemu = Qemu(cpuSpeed, macAddr, hda, kernel, append, cdrom)
        self.__nodes[nodeID] = qemu
        handler = TestHandler(qemu, utils.PcapWriter(ipOctetHex + ".pcap"))
        qemu.registerHandler(handler)
        self.__handlers.append(handler)
        self.__pollLoop.register(qemu)
        handler.setSynchronizer(self.__sync)
        handler.setNetwork(self.__net)
        self.__localMacAddrMap[ipOctet] = handler

    def onControlStart(self):
        print "Starting execution"
        self.__sync.start()

    def onControlClose(self):
        print "Closing down"
        for handler in self.__handlers:
            handler.close()
        self.__handlers = []
        self.__nodes = {}
        self.__sync.leave(self)
        self.__sync = Synchronizer()
        self.__sync.join(self, remainder = True)
        self.__net = NetworkSwitch()
        self.__netLog = NetworkLogger()
        self.__net.addObserver(self.__netLog)

    def __maybeSynchronizerProceed(self):
        if (self.__numPacketsRequired == self.__numPacketsReceived):
            if self.__nextProceedIsSnapshot:
                print "Taking snapshot"
                self.__takingSnapshot = True
                for handler in self.__handlers:
                    handler.snapshotOnNext()

            self.__sync.reached(self)
            self.__numPacketsReceived = 0
            self.__numPacketsRequired = -1
        elif DEBUG_BARRIER_BUG:
            print "Not proceeding yet: req=%d got=%d" % (self.__numPacketsRequired, self.__numPacketsReceived)

    def onControlSynchronizerProceed(self, numPackets, isSnapshot):
        if DEBUG_RECV_CMDS and (numPackets > 0 or DEBUG_SEND_CMDS_VERBOSE):
            print "Received control sync proceed (wait for %d packets)" % \
                  numPackets
            
        self.__nextProceedIsSnapshot = isSnapshot
        self.__numPacketsRequired = numPackets
        self.__maybeSynchronizerProceed()


    def onControlMacMap(self, octet, host, port):
        ipOctetHex = "%02x" % octet
        macAddr = utils.MacAddr(self.MACADDR_BASE + ipOctetHex)
        
        if octet in self.__myMacAddrs:
            print "Ignoring MAC mapping for %s (it's local)" % str(macAddr)
            return
        
        print "MAC %s is at %s:%d" % (str(macAddr), host, port)
        addr = (host, port)
        self.__macAddrMap[octet] = addr
        if addr in self.__networkEndpointMap:
            endpoint = self.__networkEndpointMap[addr]
        else:
            endpoint = NetworkRemoteEndpoint(self.__pollLoop, host, port)
            self.__networkEndpointMap[addr] = endpoint
        self.__net.join(endpoint, macAddr)

    def onControlLatency(self, srcOctet, dstOctet, latency):
        print "Setting latency for %d -> %d to %d ns" % (srcOctet,
                                                         dstOctet,
                                                         latency)
        if srcOctet in self.__myMacAddrs:
            self.__localMacAddrMap[srcOctet].setLatency(dstOctet, latency)

    def onControlLossRate(self, srcOctet, dstOctet, rate):
        print "Setting loss rate for %d -> %d to %f" % (srcOctet,
                                                      dstOctet,
                                                      rate)
        if srcOctet in self.__myMacAddrs:
            self.__localMacAddrMap[srcOctet].setLossRate(dstOctet, rate)

    def onControlSendBottleneck(self, srcOctet, bw, qlen):
        print "Setting send bottleneck for %d to bw=%d qlen=%d" % (srcOctet,
                                                                   bw, qlen)
        if srcOctet in self.__myMacAddrs:
            self.__localMacAddrMap[srcOctet].setSendBottleneck(bw, qlen)

    def onControlRollback(self, ids):
        print "Rolling back to %s" % `ids`
        for handler, i in zip(self.__handlers, ids):
            handler.load(i)
        self.__sync.reset()

    def onControlDropPacket(self, sendTime, deliveryTime, srcMac, dstMac):
        if dstMac.isMulticast():
            hList = self.__handlers
        else:
            try:
                hList = [self.__localMacAddrMap[dstMac.toTuple()[-1]]]
            except KeyError:
                print "Couldn't find responsible handler for dropping packet"

        for handler in hList:
            handler.dropPacket(sendTime, deliveryTime, srcMac, dstMac)

    def cmdControlLocalReady(self, nsecs, log):
        if DEBUG_SEND_CMDS_VERBOSE:
            print "slaved->master LocalReady(%d, %d messages)" % (nsecs,
                                                                  len(log))
        enc = QemuEncoder()
        enc.put_be64(nsecs)
        enc.put_be32(len(log))
        for x in log:
            enc.put_be64(x[0])
            enc.put_be64(x[1])
            enc.put_be64(x[2])
            enc.put_string(str(x[3].getSrcMAC()))
            enc.put_string(str(x[3].getDestMAC()))
            # XXX The following is totally the wrong way to push this
            # information over.  The master should request the packet
            # contents if it wants details like this and do its own
            # just-in-time decoding.
            enc.put_string(x[3].classify())
        self.__ctlChannel.sendPacket(xctl.Packet("lrdy", str(enc)))

    def onNetPacket(self, chan, packet):
        if packet == None:
            method = "Close"
            args = ()
        else:
            cmd = packet.cmd
            dec = QemuDecoder(packet.args)
            if cmd == "pckt":
                method = "RemotePacket"
                st = dec.get_be64()
                rt = dec.get_be64()
                packet = utils.EthernetPacket(dec.get_string())
                args = (st, rt, packet)
            else:
                raise RuntimeError, "Unknown command received %s" % `cmd`
        func = getattr(self, "onNet" + method)
        func(*args)

    def onNetClose(self):
        pass

    def onNetRemotePacket(self, st, rt, packet):
        if DEBUG_RECV_CMDS:
            print "Received remote packet %s (rt = %d)" % (str(packet), rt)
        self.__net.send(st, 0, rt, packet, onlyLocal = True)
        self.__numPacketsReceived += 1
        self.__maybeSynchronizerProceed()

    def onSyncRemainder(self):
        nsecs = self.__handlers[0].getTimeNsecs()

        if self.__takingSnapshot:
            self.__takingSnapshot = False
            # Report snapshot identifiers
            ids = []
            for handler in self.__handlers:
                ids.append(handler.lastSnapshotIdent)
            self.cmdSnapshotTaken(ids)

        log = self.__netLog.getLog()
        self.cmdControlLocalReady(nsecs, log)
        self.__netLog.resetLog()

    def cmdSnapshotTaken(self, ids):
        if DEBUG_SEND_CMDS:
            print "slaved->master SnapshotTaken([%d ids])" % len(ids)
        enc = QemuEncoder()
        enc.put_be32(len(ids))
        for s in ids:
            enc.put_string(s)
        self.__ctlChannel.sendPacket(xctl.Packet("snap", str(enc)))

def main():
    pl = xctl.PollLoop()

    port = int(sys.argv[1])
    print "Starting slaved on ports", port, "and", port+1
    slaved = Slaved(pl, port, port+1)
    xctl.mainLoop(pl)

if __name__ == "__main__":
    main()
