from pickle import *
import asyncore
import socket
import sys
from heapq import heappush, heappop, heapify
from msg import *
from time import time
from random import random

MAX_BUF_SIZE = 16384    # 16K limit on packet sizes
DEFAULT_TIMEOUT = 30    # in seconds

class Callable:
    """Hack to make class methods work.  From
    http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52304, 9/2005.
    """
    def __init__( self, anycallable ):
        self.__call__ = anycallable

class Network:
    """Sends and receives packets, and handles timeouts

    To use a network object, construct one and register receive and
    timeout functions.  After sending any initial startup messages
    (such as a SYN packet from a client), call L{loop}().  loop does not
    exit until L{close}() has been called, or ctrl-C typed.

    B{Do not} alter this class for the problem set.

    """
    timeouts = []
    heapify(timeouts)
    # we need a separate heap for delayed packets, since we should never
    # have two packets with the same sequence number in the timeouts heap
    delayed_packets = []
    heapify(delayed_packets)

    def __init__( self, host, loss=0, port=-1, debug=False,
                  queue=None, delay=0 ):
        """Initialize Network data members

        @param host: The network hostname to bind to on the local machine
        @param loss: The loss rate of the network, between 0 and 1.0, where
                     1.0 means that all packets are dropped.  (default: 0)
        @param port: The network UDP port to bind to on the local machine.
                     If port is -1, choose a dynamic port (e.g., for a client).
                     (default: -1)
        @param debug: print out janky window and chunk size info
        @param queue: A L{queue.Queue} object, representing a network
                      bottleneck
        @param delay: A delay (floating point, in seconds) to add to each
                      outgoing packet
        """
        # port == -1 indicates a client
        self.socket = NetworkSocket( host, port )
        self.loss = loss   #this is the simulated packet loss rate
        self.registerReceiveFunc( None )
        self.registerTimeoutFunc( None )
        self.exit = False
        self.socket.debug = debug
        self.queue = queue
        self.delay = delay
    
    def sendMessage( self, host, port, msg, timeout=0 ):        
        """Send a message over the network. The message can be sent
        reliably by using a non-zero timeout value. When a message
        sent using this method and a non-zero timeout value is
        correctly received, the corresponding L{cancelTimeout}
        function needs to be called to cancel the scheduled timeout.

        @param host: The DNS name or IP address of the machine to receive
                     the package.
        @param port: The UDP port on the receiver machine that is listening
                     for ReliableMessages.
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} to send to
                    the receiver
        @param timeout: Schedule a timeout for this message in seconds from
                        now.  If the timeout has not been cancelled by that
                        deadline, the registered timeout function will be
                        invoked.  If timeout is 0, no timeout is scheduled.
                        (default: 0)
        """

        # timeout == 0 means no timeout is scheduled
        # timeout is a floating point number in seconds 

        sched_time = 0
        if( self.queue != None ):
            sched_time = self.queue.schedulePacket()

        if( self.delay != 0 and sched_time >= 0 ):
            sched_time = time()
            
        if( sched_time >= 0 and random() > self.loss ):
            if( sched_time == 0 ): # don't delay the packet at all
                self.socket.sendTo( host, port, msg )
            else:
                # we're being rate limited and/or delayed, so schedule
                # this packet as a timeout
                td = TimeoutData( sched_time + self.delay,
                                  host, port, msg,
                                  self.socket.sendTo )
                heappush( Network.delayed_packets, td )
        else:
            print "Dropping msg",msg.seqno,"(type =",msg.type,")"

        # push this timeout on to the stacks of timeouts
        if( timeout != 0 ):
            deadline = time()+timeout
            td = TimeoutData( deadline, host, port, msg, self.timeout_func )
            #print "Scheduling timeout for msg",msg.seqno,"in",timeout
            heappush( Network.timeouts, td )

    def registerReceiveFunc( self, recv_func ):
        """Register a function to be called whenever a packet is received

        @param recv_func: The function to call.  This function should
                          take a hostname, a port number, and a
                          L{ReliableMessage <msg.ReliableMessage>} as its
                          parameters (in that order).
        """
        # recv func should take (host, port, msg) as its parameters
        self.socket.registerReceiveFunc( recv_func )

    def registerTimeoutFunc( self, timeout_func ):
        """Register a function to be called whenever a packet timeouts

        If you call L{sendMessage} with a timeout, Network will call
        your timeout function after that amount of time, unless you
        cancel it with L{cancelTimeout}.

        @param timeout_func: The function to call.  This function should
                             take a hostname, a port number, and a
                             L{ReliableMessage <msg.ReliableMessage>} as its
                             parameters (in that order).
        """
        # timeout func should take (host, port, msg) as its parameters
        self.timeout_func = timeout_func

    def cancelTimeout( self, host, port, seqno ):
        """Cancel a timeout that was schedule previously through L{sendMessage}

        This function needs to be called on every message that has
        been sent with a timeout value using the L{sendMessage} method
        after the message has been successfully received. This
        function currently uses only the seqno parameter to locate the
        timeout corresponding to the message and cancel it.

        @param host: The host to which the message was sent
        @param port: The port to which the message was sent
        @param seqno: The sequence number of the sent message
        @return: the message cancelled if the timeout was removed
                 successfully, None if no such timeout was found.
        """
        # look though the lists of timeouts and delete the one
        # corresponding to this host, port, and seqno
        #print "Cancelling timeout:",host,port,seqno
        for td in Network.timeouts:
            # For now, only compare seqno, since there could be a
            # discrepancy between hostname and IP address
            if( td.msg.seqno == seqno ):
                Network.timeouts.remove(td)
                heapify( Network.timeouts )
                return td.msg  # this is not duplicate removal
        print "Dup ACK - corresponding timer not found"
        return None

    def close( self ):
        """Close this network connection.

        Causes L{loop} to exit
        """
        self.socket.exit = True

    def fireTimeouts( heap ):
        """A class method that takes a heap and calls cancelTimeout for
        all whose deadlines have passed.

        @param heap: A heap object
        """

        deadline = time()+DEFAULT_TIMEOUT
        if( len(heap) > 0 ):

            # fire all timeouts that have passed
            td = heap[0]
            deadline = td.deadline
            while( deadline < time() ):
                heappop( heap )
                td.func( td.host, td.port, td.msg )
                if( len(heap) > 0 ):
                    td = heap[0]
                    deadline = td.deadline
                else:
                    # no more timeouts, so just wait for the usual
                    deadline = time()+DEFAULT_TIMEOUT

        return deadline
    fireTimeouts = Callable(fireTimeouts)

    def loop():
        """Listen for receive events, and generate timeout events
        (class method)

        This is the heart of the Network class.  Once called, it will
        not exit until the application calls L{close} for all sockets
        or the user types ctrl-C.  Each time a network message
        arrives, it calls the registered receive function.  Each time
        an uncancelled timeout occurs, it calls the registered timeout
        function.  Any initial starting messages must be sent
        B{before} loop is called.
        """
        while asyncore.socket_map:

            if( len(Network.timeouts) > 0 or
                len(Network.delayed_packets) > 0 ):

                tdead = Network.fireTimeouts( Network.timeouts )
                pdead = Network.fireTimeouts( Network.delayed_packets )

                deadline = min( tdead, pdead )
                    
                # we call abs() just in case time() > deadline since
                # the check in the while loop above
                timeout = abs(deadline - time())
                asyncore.poll(timeout=timeout)
            else:
                # don't worry, this isn't really looping and polling
                # it's just a select
                asyncore.poll(DEFAULT_TIMEOUT)
    loop = Callable(loop)

class NetworkSocket(asyncore.dispatcher):
    """An internal class used to send and receive asynchronous network events.

    B{Do not modify or directly use this class for the problem set.}

    """

    def __init__( self, host, port=-1 ):
        asyncore.dispatcher.__init__(self)
        self.host = host
        self.port = port
        self.lastsent = 0
        self.create_socket( socket.AF_INET, socket.SOCK_DGRAM )
        addr = (host,port)
        if( port != -1 ):
            self.bind(addr)
        self.buffer = []
        self.exit = False
        self.debug = False
        self.window = 0

    def handle_connect(self):
        pass

    def handle_read( self ):
        data,addr = self.recvfrom(MAX_BUF_SIZE)
        msg = loads( data )
        if( self.debug ):
            if( msg.type == TYPE_ACK ):
                self.window = self.window-1
                print "DEBUG: ACK",self.window
        self.recv_func( addr[0], addr[1], msg )

    def sendTo( self, host, port, msg ):
        self.buffer.append( SocketMessage( host, port, msg ) )

    def writable (self):
        return (len(self.buffer) > 0)

    def handle_write (self):
        smsg = self.buffer.pop(0)
        addr = (smsg.host, smsg.port)
        msg = smsg.msg
        # just assume we can send the whole message
        if( self.debug ):
            if( msg.type == TYPE_DATA ):
                self.window = self.window+1
                print "DEBUG: DATA",len(msg.data),self.window,msg.seqno
        self.sendto( dumps(msg), addr )
        #print time(),"sending new message",msg.seqno
        self.lastsent = time()

        if( self.canExit() ):
            self.close()

    def registerReceiveFunc( self, recv_func ):
        # recv func should take (host, port, msg) as its parameters
        self.recv_func = recv_func

    def canExit( self ):
        # we can exit if close() has been called, and there are no messages
        # to send (who cares if there are timeouts left, close() was called)
        return self.exit and not len( self.buffer )



class TimeoutData:
    """An internal data structure used by L{Network <Network>} to organize
    data associated with timeouts.


    B{Do not modify or directly use this class for the problem set.}
    """

    deadline = 0.0
    
    def __init__( self, deadline, host, port, msg, func ):
        self.deadline = deadline
        self.host = host
        self.port = port
        self.msg = msg
        self.func = func

    def __cmp__( self, other ):
        return cmp( self.deadline, other.deadline )

    def __eq__( self, other ):
        return (self.deadline == other.deadline and
                self.host == other.host and
                self.port == other.port and
                self.msg == other.msg)

    def __ne__( self, other ):
        return (self.deadline != other.deadline or
                self.host != other.host or
                self.port != other.port or
                self.msg != other.msg)

class SocketMessage:
    """An internal data structure used by L{NetworkSocket <NetworkSocket>}
    to organize data associated with messages.


    B{Do not modify or directly use this class for the problem set.}
    """

    def __init__( self, host, port, msg ):
        self.host = host
        self.port = port
        self.msg = msg

