import os
import socket

import xctl
import utils
from codec import QemuEncoder, QemuDecoder

DEBUG_RECV_CMDS = True
DEBUG_SEND_CMDS = True

# Time sync commands
DEBUG_RECV_CMDS_VERBOSE = False
DEBUG_SEND_CMDS_VERBOSE = False

class Qemu(xctl.IPollable, xctl.IChannelHandler):
    __slots__ = ["__pid", "__channel", "__handlers", "__macaddr"]
    __albumPort = 5000

    def __init__(self, speed, macaddr, hda, kernel, append, cdrom,
                 label = None, qemu_bin = "../../bin/qemu",
                 album_bin = "../album/album"):
        self.__speed = speed
        self.__macaddr = macaddr
        self.__hda = hda
        self.__kernel = kernel
        self.__append = append
        self.__cdrom = cdrom
        if label is None:
            label = str(macaddr)[-2:]
        self.__label = label

        self.__handlers = []
        self.__channel = None
        self.__startQemu(qemu_bin, album_bin, Qemu.__albumPort)
        Qemu.__albumPort += 1

        if DEBUG_RECV_CMDS:
            self.registerHandler(LoggingHandler(self.getLabel()+":"))

    #
    # Qemu launching
    #

    def __startQemu(self, qemu_bin, album_bin, album_port):
        album2qemu, qemu2album = socket.socketpair()
        apid = os.fork()
        if apid == 0:
            qemu2album.close()
            print "Starting album"
            os.execl(album_bin, album_bin, "%d" % album2qemu.fileno())
        else:
            album2qemu.close()

        me2qemu, qemu2me = socket.socketpair()
        pid = os.fork()
        if pid == 0:
            me2qemu.close()
            print "Starting qemu"
            os.execl(qemu_bin, qemu_bin,
                     # Don't start the VM until we say so
                     "-S",
                     # Disable graphics (must be done before -monitor)
                     #"-nographic",
                     # Disable monitor
                     #"-monitor", "null",
                     "-monitor", "vc",
                     # Enable xctl and connect to the pipe
                     "-xctl", "fd%d" % qemu2me.fileno(),
                     # Set MAC address
                     "-macaddr", self.__macaddr.toString(),
                     # Album fd
                     "-albumfd", "%d" % qemu2album.fileno(),
                     # Load up the linux test image
                     "-snapshot",
                     "-hda",  self.__hda,  #"../rootimage/foo.img",
                     "-kernel", self.__kernel, #"../kernel/bzImage-2.4.32",
                     "-append", self.__append, #"root=/dev/hda1" +
                     #" ide2=noprobe ide3=noprobe ide4=noprobe ide5=noprobe",
                     # Load BBC for the heck of it
                     "-cdrom", self.__cdrom, #"../../experimental/austin/linux-test/bbc-2.1.iso",
                     )
        else:
            self.__pid = pid
            qemu2album.close()
            qemu2me.close()
            self.__channel = xctl.Channel(me2qemu)
            self.__channel.registerHandler(self)

    #
    # Process control
    #

    def wait(self):
        pid, status = os.waitpid(self.__pid, 0)
        if os.WIFEXITED(status):
            return os.WEXITSTATUS(status)
        else:
            return -1

    def kill(self):
        os.kill(self.__pid, 2)

    #
    # Command senders
    #

    def cmdRun(self, bp):
        if DEBUG_SEND_CMDS_VERBOSE:
            print "%s:xctl client -> run(%d)" % (self.getLabel(), bp)
        enc = QemuEncoder()
        enc.put_be64(bp)
        self.__channel.sendPacket(xctl.Packet("run", str(enc)))

    def cmdQuit(self):
        if DEBUG_SEND_CMDS:
            print "%s:xctl client -> quit()" % self.getLabel()
        self.__channel.sendPacket(xctl.Packet("quit"))

    def cmdPacket(self, mid, packet):
        if DEBUG_SEND_CMDS:
            print "%s:xctl client -> packet(%d, %s)" % (self.getLabel(),
                                                        mid, packet)
        pdata = packet.getBytes()
        enc = QemuEncoder()
        enc.put_be64(mid)
        enc.put_be32(len(pdata))
        enc.put_buffer(pdata)
        self.__channel.sendPacket(xctl.Packet("pkt", str(enc)))

    def cmdSave(self, incremental, extra):
        if DEBUG_SEND_CMDS:
            print "%s:xctl client -> save(%d, [%d])" % (self.getLabel(),
                                                        incremental,
                                                        len(extra))
            print utils.hexDump(extra, prefix = "  ")

        enc = QemuEncoder()
        enc.put_be32(incremental)
        enc.put_string(extra)
        self.__channel.sendPacket(xctl.Packet("save", str(enc)))

    def cmdLoad(self, ident):
        if DEBUG_SEND_CMDS:
            print "%s:xctl client -> load(%s)" % (self.getLabel(),
                                                  `ident`)

        enc = QemuEncoder()
        enc.put_string(ident)
        self.__channel.sendPacket(xctl.Packet("load", str(enc)))

    #
    # Channel
    #

    def getChannel(self):
        return self.__channel

    def onPacket(self, chan, packet):
        if packet is None:
            method = "Close"
            args = ()
        else:
            cmd = packet.cmd
            dec = QemuDecoder(packet.args)
            if cmd == "bp":
                method = "BPReached"
                tsc = dec.get_be64()
                args = (tsc,)
            elif cmd == "pkt":
                method = "Packet"
                tsc = dec.get_be64()
                macaddr = utils.MacAddr([dec.get_byte() for _ in range(6)])
                size = dec.get_be32()
                pkt = utils.EthernetPacket(dec.get_buffer(size))
                args = (tsc, macaddr, pkt)
            elif cmd == "pok":
                method = "PacketOk"
                tsc = dec.get_be64()
                mid = dec.get_be64()
                args = (tsc, mid)
            elif cmd == "perr":
                method = "PacketErr"
                tsc = dec.get_be64()
                mid = dec.get_be64()
                args = (tsc, mid)
            elif cmd == "save":
                method = "Saved"
                ident = dec.get_string()
                args = (ident,)
            elif cmd == "load":
                method = "Loaded"
                tsc = dec.get_be64()
                extra = dec.get_string()
                args = (tsc, extra)
            else:
                raise RuntimeError, "Unknown command received %s" % `cmd`
        for handler in self.__handlers:
            func = getattr(handler, "onQemu" + method)
            func(self, *args)

        if packet is None:
            # Close the channel
            self.__channel = None

    #
    # Handlers
    #

    def registerHandler(self, handler):
        assert isinstance(handler, IQemuResponseHandler)
        self.__handlers.append(handler)

    def unregisterHandler(self, handler):
        self.__handlers.remove(handler)

    #
    # Time
    #

    def getTicksPerSec(self):
        return self.__speed

    def ticksToNsec(self, ticks):
        return long(ticks*1e9/self.__speed)

    def nsecToTicks(self, nsec):
        return long(nsec*self.__speed/1e9)

    def ticksToSec(self, ticks):
        return float(ticks)/self.__speed

    def secToTicks(self, sec):
        # Avoid this (floating point inaccuracy)
        return long(sec*self.__speed)

    #
    # Misc
    #

    def getMAC(self):
        return self.__macaddr

    def getLabel(self):
        return self.__label

class IQemuResponseHandler(object):
    def __init__(self):
        pass

    def onQemuClose(self, qemu):
        pass

    def onQemuBPReached(self, qemu, tsc):
        pass

    def onQemuPacket(self, qemu, tsc, macaddr, packet):
        pass

    def onQemuPacketOk(self, qemu, tsc, mid):
        pass

    def onQemuPacketErr(self, qemu, tsc, mid):
        pass

    def onQemuSaved(self, qemu, ident):
        pass

    def onQemuLoaded(self, qemu, tsc, extra):
        pass

class LoggingHandler(IQemuResponseHandler):
    def __init__(self, prefix = ""):
        self.prefix = prefix

    def onQemuClose(self, qemu):
        self.log("Close()")

    def onQemuBPReached(self, qemu, tsc):
        if DEBUG_SEND_CMDS_VERBOSE:
            self.log("BPReached(%d)" % tsc)

    def onQemuPacket(self, qemu, tsc, macaddr, packet):
        self.log("Packet(%d, %s, %s)" % (tsc, macaddr, packet))
        if 1:
            print utils.hexDump(packet.getBytes()[12:], prefix="  ... ")

    def onQemuPacketOk(self, qemu, tsc, mid):
        self.log("PacketOk(%d, %d)" % (tsc, mid))

    def onQemuPacketErr(self, qemu, tsc, mid):
        self.log("PacketErr(%d, %d)" % (tsc, mid))

    def onQemuSaved(self, qemu, ident):
        self.log("Saved(%s)" % `ident`)

    def onQemuLoaded(self, qemu, tsc, extra):
        self.log("Loaded(%d, [%d])" % (tsc, len(extra)))
        if 1:
            print utils.hexDump(extra, prefix="  ... ")

    def log(self, msg):
        print "%sxctl client <- %s" % (self.prefix, msg)
