import socket, select, errno
from utils import hexDump
from codec import QemuEncoder, QemuDecoder

DEBUG_CHANNEL_DATA = False
DEBUG_CHANNEL_CMDS = False

class PollLoop(object):
    def __init__(self):
        self.__channels = []
        self.__idlers = []

    def register(self, chan):
        assert isinstance(chan, IPollable)
        self.__channels.append(chan)

    def unregister(self, chan):
        self.__channels.remove(chan)

    def registerIdler(self, idler, maxdelay):
        assert isinstance(idler, IPollIdler)
        self.__idlers.append((idler, maxdelay))

    def unregisterIdler(self, idler):
        for i in self.__idlers:
            if i[0] == idler:
                self.__idlers.remove(i)
                return
        raise ValueError, "Idler %s not found" % `idler`

    def poll(self, timeout = None):
        # Construct poll object
        p = select.poll()
        fdmap = {}
        empty = True
        for c in self.__channels:
            chan = c.getChannel()
            if chan is None:
                continue
            if chan.closed():
                self.unregister(chan)
                continue
            mask = select.POLLIN
            if chan._canSend():
                mask |= select.POLLOUT
            p.register(chan.fileno(), mask)
            empty = False
            fdmap[chan.fileno()] = chan
        # Is there anything to poll?
        if empty:
            return False
        # See if any idlers are interested
        interestedIdlers = []
        idlerDelays = []
        for idler, idlerDelay in self.__idlers:
            if idler.canPollIdle():
                interestedIdlers.append(idler)
                idlerDelays.append(idlerDelay)
        # Compute my timeout
        if len(idlerDelays):
            if timeout is None:
                timeout = min(idlerDelays)
            else:
                timeout = min(min(idlerDelays), timeout)
        # Poll!
        try:
            fds = p.poll(timeout)
        except select.error, (err, errstr):
            if err == errno.EINTR:
                fds = ()
            else:
                raise
        # Invoke idlers
        if len(fds) == 0:
            for idler in interestedIdlers:
                idler.onPollIdle()
        # Invoke FD handlers
        for fd, event in fds:
            if event & (select.POLLIN | select.POLLHUP):
                fdmap[fd]._tickRecv()
                event &= ~(select.POLLIN | select.POLLHUP)
            if event & select.POLLOUT:
                fdmap[fd]._tickSend()
                event &= ~select.POLLOUT
            if event:
                print "Unknown poll event 0x%x" % event
        return True

class IPollable(object):
    def getChannel(self):
        pass

class IPollIdler(object):
    def canPollIdle(self):
        return True

    def onPollIdle(self):
        pass

class IChannel(IPollable):
    def getChannel(self):
        return self
    
    def closed(self):
        pass

    def fileno(self):
        pass

    def _canSend(self):
        pass

    def _tickRecv(self):
        pass

    def _tickSend(self):
        pass

class Packet(object):
    __slots__ = ["cmd", "args"]

    def __init__(self, cmd, args = ""):
        self.cmd = cmd
        if len(cmd) > 4:
            raise ValueError, "Packet cmd must be <= 4 chars long"
        self.args = args

    def encode(self):
        enc = QemuEncoder()
        cmd = self.cmd.ljust(4)
        enc.put_byte(ord(cmd[0]))
        enc.put_byte(ord(cmd[1]))
        enc.put_byte(ord(cmd[2]))
        enc.put_byte(ord(cmd[3]))
        enc.put_buffer(self.args)
        return str(enc)

    @classmethod
    def decode(cls, buf):
        dec = QemuDecoder(buf)
        cmd = chr(dec.get_byte()) + chr(dec.get_byte()) + \
              chr(dec.get_byte()) + chr(dec.get_byte())
        cmd = cmd.strip()
        args = dec.get_buffer()
        return cls(cmd, args)

    def __repr__(self):
        return "<packet %s(%s)>" % (self.cmd, hexDump(self.args))

class Channel(IChannel):
    def __init__(self, fd):
        self._fd = fd
        self.__recvBuf = ""
        self.__recvCmdLen = None
        self.__sendBuf = ""
        self.__closed = False
        self.__handlers = []

    def getChannel(self):
        return self

    #
    # Send/receive queuing and processing
    #
        
    def sendPacket(self, packet):
        enc = QemuEncoder()
        payload = packet.encode()
        enc.put_be32(len(payload))
        enc.put_buffer(payload)
        self.__sendBuf += str(enc)
        if DEBUG_CHANNEL_CMDS:
            print "xctl client -> %s" % packet.cmd

    def __recvPacket(self):
        if self.__recvCmdLen is None and len(self.__recvBuf) >= 4:
            # Have packet size
            dec = QemuDecoder(self.__recvBuf)
            self.__recvCmdLen = dec.get_be32()
            self.__recvBuf = self.__recvBuf[4:]
        if self.__recvCmdLen is not None and \
           len(self.__recvBuf) >= self.__recvCmdLen:
            # Have payload
            payload = self.__recvBuf[:self.__recvCmdLen]
            self.__recvBuf = self.__recvBuf[self.__recvCmdLen:]
            self.__recvCmdLen = None
            return Packet.decode(payload)
        return None

    def closed(self):
        return len(self.__recvBuf) == 0 and self.__closed

    def fileno(self):
        return self._fd.fileno()

    #
    # Low-level send/receive
    #

    def _tickRecv(self):
        try:
            r = self._fd.recv(128)
        except socket.error, (err, errstr):
            if err == errno.ECONNRESET:
                print errstr
                r = ""
            else:
                raise
        if len(r) == 0:
            print "Connection closed"
            self.__closed = True
        else:
            if DEBUG_CHANNEL_DATA:
                print "Received"
                print hexDump(r, terse = False, prefix = "  ")
            self.__recvBuf += r
        self.__invokeHandlers()

    def _canSend(self):
        return len(self.__sendBuf) > 0

    def _tickSend(self):
        try:
            n = self._fd.send(self.__sendBuf)
            if DEBUG_CHANNEL_DATA:
                print "Sent"
                print hexDump(self.__sendBuf[:n], terse = False, prefix = "  ")
            self.__sendBuf = self.__sendBuf[n:]
        except socket.error, arg:
            if arg[0] == 32:
                # Broken pipe
                self.__closed = True
                print "Broken pipe"
            else:
                raise

    #
    # Handlers
    #

    def registerHandler(self, handler):
        assert isinstance(handler, IChannelHandler)
        self.__handlers.append(handler)

    def unregisterHandler(self, handler):
        self.__handlers.remove(handler)

    def __invokeHandlers(self):
        while True:
            packet = self.__recvPacket()
            if packet is not None:
                if DEBUG_CHANNEL_CMDS:
                    print "xctl client <- %s" % packet.cmd
                for handler in self.__handlers:
                    handler.onPacket(self, packet)
            else:
                break
        if self.closed():
            if DEBUG_CHANNEL_CMDS:
                print "xctl client <- (close)"
            for handler in self.__handlers:
                handler.onPacket(self, None)

class IChannelHandler(object):
    def onPacket(self, chan, packet):
        pass

class Listener(IChannel):
    def __init__(self, port):
        self.__port = port
        self.__sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.__sock.bind(("", port))
        self.__sock.listen(5)
        self.__handlers = []

    def closed(self):
        return False

    def fileno(self):
        return self.__sock.fileno()

    def _canSend(self):
        return False

    def _tickRecv(self):
        sock, addr = self.__sock.accept()
        print "Got new connection from %s" % repr(addr)
        channel = Channel(sock)
        for handler in self.__handlers:
            handler.onListenerNewConnection(channel)

    def _tickSend(self):
        print "BUG: Listener cannot send"

    def registerHandler(self, handler):
        assert isinstance(handler, IListenerHandler)
        self.__handlers.append(handler)

    def unregisterHandler(self, handler):
        self.__handlers.remove(handler)

class IListenerHandler(object):
    def onListenerNewConnection(self, chan):
        pass

_exitMainLoop = False

def exitMainLoop():
    global _exitMainLoop
    _exitMainLoop = True

def mainLoop(pl):
    global _exitMainLoop
    _exitMainLoop = False
    while not _exitMainLoop:
        if not pl.poll():
            print "No connections left.  Shutting down client"
            exitMainLoop()
