from pickle import *
import asyncore
import socket
import sys
from heapq import heappush, heappop, heapify
from msg import *
from time import time
from random import random

MAX_BUF_SIZE = 16384    # 16K limit on packet sizes
DEFAULT_TIMEOUT = 30    # in seconds

class Network:
    """Sends and receives packets, and handles timeouts

    To use a network object, construct one and register receive and
    timeout functions.  After sending any initial startup messages
    (such as a SYN packet from a client), call L{loop}().  loop does not
    exit until L{close}() has been called, or ctrl-C typed.

    B{Do not} alter this class for the problem set.

    """

    def __init__( self, host, loss=0, port=-1, debug=False ):
        """Initialize Network data members

        @param host: The network hostname to bind to on the local machine
        @param loss: The loss rate of the network, between 0 and 1.0, where
                     1.0 means that all packets are dropped.  (default: 0)
        @param port: The network UDP port to bind to on the local machine.
                     If port is -1, choose a dynamic port (e.g., for a client).
                     (default: -1)
        """
        # port == -1 indicates a client
        self.socket = NetworkSocket( host, port )
        self.loss = loss   #this is the simulated packet loss rate
        self.timeouts = []
        self.registerReceiveFunc( None )
        self.registerTimeoutFunc( None )
        self.exit = False
        self.socket.debug = debug
    
    def sendMessage( self, host, port, msg, timeout=0 ):        
        """Send a message over the network. The message can be sent
        reliably by using a non-zero timeout value. When a message
        sent using this method and a non-zero timeout value is
        correctly received, the corresponding L{cancelTimeout}
        function needs to be called to cancel the scheduled timeout.

        @param host: The DNS name or IP address of the machine to receive
                     the package.
        @param port: The UDP port on the receiver machine that is listening
                     for ReliableMessages.
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} to send to
                    the receiver
        @param timeout: Schedule a timeout for this message in seconds from
                        now.  If the timeout has not been cancelled by that
                        deadline, the registered timeout function will be
                        invoked.  If timeout is 0, no timeout is scheduled.
                        (default: 0)
        """

        # timeout == 0 means no timeout is scheduled
        # timeout is a floating point number in seconds 
        if( random() > self.loss ):
            self.socket.sendTo( host, port, msg )
        else:
            print "Dropping msg",msg.seqno,"(type =",msg.type,")"

        # push this timeout on to the stacks of timeouts
        if( timeout != 0 ):
            deadline = time()+timeout
            td = TimeoutData( deadline, host, port, msg )
            #print "Scheduling timeout for msg",msg.seqno,"in",timeout
            heappush( self.timeouts, td )

    def registerReceiveFunc( self, recv_func ):
        """Register a function to be called whenever a packet is received

        @param recv_func: The function to call.  This function should
                          take a hostname, a port number, and a
                          L{ReliableMessage <msg.ReliableMessage>} as its
                          parameters (in that order).
        """
        # recv func should take (host, port, msg) as its parameters
        self.socket.registerReceiveFunc( recv_func )

    def registerTimeoutFunc( self, timeout_func ):
        """Register a function to be called whenever a packet timeouts

        If you call L{sendMessage} with a timeout, Network will call
        your timeout function after that amount of time, unless you
        cancel it with L{cancelTimeout}.

        @param timeout_func: The function to call.  This function should
                             take a hostname, a port number, and a
                             L{ReliableMessage <msg.ReliableMessage>} as its
                             parameters (in that order).
        """
        # timeout func should take (host, port, msg) as its parameters
        self.timeout_func = timeout_func

    def cancelTimeout( self, host, port, seqno ):
        """Cancel a timeout that was schedule previously through L{sendMessage}

        This function needs to be called on every message that has
        been sent with a timeout value using the L{sendMessage} method
        after the message has been successfully received. This
        function currently uses only the seqno parameter to locate the
        timeout corresponding to the message and cancel it.

        @param host: The host to which the message was sent
        @param port: The port to which the message was sent
        @param seqno: The sequence number of the sent message
        @return: 1 if the timeout was removed successfully, 0 if no such
                 timeout was found.
        """
        # look though the lists of timeouts and delete the one
        # corresponding to this host, port, and seqno
        #print "Cancelling timeout:",host,port,seqno
        for td in self.timeouts:
            # For now, only compare seqno, since there could be a
            # discrepancy between hostname and IP address
            if( td.msg.seqno == seqno ):
                self.timeouts.remove(td)
                heapify( self.timeouts )
                return 1  # this is not duplicate removal
        print "Dup ACK - corresponding timer not found"
        return 0


    def close( self ):
        """Close this network connection.

        Causes L{loop} to exit
        """
        self.socket.exit = True

    def loop( self ):
        """Listen for receive events, and generate timeout events

        This is the heart of the Network class.  Once called, it will
        not exit until the application calls L{close} or the user types
        ctrl-C.  Each time a network message arrives, it calls the registered
        receive function.  Each time an uncancelled timeout occurs, it calls
        the registered timeout function.  Any initial starting messages
        must be sent B{before} loop is called.
        """
        while asyncore.socket_map:
            if( len(self.timeouts) > 0 ):

                # fire all timeouts that have passed
                td = self.timeouts[0]
                deadline = td.deadline
                while( deadline < time() ):
                    heappop( self.timeouts )
                    self.timeout_func( td.host, td.port, td.msg )
                    if( len(self.timeouts) > 0 ):
                        td = self.timeouts[0]
                        deadline = td.deadline
                    else:
                        # no more timeouts, so just wait for the usual
                        deadline = time()+DEFAULT_TIMEOUT
                    
                # we call abs() just in case time() > deadline since
                # the check in the while loop above
                timeout = abs(deadline - time())
                asyncore.poll(timeout=timeout)
            else:
                # don't worry, this isn't really looping and polling
                # it's just a select
                asyncore.poll(DEFAULT_TIMEOUT)

class NetworkSocket(asyncore.dispatcher):
    """An internal class used to send and receive asynchronous network events.

    B{Do not modify or directly use this class for the problem set.}

    """

    def __init__( self, host, port=-1 ):
        asyncore.dispatcher.__init__(self)
        self.host = host
        self.port = port
        self.create_socket( socket.AF_INET, socket.SOCK_DGRAM )
        addr = (host,port)
        if( port != -1 ):
            self.bind(addr)
        self.buffer = []
        self.exit = False
        self.debug = False
        self.window = 0

    def handle_connect(self):
        pass

    def handle_read( self ):
        data,addr = self.recvfrom(MAX_BUF_SIZE)
        msg = loads( data )
        if( self.debug ):
            if( msg.type == TYPE_ACK ):
                self.window = self.window-1
                print "DEBUG: ACK",self.window
        self.recv_func( addr[0], addr[1], msg )

    # Note: this sendTo function will not handle sending to multiple
    # receivers at the same time.  For that, we would need to parameterize
    # buffer by the host/port.
    def sendTo( self, host, port, msg ):
        self.host_to = host
        self.port_to = port
        self.buffer.append(msg)

    def writable (self):
        return (len(self.buffer) > 0)

    def handle_write (self):
        addr = (self.host_to, self.port_to)
        # just assume we can send the whole message
        msg = self.buffer.pop(0)
        if( self.debug ):
            if( msg.type == TYPE_DATA ):
                self.window = self.window+1
                print "DEBUG: DATA",len(msg.data),self.window,msg.seqno
        self.sendto( dumps(msg), addr )

        if( self.canExit() ):
            self.close()

    def registerReceiveFunc( self, recv_func ):
        # recv func should take (host, port, msg) as its parameters
        self.recv_func = recv_func

    def canExit( self ):
        # we can exit if close() has been called, and there are no messages
        # to send (who cares if there are timeouts left, close() was called)
        return self.exit and not len( self.buffer )



class TimeoutData:
    """An internal data structure used by L{Network <Network>} to organize
    data associated with timeouts.


    B{Do not modify or directly use this class for the problem set.}
    """

    deadline = 0
    
    def __init__( self, deadline, host, port, msg ):
        self.deadline = deadline
        self.host = host
        self.port = port
        self.msg = msg

    def __cmp__( self, other ):
        return self.deadline - other.deadline

    def __eq__( self, other ):
        return (self.deadline == other.deadline and
                self.host == other.host and
                self.port == other.port and
                self.msg == other.msg)

    def __ne__( self, other ):
        return (self.deadline != other.deadline or
                self.host != other.host or
                self.port != other.port or
                self.msg != other.msg)
