#!/usr/bin/env python

from msg import *
from server import *
from network import Network
from time import time
from heapq import heappush, heappop, heapify
import sys, os

class AIMDData:
    """This class is a container for data needed to maintain an AIMD-style
    congestion control transport protocol.  This container should contain
    any data used by the L{ccserver.AIMDReliableServer} to manage its
    congestion window.

    The purpose of this class is to facilitate the sharing of congestion
    window data among different servers, each using a diifferent network
    layer.  This allows us to emulate the case if the transport protocol
    does not know that its packets are being split down different paths
    somewhere in the nextwork.

    @cvar SSTHRESH_MAX: The initial default value for ssthresh

    """

    SSTHRESH_MAX = 4

    def __init__( self ):
        self.resetDataFields()
        
    def resetDataFields( self ):
        """Initialize all server data members that are to be used by
        L{ccserver.AIMDReliableServer}.
        """        
        self.cwnd = 1.0
        self.ssthresh = self.SSTHRESH_MAX
        self.outstanding = 0
        self.rtt = 1.5
        self.lastCAMD = 0
        self.outstandingAcks = {}       # maps seqno to number of oo acks
        self.firstPacketTime = 0
        self.packetsSent = 0

class AIMDReliableServer(ReliableServer):
    """This class implements an AIMD-style congestion control transport
    protocol.  It extends L{server.ReliableServer}, and should implement
    the following TCP-style provisions in the context of RMTP:
      1. Slow start
      2. Additive increase / multiplicative decrease of the congestion window
      3. Fast retransmissions
    This can be done by overriding key methods from the
    L{server.ReliableServer} class, such as
    L{handleAck <server.ReliableServer.handleAck>} and
    L{handleTimeout <server.ReliableServer.handleTimeout>}.

    Note: Please be sure to place all data fields in the L{AIMDData}
    object, so that the data can be shared across multiple connections
    if desired.  It is a bit tricky, but necessary, to place the
    L{ReliableServer.outstanding} field in this object as well.  Make
    sure to synchronize self.data.outstanding with self.outstanding
    whenever either variable changes.
    """

    def __init__( self, data, network, window=4, chunk_size=1000,
                  fast_rxmit=True ):
        """Initialize server data members, and register network functions.

        @param data: An L{AIMDData} object
        @param network: A configured L{network.Network} object
        @param window: Irrelevant for an L{AIMDReliableServer} object
        @param chunk_size: The size of the data chunks sent to the client.
                           (default: 1000)
        @param fast_rxmit: Whether or not to do fast retransmissions

        """
        
        self.data = data
        self.fast_rxmit = fast_rxmit
        ReliableServer.__init__( self, network, window, chunk_size )
        

    def resetFields( self ):
        """Reset data members to their default values.

        Call this on initialization, and between client connections.
        """
        ReliableServer.resetFields(self)
        self.data.resetDataFields()
         # file handle to dump cwnd
        self.data.cwndlog = 0

    def sendData( self, host, port, msg ):
        """Send data messages to the client

        This method sends as many data messages to the client as possible,
        while staying within the current congestion window.

        At the very least, this method should call the parent class's
        sendData method, L{server.ReliableServer.sendData}.  That
        method can be called as follows:
        ReliableServer.sendData( self, host, port, msg )

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type
                    SYNACKACK or type ACK

        """
##         print "self.outstanding", self.outstanding
##         print "self.data.outstanding", self.data.outstanding
        # Advise self.network.sendMessage to find out the seqnos of
        # packets being sent because we can't modify
        # ReliableServer. Damnit.
        if self.data.firstPacketTime == 0:
            self.data.firstPacketTime = time()
            
        oldSendMessage = self.network.sendMessage
        try:
            def sendMessageAdvice(host, port, msg, timeout=0):
                self.data.outstandingAcks[msg] = 0
                self.data.packetsSent += 1
                oldSendMessage(host, port, msg, timeout)
            self.network.sendMessage = sendMessageAdvice

#            self.outstanding = self.data.outstanding
            ReliableServer.sendData(self, host, port, msg)
#            self.data.outstanding = self.outstanding
        finally:
            self.network.sendMessage = oldSendMessage
    
    def handleAck( self, host, port, msg ):
        """Receive a mesage of type ACK from the client

        The client is acknowledging that it received a data message.
        Update the relevant data members (like the RTT estimate, the
        congestion window, etc), cancel any timeouts, and send more
        data.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type ACK

        """
        if self.state == self.STATE_DATA:
            if self.data.cwnd < self.data.ssthresh:
                # Slow-start
                cwndOld = self.data.cwnd
                self.data.cwnd += 1.0
                print "Slow-start ack, cwnd", cwndOld, "->", self.data.cwnd
                self.dumpCwndLog()
            else:
                # Congestion avoidance
                cwndOld = self.data.cwnd
                self.data.cwnd += (1.0 / self.data.cwnd)
                print "Cong-avoid ack, cwnd", cwndOld, "->", self.data.cwnd
                self.dumpCwndLog()
                
                
            # Check for out-of-order ack
            for x in self.data.outstandingAcks.keys():
                if x.timestamp < msg.timestamp:
                    self.data.outstandingAcks[x] += 1
                    if (self.data.outstandingAcks[x] >= 3):
                        # Third oo ack
                        print "Third out-of-order ack for seqno", x.seqno
                        del self.data.outstandingAcks[x]
                        if self.fast_rxmit:
                            rxmsg = x
                            rxmsg.timestamp = time()
                            self.network.cancelTimeout(host, port, x.seqno)
                            self.network.sendMessage(host, port, rxmsg,
                                                     self.calcTimeout(host,
                                                                      port))
                            self.data.outstandingAcks[rxmsg] = 0
                            if (self.data.lastCAMD + self.data.rtt) < time():
                                self.data.lastCAMD = time()
                                cwndOld = self.data.cwnd
                                self.data.cwnd /= 2.0
                                print "Fast retransmit, cwnd", cwndOld, "->", self.data.cwnd
                                self.dumpCwndLog()
                elif x.seqno == msg.seqno:
                    del self.data.outstandingAcks[x]


#        self.outstanding = self.data.outstanding
#        ReliableServer.handleAck(self, host, port, msg)
#        self.data.outstanding = self.outstanding
#        self.sendData(host, port, msg)

        # Copy-paste ReliableServer.handleAck here in order to deal
        # with stupid fucking virtualization and the need to
        # synchronize self.outstanding and self.data.outstanding. Hey,
        # it's not like we need inheritance, right?
        if( self.state == self.STATE_DATA ):
            rtt_est = time() - msg.timestamp   # in seconds
            self.updateRTT( host, port, rtt_est )

            # check if dup ack and advance window only if it is not dup ack
            print "Got ACK for",msg.seqno
            cancel = self.network.cancelTimeout( host, port, msg.seqno )
            if( cancel != None ): # if not dup ack
                self.data.outstanding = self.data.outstanding-1
                self.outstanding -= 1
                self.sendData( host, port, msg )


    def handleTimeout( self, host, port, msg ):
        """A message has timed out.

        A sent message has not yet been acknowledged, or else someone
        forgot to call L{cancelTimeout <network.Network.cancelTimeout>}.
        Take the appropriate action (eg: resending the message), and
        update any state variables (such as the congestion window).

        @param host: The intended receiver of the timed out message
        @param port: The UDP port of the receiver
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} that timed out

        """
        if self.state == self.STATE_DATA:
#            if (self.data.lastCAMD + self.data.rtt) < time():
#                self.data.lastCAMD = time()
                cwndOld = self.data.cwnd
                ssthreshOld = self.data.ssthresh
                self.data.ssthresh = self.data.cwnd / 2.0
                self.data.cwnd = 1.0
                print "TIMEOUT, cwnd", cwndOld, "->", self.data.cwnd, "ssthresh", ssthreshOld, "->", self.data.ssthresh
                self.dumpCwndLog()
                
#        self.outstanding = self.data.outstanding
        ReliableServer.handleTimeout(self, host, port, msg)
#        self.data.outstanding = self.outstanding
        

    def calcWindow( self, host, port ):
        """Calculate the current window size for a client

        You must override this method, since the one in
        L{server.ReliableServer} returns a fixed window size, and does
        not use the L{AIMDData} object.
        
        @param host: The client host
        @param port: The UDP port of the client
        @return: the calculated window size (as an integer)

        """
        return int(self.data.cwnd)

    def updateRTT( self, host, port, rtt_est ):
        """Update the round trip time estimate to the client

        You must override this method, since the one in
        L{server.ReliableServer} does not use the L{AIMDData} object.
        However, you should update it using the same formula used in
        L{server.ReliableServer.updateRTT}.
        
        @param host: The client host
        @param port: The UDP port of the client
        @param rtt_est: A new measurement of the round trip time
        
        """
        self.data.rtt = self.alpha*self.data.rtt + (1-self.alpha)*rtt_est

    def calcTimeout( self, host, port ):
        """Calculate a timeout for a message to a client
        
        @param host: The client host
        @param port: The UDP port of the client
        @return: the calculated timeout (in seconds)

        """

        return self.beta*self.data.rtt

    def openCwndLog(self):
        """Opens a logfile for logging congestion window data
        """
        
        if(self.data.cwndlog == 0):
            if "CWNDSCREW" in os.environ:
                ratedelay = self.network.queue.ratedelay
                if ratedelay == 0:
                    rate = 0
                else:
                    rate = 1.0/ratedelay
                filename = "%sloss%02d-rate%02d-queue%02d-delay%3.1f" % \
                           (os.environ["CWNDPREFIX"],
                            int(self.network.loss * 100), rate,
                            self.network.queue.queue_size, self.network.delay)
            else:
                if "CWNDPREFIX" in os.environ:
                    filename = os.environ["CWNDPREFIX"]+'cwnd'+str(self.getID())
                else:
                    filename = 'cwnd'+str(self.getID())
            self.data.cwndlog = open(filename, 'w')
            print "opened file " + filename

    def dumpCwndLog(self):
        """Writes the current time and congestion window to the log file
        """
        self.openCwndLog() # opens if not previously open
        # Note: obtain window without having a host or port.  Breaks
        # abstraction we've worked so hard to preserve.
        self.data.cwndlog.write(str(time() - self.data.firstPacketTime)+' '+
                                str(self.calcWindow( None, 0 ))+' '+
                                str(self.data.packetsSent)+'\n')
        print "Dumping onto cwnd log"
        self.data.cwndlog.flush()

    def closeCwndLog(self):
        """Closes (and flushes) congestion window log
        """
        if(self.data.cwndlog):
            self.data.cwndlog.close()
            #self.data.cwndlog = 0
    
