#!/usr/bin/env python

from socket import *
from vserver import *
from msg import *
from optparse import OptionParser
from network import Network
from random import random
from time import time
import sys

class ReliableServer(VirtualServer):
    """A class implementing the server side of the reliable 6.829 message
    transfer protocol.

    This class is capable of serving one L{client.ReliableClient} at a time.
    From the client, it receives a request for a file during the SYN handshake,
    send the data to the client one chunk at a time, and then closes the
    connection.

    The server keeps a weighted-average estimate of the round trip time to
    the client.  To keep this up-to-date, it should set the timestamp
    of each message it sends with the current time (i.e., call time()).
    When it receives a message, the server should calculate the round-trip
    time using the timestamp of the new message (which is simply echoed by the
    client).  The server should call L{updateRTT} with this new measurement.

    The server should set timeouts of messages using the L{calcTimeout}
    method.

    @cvar STATE_SENT_SYNACK:  Corresponds to a server state in which the
                              server sent the SYNACK, but no data
    @cvar STATE_DATA: Corresponds to a server state in which the server is
                      actively sending data packets
    @cvar STATE_SENT_FIN: Corresponds to a server state in which the server is
                          finished sending data, and is trying to close the
                          connection

    @ivar network: The server's L{network.Network} object
    @ivar window: How many messages can be outstanding to the client at
                  any one time
    @ivar chunk_size: The size of the data chunks sent to the client.
    @ivar state: The current server state (e.g., STATE_SENT_SYNACK)
    @ivar filename: The name of the file requested by the client
    @ivar fd: The read-only file descriptor corresponding to the filename
    @ivar outstanding: The number of outstanding (un-ACKed) messages
    @ivar seqno: The sequence number of the next message to the client
    @ivar done: True if the server has sent all data chunks to the client
    @ivar rtt: The estimated round trip time to the client
    @ivar alpha: How much to weight the current rtt estimate
    @ivar beta: Timeout factor
    """

    STATE_SENT_SYNACK = 1
    STATE_DATA = 2
    STATE_SENT_FIN = 3
    

    #####################################################################
    ##  Initialization
    #####################################################################
    def __init__( self, network, window=4, chunk_size=1000 ):
        """Initialize server data members, and register network functions.

        @param network: A configured L{network.Network} object
        @param window: How many messages can be outstanding to the client at
                       any one time (default: 4)
        @param chunk_size: The size of the data chunks sent to the client.
                           (default: 1000)

        """
        self.window = window
        self.chunk_size = chunk_size
        self.network = network
        self.resetFields()

        self.network.registerReceiveFunc( self.receiveMessage )
        self.network.registerTimeoutFunc( self.handleTimeout )
        self.state = 0
        
    def resetFields( self ):
        """Reset data members to their default values.

        Call this on initialization, and between client connections.
        """
        self.filename = ""
        self.fd = 0
        self.outstanding = 0
        self.seqno = 0
        self.done = False
        self.state = 0
        
        # rtt/timeout estimates default values
        self.rtt = 1.5
        self.alpha = 0.9
        self.beta = 4

    #####################################################################
    ##  User-defined functions
    #####################################################################

    def handleSyn( self, host, port, msg ):
        """Receive a mesage of type SYN from a client

        The message contains the name of a file to open (in the
        variable msg.data).  Open the file for reading; if the file
        opens successfully, respond to the client with a SYNACK.
        Otherwise, respond with a SYNRETRY.

        The server also picks a (possibly random) sequence number for the
        SYNACK packet and increments this sequence number for every subsequent
        data packet.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type SYN

        """

        if( not (self.state == 0 or self.state == self.STATE_SENT_SYNACK) ):
            return
        
        self.setSeqNo( host, port )

        # The file to transport
        filename = msg.data
        msgtype = TYPE_SYNACK
        self.state = self.STATE_SENT_SYNACK
        
        if( not self.openFile( host, port, filename ) ):
            msgtype = TYPE_SYNRETRY
                
        synack = ReliableMessage( msgtype, self.nextSeqNo( host, port ),
                                  time(), self.chunk_size )
        # don't send reliably, since it's up to the client to retransmit
        # the SYN if the handshake doesn't work
        self.network.sendMessage( host, port, synack )        
        if(msgtype == TYPE_SYNRETRY):
            self.resetFields()
            
    def handleSynAckAck( self, host, port, msg ):
        """Receive a mesage of type SYNACKACK from the client

        The client is confirming that it received the SYNACK.  Thus, the server
        updates the RTT estimate and begins sending data packets to the client.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type
                    SYNACKACK

        """
        
        if(self.state == self.STATE_SENT_SYNACK):
            rtt_est = time() - msg.timestamp   # in seconds
            self.updateRTT( host, port, rtt_est )
            self.sendData( host, port, msg )

    def handleAck( self, host, port, msg ):
        """Receive a mesage of type ACK from the client

        The client is acknowledging that it received a data message.  Update
        the relevant data members (like RTT estimate), cancel any timeouts,
        and send more data.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type ACK

        """
        if( self.state == self.STATE_DATA ):
            rtt_est = time() - msg.timestamp   # in seconds
            self.updateRTT( host, port, rtt_est )

            # check if dup ack and advance window only if it is not dup ack
            print "Got ACK for",msg.seqno
            cancel = self.network.cancelTimeout( host, port, msg.seqno )
            if( cancel != None ): # if not dup ack
                self.outstanding = self.outstanding-1
                self.sendData( host, port, msg )

    def sendData( self, host, port, msg ):
        """Send data messages to the client

        This method sends as many data messages to the client as possible,
        while maintaining the fixed L{window} size.  It sends the messages
        reliably.  Close the file once all data has been sent.

        If all the file's data chunks have been sent and acknowledged,
        initialize the closing phase of the connection by sending the FIN
        packet.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type
                    SYNACKACK or type ACK

        """

        # Keep on sending the file
        while( self.outstanding < self.calcWindow( host, port )
               and not self.done ):
            data = self.nextDataChunk( host, port )
            if( data != "" ):
                self.state = self.STATE_DATA
                datamsg = ReliableMessage( TYPE_DATA,
                                           self.nextSeqNo( host, port ),
                                           time(), data )
                self.network.sendMessage( host, port, datamsg,
                                          self.calcTimeout( host, port ) )
                print "Sent data for", datamsg.seqno, "on server", self.getID()
                self.outstanding = self.outstanding+1
                
            else:
                # end of file
                self.done = True
                print "Closing file"
                self.closeFile( host, port )

        # if we've got all the acks back, send FIN
        if( self.canSendFin( host, port ) ):
            fin = ReliableMessage( TYPE_FIN, self.nextSeqNo( host, port ) )
            self.state = self.STATE_SENT_FIN
            self.network.sendMessage( host, port, fin,
                                      self.calcTimeout( host, port ) )
            print "Sent FIN!"

    def handleFinAck( self, host, port, msg ):
        """Receive a mesage of type FINACK from the client

        The client is acknowledging that it received a FIN.  Thus, the
        connection is officially closed.  Reset all necessary state

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type FINACK

        """

        if(self.state == self.STATE_SENT_FIN):
            # done! delete connection state
            print "File download successful. Connection to client closed."
            self.network.cancelTimeout( host, port, msg.seqno )
            self.resetFields()

    def handleTimeout( self, host, port, msg ):
        """A message has timed out.

        A sent message has not yet been acknowledged, or else someone
        forgot to call L{cancelTimeout <network.Network.cancelTimeout>}.
        Take the appropriate action (eg: resending the message). 

        @param host: The intended receiver of the timed out message
        @param port: The UDP port of the receiver
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} that timed out

        """

        print "This message timed out:", msg.seqno
        msg.timestamp = time()
        self.network.sendMessage( host, port, msg,
                                  self.calcTimeout( host, port ) )

    def updateRTT( self, host, port, rtt_est ):
        """Update the round trip time estimate to the client
        
        @param host: The client host
        @param port: The UDP port of the client
        @param rtt_est: A new measurement of the round trip time
        
        """

        self.rtt = self.alpha*self.rtt + (1-self.alpha)*rtt_est

    def calcTimeout( self, host, port ):
        """Calculate a timeout for a message to a client
        
        @param host: The client host
        @param port: The UDP port of the client
        @return: the calculated timeout (in seconds)

        """

        return self.beta*self.rtt

    def calcWindow( self, host, port ):
        """Calculate the current window size for a client
        
        @param host: The client host
        @param port: The UDP port of the client
        @return: the calculated window size

        """
        
        return self.window

    def openFile( self, host, port, filename ):
        """Open the given file for the given host.

        @param host: The client host
        @param port: The UDP port of the client
        @param filename: The file to open
        @return: True if the file was successfully opened.
        """
        
        try:
            self.fd = open( filename, "r" )
            return True
        except:
            print "Error: File doesn't exist. Sending retry to client"
            return False

    def closeFile( self, host, port ):
        """Close the file associated with the host

        @param host: The client host
        @param port: The UDP port of the client
        """
        self.fd.close()

    def nextDataChunk( self, host, port ):
        """Read the next chunk of data for this host

        @param host: The client host
        @param port: The UDP port of the client
        """
        return self.fd.read( self.chunk_size )

    def setSeqNo( self, host, port ):
        """Make a new starting sequence number for this host

        @param host: The client host
        @param port: The UDP port of the client
        """
        
        # Pick a random sequence number
        self.seqno = int( random() * 16384 )

    def getCurrSeqNo(self,host,port):
        """What is this host's current sequence number?

        @param host: The client host
        @param port: The UDP port of the client
        @return: The sequence number in question
        """
        return self.seqno

    def nextSeqNo( self, host, port ):
        """Increment and return this host's sequence number

        @param host: The client host
        @param port: The UDP port of the client
        @return: The sequence number in question
        """
        currnum = self.seqno
        self.seqno = self.seqno+1
        return currnum

    def canSendFin( self, host, port ):
        """Is it all right to send a FIN packet to this host?

        @param host: The client host
        @param port: The UDP port of the client
        @return: True if it is all right
        """
        return (self.done and self.outstanding == 0)

    def getID( self ):
        """Returns a unique ID for this server.  Currently based on the
        server's listening port.

        @return: the unique ID
        """
       	return (self.network.socket.getsockname())[1]
    
    #####################################################################
    ##  Message processing
    #####################################################################

    def receiveMessage( self, host, port, msg ):
        """A L{ReliableMessage <msg.ReliableMessage>} was received from
        a client. The corresponding handler functions are called depending
        on the type of the message. 

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>}

        """
        
        print "Received message:", msg.seqno, " of type ",msg.type, \
              " on server ", self.getID()
        
        if( msg.type == TYPE_SYN ):
            self.handleSyn( host, port, msg )

        elif( msg.type == TYPE_SYNACKACK ):
            self.handleSynAckAck( host, port, msg )
            
        elif( msg.type == TYPE_ACK ):
            self.handleAck( host, port, msg )

        elif( msg.type == TYPE_FINACK ):
            self.handleFinAck( host, port, msg )

        else:
            print "Unexpected message type: " + msg.type

#####################################################################
##  Parse options / start network
#####################################################################

def makeServerParser():
    """Parse server options
    """

    host = gethostname()
    port = 6829
    window = 4
    chunk_size = 1000
    loss = 0
    debug = False

    parser = OptionParser()
    parser.add_option("-w", "--window", dest="window", type="int",
                      default=window, action="store",
                      help="fixed window size (in packets)")
    parser.add_option("-l", "--loss", dest="loss", type="float",
                      default=loss, action="store",
                      help="loss rate of channel between client and server")
    parser.add_option("-p", "--port", dest="port", type="int", default=port,
                      action="store", help="port to listen on")
    parser.add_option("-s", "--size", dest="size", type="int",
                      default=chunk_size, action="store",
                      help="data chunk size")
    parser.add_option("-d", "--debug", dest="debug",
                      default=debug, action="store_true",
                      help="print network debug info")
    parser.add_option("-n", "--name", dest="host", type="string",
                      default=host, action="store",
                      help="host name to bind to")
    return parser

def main(argv=None):
    """Run a server from the command line.

    Parse options, and enter the L{Network <network.Network>} event loop.
    """
    
    if argv is None:
        argv = sys.argv

    # Get command line options
    parser = makeServerParser()
    (options, args) = parser.parse_args()

    window = options.window
    loss = options.loss
    port = options.port
    chunk_size = options.size
    host = options.host
    debug = options.debug

    # create a network
    network = Network( host, loss, port, debug ) 

    server = ReliableServer( network, window, chunk_size )

    try:
        Network.loop()
    except KeyboardInterrupt:
        print "Exiting on command.  Goodbye."

if __name__ == "__main__":
    sys.exit(main())
