import utils

class NetworkSwitch(object):
    def __init__(self):
        self.__objects = []
        self.__mactoobj = {}
        self.__objtomacs = {}
        self.__observers = []

    def join(self, obj, mac):
        assert isinstance(obj, INetworkEndPoint)
        assert isinstance(mac, utils.MacAddr)
        if mac in self.__mactoobj:
            raise RuntimeError, "Duplicate MAC address %s" % mac
        self.__objects.append(obj)
        self.__mactoobj[mac] = obj
        macs = self.__objtomacs.get(obj, [])
        macs.append(mac)
        self.__objtomacs[obj] = macs

    def leave(self, obj):
        self.__objects.remove(obj)
        macs = self.__objtomacs[obj]
        del self.__objtomacs[obj]
        for mac in macs:
            del self.__mactoobj[mac]

    def addObserver(self, obs):
        assert isinstance(obs, INetworkObserver)
        self.__observers.append(obs)

    def removeObserver(self, obs):
        self.__observers.remove(obs)

    def send(self, st, qt, rt, packet, onlyLocal = False):
        assert isinstance(packet, utils.EthernetPacket)
        src = packet.getSrcMAC()
        dest = packet.getDestMAC()
        
        # Report to all observers, unless onlyLocal
        if not onlyLocal:
            for obs in self.__observers:
                obs.onNetworkPacket(self, st, qt, rt, packet)
                
        if dest.isMulticast():
            # Only once to each unique endpoint
            sentTo = set()
            for obj in self.__objects:
                if obj in sentTo:
                    continue
                sentTo.add(obj)
                if src in self.__objtomacs[obj]:
                    continue
                if not onlyLocal or obj.isNetworkLocalEndpoint():
                    obj.onNetworkPacket(self, st, rt, packet)
        else:
            obj = self.__mactoobj.get(dest, None)
            if obj is None:
                print "Warning: Packet to unknown dest %s" % dest
            else:
                if not onlyLocal or obj.isNetworkLocalEndpoint():
                    obj.onNetworkPacket(self, st, rt, packet)

class INetworkEndPoint(object):
    def onNetworkPacket(self, network, st, rt, packet):
        pass

    def isNetworkLocalEndpoint(self):
        return True

class INetworkObserver(object):
    def onNetworkPacket(self, network, st, qt, rt, packet):
        pass

class NetworkLogger(INetworkObserver):
    def __init__(self):
        self.__log = []

    def resetLog(self):
        self.__log = []

    def getLog(self):
        return self.__log

    def onNetworkPacket(self, network, st, qt, rt, packet):
        self.__log.append((st, qt, rt, packet))
