#!/usr/bin/env python

from msg import *
from server import *
from network import Network
from time import time
from heapq import heappush, heappop, heapify
import os, sys

# class ACKSet:
#     """Efficiently represents a set of ACKs"""

#     def __init__(self):
#         self.__ranges = []

#     def add(self, seqno):
#         """Add a received seqno to the set"""

#         for i, r in enumerate(self.__ranges):
#             if r[0] <= seqno <= r[1]:
#                 print "Warning: Seqno %d already in set" % seqno
#                 return
#             if seqno < r[0]:
#                 high = low = seqno
#                 # Merge with next range?
#                 if seqno + 1 == r[0]:
#                     self.__ranges.remove(r)
#                     high = r[1]
#                 # Merge with previous range?
#                 if i > 0:
#                     rpre = self.__ranges[i-1]
#                     if rpre[1] + 1 == seqno:
#                         self.__ranges.remove(rpre)
#                         i -= 1
#                     low = rpre[0]
#                 # Insert
#                 self.__ranges.insert(i, (low, high))
#                 return
#         else:
#             # Append
#             high = low = seqno
#             if len(self.__ranges) > 0:
#                 last = self.__ranges[-1]
#                 if last[1] + 1 == seqno:
#                     self.__ranges.remove(last)
#                     low = last[0]
#             self.__ranges.append((low, high))

#     def __contains__(self, seqno):
#         for r in self.__ranges:
#             if r[0] <= seqno <= r[1]:
#                 return True
#             if seqno < r[0]:
#                 break
#         return False

#     def missed(self):
#         """Returns a list of missed ACKs

#         These are ACKs that should have been acknowledged.  A missed ACK
#         is an ACK that is not in the set, but a later ACK is.
#         """

#         ret = []
#         expecting = 0
#         for r in self.__ranges:
#             ret.extend(range(expecting, r[0]))
#             expecting = r[1] + 1
#         return ret

#     def _test(cls):
#         s = cls()
#         assert s.__ranges == []
#         assert s.missed() == []
#         s.add(3)
#         assert s.__ranges == [(3,3)]
#         assert s.missed() == [0,1,2]
#         s.add(1)
#         assert s.__ranges == [(1,1),(3,3)]
#         assert s.missed() == [0,2]
#         s.add(2)
#         assert s.__ranges == [(1,3)]
#         assert s.missed() == [0]
#         s.add(4)
#         assert s.__ranges == [(1,4)]
#         assert s.missed() == [0]
#         s.add(6)
#         assert s.__ranges == [(1,4),(6,6)]
#         assert s.missed() == [0,5]
#         s.add(0)
#         assert s.__ranges == [(0,4),(6,6)]
#         assert s.missed() == [5]
#     _test = classmethod(_test)

class OutstandingMessage(object):
    """Represents a single outstanding message"""

    __slots__ = ["msg", "missedcount", "retransmitted"]

    def __init__(self, msg):
        self.msg = msg
        self.missedcount = 0

class AIMDData(object):
    """This class is a container for data needed to maintain an AIMD-style
    congestion control transport protocol.  This container should contain
    any data used by the L{ccserver.AIMDReliableServer} to manage its
    congestion window.

    The purpose of this class is to facilitate the sharing of congestion
    window data among different servers, each using a diifferent network
    layer.  This allows us to emulate the case if the transport protocol
    does not know that its packets are being split down different paths
    somewhere in the nextwork.

    @cvar SSTHRESH_MAX: The initial default value for ssthresh

    """

    SSTHRESH_MAX = 4
    __slots__ = ["cwnd", "lastcwndhalftime", "ssthresh", "rtt",
                 # Stuff added by staff code
                 "outstanding", "cwndlog"]

    def __init__( self ):
        self.resetDataFields()
        
    def resetDataFields( self ):
        """Initialize all server data members that are to be used by
        L{ccserver.AIMDReliableServer}.
        """
        self.cwnd = 1.0
        self.lastcwndhalftime = 0
        self.ssthresh = self.SSTHRESH_MAX
        self.rtt = 1.5
        self.outstanding = 0

class AIMDReliableServer(ReliableServer):
    """This class implements an AIMD-style congestion control transport
    protocol.  It extends L{server.ReliableServer}, and should implement
    the following TCP-style provisions in the context of RMTP:
      1. Slow start
      2. Additive increase / multiplicative decrease of the congestion window
      3. Fast retransmissions
    This can be done by overriding key methods from the
    L{server.ReliableServer} class, such as
    L{handleAck <server.ReliableServer.handleAck>} and
    L{handleTimeout <server.ReliableServer.handleTimeout>}.

    Note: Please be sure to place all data fields in the L{AIMDData}
    object, so that the data can be shared across multiple connections
    if desired.  It is a bit tricky, but necessary, to place the
    L{ReliableServer.outstanding} field in this object as well.  Make
    sure to synchronize self.data.outstanding with self.outstanding
    whenever either variable changes.
    """

    def __init__( self, data, network, window=4, chunk_size=1000,
                  fast_rxmit=True ):
        """Initialize server data members, and register network functions.

        @param data: An L{AIMDData} object
        @param network: A configured L{network.Network} object
        @param window: Irrelevant for an L{AIMDReliableServer} object
        @param chunk_size: The size of the data chunks sent to the client.
                           (default: 1000)
        @param fast_rxmit: Whether or not to do fast retransmissions

        """
        
        self.data = data
        self.fast_rxmit = fast_rxmit
        # XXX Does unacknowledged play well with MH?
        self.unacknowledged = {}
        self.sent = 0
        self.start = 0
        ReliableServer.__init__( self, network, window, chunk_size )

        # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
        #self.window = 20

    def resetFields( self ):
        """Reset data members to their default values.

        Call this on initialization, and between client connections.
        """
        ReliableServer.resetFields(self)
        self.data.resetDataFields()
         # file handle to dump cwnd
        self.data.cwndlog = 0

    def sendData( self, host, port, msg ):
        """Send data messages to the client

        This method sends as many data messages to the client as possible,
        while staying within the current congestion window.

        At the very least, this method should call the parent class's
        sendData method, L{server.ReliableServer.sendData}.  That
        method can be called as follows:
        ReliableServer.sendData( self, host, port, msg )

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type
                    SYNACKACK or type ACK

        """
        # Setup timing information
        if self.start == 0:
            self.start = time()

        # First, check for retransmits that need to be done

        # Deal with out-of-order ACKs (if fast_rxmit is set)
        rexmits = []
        if self.fast_rxmit:
            for seqno, unack in self.unacknowledged.items():
                if unack.missedcount >= 3:
                    rexmits.append(unack)
            rexmits.sort(lambda a,b: cmp(a.msg.seqno, b.msg.seqno))

        # Half cwnd if performing retransmits
        if len(rexmits) > 0:
            now = time()
            if now - self.data.lastcwndhalftime >= self.data.rtt:
                print "Halving cwnd due to retransmits (%g -> %g)" % \
                      (self.data.cwnd, self.data.cwnd / 2)
                self.data.cwnd /= 2
                self.data.lastcwndhalftime = now
                self.dumpCwndLog()
            else:
                print "Not halving cwnd (too early)"

        # Send retransmits
        for unack in rexmits:
            print "Fast retransmitting message %d" % unack.msg.seqno
            # Cancel pending timeout
            self.network.cancelTimeout(host, port, unack.msg.seqno)
            # Retransmit
            unack.msg.timestamp = time()
            unack.missedcount = 0
            self.network.sendMessage(host, port, unack.msg,
                                     self.calcTimeout(host, port))

        # Second, send more file data
        
        # Abandon hope, all ye who enter.  We need to snatch packets
        # in order to do fast retransmits of them, so advise
        # self.network.sendMessage.
        oldsend = self.network.sendMessage
        def mySendMessage(host, port, msg, timeout=0):
            self.unacknowledged[msg.seqno] = OutstandingMessage(msg)
            self.sent += 1
            return oldsend(host, port, msg, timeout)
        try:
            self.network.sendMessage = mySendMessage

            # Actually send more file data
            ReliableServer.sendData(self, host, port, msg)
            self.showProgress(port)
        finally:
            # Unadvise self.network.sendMessage
            self.network.sendMessage = oldsend

    def handleAck( self, host, port, msg ):
        """Receive a mesage of type ACK from the client

        The client is acknowledging that it received a data message.
        Update the relevant data members (like the RTT estimate, the
        congestion window, etc), cancel any timeouts, and send more
        data.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type ACK

        """
        # Update cwnd
        if self.data.cwnd < self.data.ssthresh:
            self.data.cwnd += 1
        else:
            self.data.cwnd += 1.0/self.data.cwnd
        self.dumpCwndLog()

        # Remove from unacknowledged packets (these are only packets
        # sent by sendData, so this may be an acknowledgement for some
        # other packet or a dup)
        if msg.seqno in self.unacknowledged:
            ts = self.unacknowledged[msg.seqno].msg.timestamp
            del self.unacknowledged[msg.seqno]

            # Updated missed counts of unacknowledged packets
            for unack in self.unacknowledged.values():
                if unack.msg.timestamp < ts:
                    unack.missedcount += 1
        else:
            print "%d: Got unknown ACK %d" % (self.getID(), msg.seqno)

        # Fsck inheritence.  I would pass the buck the
        # ReliableServer.handleAck like in a normal, sane
        # object-oriented system, but it modifies self.outstanding,
        # then calls self.sendData, which has been virtualized to
        # MultiHomedReliableServer, which promptly scribbles over
        # self.outstanding.  So, instead, I employ the time-honored
        # technique of copy-paste coding (see "Copy-Paste Coding the
        # Ultimate Abstraction") to avoid the looked-down-upon
        # technique of inheritence.  However!  We can't just change
        # this to use self.data.outstanding because then it won't work
        # non-multi-homed, so we change it to modify both.  In the
        # non-multi-homed case self.data.outstanding will become
        # bogus, but nothing will read from it.
        
        #ReliableServer.handleAck(self, host, port, msg)

        if( self.state == self.STATE_DATA ):
            rtt_est = time() - msg.timestamp   # in seconds
            self.updateRTT( host, port, rtt_est )

            # check if dup ack and advance window only if it is not dup ack
            print "Got ACK for",msg.seqno
            cancel = self.network.cancelTimeout( host, port, msg.seqno )
            if( cancel != None ): # if not dup ack
                self.data.outstanding -= 1
                self.outstanding -= 1
                self.sendData( host, port, msg )

    def handleTimeout( self, host, port, msg ):
        """A message has timed out.

        A sent message has not yet been acknowledged, or else someone
        forgot to call L{cancelTimeout <network.Network.cancelTimeout>}.
        Take the appropriate action (eg: resending the message), and
        update any state variables (such as the congestion window).

        @param host: The intended receiver of the timed out message
        @param port: The UDP port of the receiver
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} that timed out

        """
        # Uh oh!  Enter slow start
        print "%d: Entering slow start due to timeout of %d" % \
              (self.getID(), msg.seqno)
        self.data.ssthresh = int(self.data.cwnd / 2)
        self.data.cwnd = 1.0
        self.dumpCwndLog()

        ReliableServer.handleTimeout(self, host, port, msg)

    def calcWindow( self, host, port ):
        """Calculate the current window size for a client

        You must override this method, since the one in
        L{server.ReliableServer} returns a fixed window size, and does
        not use the L{AIMDData} object.
        
        @param host: The client host
        @param port: The UDP port of the client
        @return: the calculated window size (as an integer)

        """
        return int(self.data.cwnd)

    def updateRTT( self, host, port, rtt_est ):
        """Update the round trip time estimate to the client

        You must override this method, since the one in
        L{server.ReliableServer} does not use the L{AIMDData} object.
        However, you should update it using the same formula used in
        L{server.ReliableServer.updateRTT}.
        
        @param host: The client host
        @param port: The UDP port of the client
        @param rtt_est: A new measurement of the round trip time
        
        """
        self.data.rtt = self.alpha*self.data.rtt + (1-self.alpha)*rtt_est

    def calcTimeout( self, host, port ):
        """Calculate a timeout for a message to a client
        
        @param host: The client host
        @param port: The UDP port of the client
        @return: the calculated timeout (in seconds)

        """

        return self.beta*self.data.rtt

    def calcRate(self):
        return self.sent/(time() - self.start)

    def openCwndLog(self):
        """Opens a logfile for logging congestion window data
        """
        
        if(self.data.cwndlog == 0):
            prefix = os.environ.get("CWNDPREFIX", "")
            if len(prefix):
                prefix += "-"
            if "CWNDSCREW" in os.environ:
                ratedelay = self.network.queue.ratedelay
                if ratedelay == 0:
                    rate = 0
                else:
                    rate = 1.0/ratedelay
                filename = "%sloss%02d-rate%02d-queue%02d-delay%3.1f" % \
                           (prefix, int(self.network.loss*100), rate,
                            self.network.queue.queue_size, self.network.delay)
            else:
                filename = "%scwnd%d" % (prefix, self.getID())
            self.data.cwndlog = open(filename, 'w')
            print "opened file "+filename

    def dumpCwndLog(self):
        """Writes the current time and congestion window to the log file
        """
        self.openCwndLog() # opens if not previously open
        # Note: obtain window without having a host or port.  Breaks
        # abstraction we've worked so hard to preserve.
        #self.data.cwndlog.write(str(time())+' '+
        #                        str(self.calcWindow( None, 0 ))+'\n')
        print >> self.data.cwndlog, "%s %g %d %g %g" % \
              (time(), self.data.cwnd, self.data.ssthresh,
               self.calcRate(), self.data.rtt)
        self.data.cwndlog.flush()
        #print "Dumping onto cwnd log"

    def closeCwndLog(self):
        """Closes (and flushes) congestion window log
        """
        if(self.data.cwndlog):
            self.data.cwndlog.close()
            #self.data.cwndlog = 0

    def showProgress(self, port):
        try:
            seqnos = list(self.unacknowledged.keys())
            seqnos.sort()
            if len(seqnos):
                c = []
                first = int(min(seqnos)/8)*8
                for seqno in range(first, max(seqnos)+1):
                    if seqno in seqnos:
                        c.append("+")
                    else:
                        c.append("=")
                print "!!! %d: %s (%s)" % (self.getID(), "".join(c),
                                           " ".join(map(str,seqnos)))
#                 print "!!! %d: %s" % (self.getID(), "".join(c))
            else:
                print "!!! %d: No outstanding" % self.getID()
            print ("!!! %d: cwnd: %3g  ssthresh: %1d  rtt: %3g"
                   "  rate: %3g p/s" % \
                   (self.getID(), self.data.cwnd, self.data.ssthresh, \
                    self.data.rtt, self.calcRate()))
            print "!!! %d: oustanding: %d, %d" % (self.getID(),
                                                  self.outstanding,
                                                  self.data.outstanding)
        except:
            print "!!! showProgress exception"

