#!/usr/bin/env python

from socket import *
from msg import *
from optparse import OptionParser
from network import Network
from random import random
from time import time
import sys

class ReliableServer:
    """A class implementing the server side of the reliable 6.829 message
    transfer protocol.

    This class is capable of serving one L{client.ReliableClient} at a time.
    From the client, it receives a request for a file during the SYN handshake,
    send the data to the client one chunk at a time, and then closes the
    connection.

    The server keeps a weighted-average estimate of the round trip time to
    the client.  To keep this up-to-date, it should set the timestamp
    of each message it sends with the current time (i.e., call time()).
    When it receives a message, the server should calculate the round-trip
    time using the timestamp of the new message (which is simply echoed by the
    client).  The server should call L{updateRTT} with this new measurement.

    The server should set timeouts of messages using the L{calcTimeout}
    method.

    @cvar STATE_SENT_SYNACK:  Corresponds to a server state in which the
                              server sent the SYNACK, but no data
    @cvar STATE_DATA: Corresponds to a server state in which the server is
                      actively sending data packets
    @cvar STATE_SENT_FIN: Corresponds to a server state in which the server is
                          finished sending data, and is trying to close the
                          connection

    @ivar network: The server's L{network.Network} object
    @ivar window: How many messages can be outstanding to the client at
                  any one time
    @ivar chunk_size: The size of the data chunks sent to the client.
    @ivar state: The current server state (e.g., STATE_SENT_SYNACK)
    @ivar filename: The name of the file requested by the client
    @ivar fd: The read-only file descriptor corresponding to the filename
    @ivar outstanding: The number of outstanding (un-ACKed) messages
    @ivar seqno: The sequence number of the next message to the client
    @ivar done: True if the server has sent all data chunks to the client
    @ivar rtt: The estimated round trip time to the client
    @ivar alpha: How much to weight the current rtt estimate
    @ivar beta: Timeout factor
    """

    STATE_SENT_SYNACK = 1
    STATE_DATA = 2
    STATE_EOF = 3
    STATE_SENT_FIN = 4
    

    #####################################################################
    ##  Initialization
    #####################################################################
    def __init__( self, network, window=4, chunk_size=700 ):
        """Initialize server data members, and register network functions.

        @param network: A configured L{network.Network} object
        @param window: How many messages can be outstanding to the client at
                       any one time (default: 4)
        @param chunk_size: The size of the data chunks sent to the client.
                           (default: 700)

        """
        self.window = window
        self.chunk_size = chunk_size
        self.network = network
        self.resetFields()

        self.network.registerReceiveFunc( self.receiveMessage )
        self.network.registerTimeoutFunc( self.handleTimeout )
        self.state = 0
        
    def resetFields( self ):
        """Reset data members to their default values.

        Call this on initialization, and between client connections.
        """
        self.filename = ""
        self.fd = 0
        self.outstanding = 0
        self.seqno = 0
        self.seqno_start = 0
        self.done = 0
        self.state = 0
        
        # rtt/timeout estimates default values
        self.rtt = 1.5
        self.alpha = .9
        self.beta = 2

    #####################################################################
    ##  User-defined functions
    #####################################################################

    def handleSyn( self, host, port, msg ):
        """Receive a mesage of type SYN from a client

        The message contains the name of a file to open (in the
        variable msg.data).  Open the file for reading; if the file
        opens successfully, respond to the client with a SYNACK.
        Otherwise, respond with a SYNRETRY.

        The server also picks a (possibly random) sequence number for the
        SYNACK packet and increments this sequence number for every subsequent
        data packet.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type SYN

        """

        if self.state == ReliableServer.STATE_SENT_SYNACK:
            # Already sent a synack. Assume it was lost and the client
            # is retransmitting.
            self.resetFields()
        elif self.state != 0:
            return
        
        try:
            self.filename = msg.data
            self.fd = file(self.filename, "r")
            self.seqno_start = int(random() * 65536)
            self.seqno = self.seqno_start + 1
            self.state = ReliableServer.STATE_SENT_SYNACK
            synAck = ReliableMessage(TYPE_SYNACK, self.seqno_start,
                                     time(), self.chunk_size)
            self.network.sendMessage(host, port, synAck)
        except:
            synRetry = ReliableMessage(TYPE_SYNRETRY, 0, time())
            self.network.sendMessage(host, port, synRetry)

    def handleSynAckAck( self, host, port, msg ):
        """Receive a mesage of type SYNACKACK from the client

        The client is confirming that it received the SYNACK.  Thus, the server
        updates the RTT estimate and begins sending data packets to the client.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type
                    SYNACKACK

        """
        if self.state != ReliableServer.STATE_SENT_SYNACK:
            return
        
        self.updateRTT(host, port, time()-msg.data)
        self.state = ReliableServer.STATE_DATA
        self.sendData(host, port, msg)
    
    def handleAck( self, host, port, msg ):
        """Receive a mesage of type ACK from the client

        The client is acknowledging that it received a data message.  Update
        the relevant data members (like RTT estimate), cancel any timeouts,
        and send more data.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type ACK

        """
        if self.state != ReliableServer.STATE_DATA and self.state != ReliableServer.STATE_EOF:
            return
        
        if self.network.cancelTimeout(host, port, msg.seqno) == 0:
            # Dup ack
            return

        self.updateRTT(host, port, time()-msg.data)
        self.outstanding -=  1

        if self.state == ReliableServer.STATE_DATA:
            self.sendData(host, port, msg)
        elif self.state == ReliableServer.STATE_EOF:
            if self.outstanding == 0:
                fin = ReliableMessage(TYPE_FIN, self.seqno, time())
                self.state = ReliableServer.STATE_SENT_FIN
                self.network.sendMessage(host, port, fin,
                                         self.calcTimeout(host, port))

    def sendData( self, host, port, msg ):
        """Send data messages to the client

        This method sends as many data messages to the client as possible,
        while maintaining the fixed L{window} size.  It sends the messages
        reliably.  Close the file once all data has been sent.

        If all the file's data chunks have been sent and acknowledged,
        initialize the closing phase of the connection by sending the FIN
        packet.

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type
                    SYNACKACK or type ACK

        """
        if (self.state != ReliableServer.STATE_DATA):
            return
        
        while self.outstanding < self.calcWindow(host, port):
            offset = (self.seqno-self.seqno_start-1) * self.chunk_size
            print "Sending from", offset
            self.fd.seek(offset)
            data = self.fd.read(self.chunk_size)
            dataMsg = ReliableMessage(TYPE_DATA, self.seqno, time(),
                                      data)
            self.network.sendMessage(host, port, dataMsg,
                                     self.calcTimeout(host, port))
            self.seqno += 1
            self.outstanding += 1
            
            if (len(data) < self.chunk_size):
                self.state = ReliableServer.STATE_EOF
                break
                

    def handleFinAck( self, host, port, msg ):
        """Receive a mesage of type FINACK from the client

        The client is acknowledging that it received a FIN.  Thus, the
        connection is officially closed.  Reset all necessary state

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type FINACK

        """
        if self.state != ReliableServer.STATE_SENT_FIN:
            return
        
        self.resetFields()
        self.network.cancelTimeout(host, port, msg.seqno)

    def handleTimeout( self, host, port, msg ):
        """A message has timed out.

        A sent message has not yet been acknowledged, or else someone
        forgot to call L{cancelTimeout <network.Network.cancelTimeout>}.
        Take the appropriate action (eg: resending the message). 

        @param host: The intended receiver of the timed out message
        @param port: The UDP port of the receiver
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} that timed out

        """
        msg.timestamp = time()
        self.network.sendMessage(host, port, msg, self.calcTimeout(host, port))

    def updateRTT( self, host, port, rtt_est ):
        """Update the round trip time estimate to the client
        
        @param host: The client host
        @param port: The UDP port of the client
        @param rtt_est: A new measurement of the round trip time
        
        """

        self.rtt = self.alpha*self.rtt + (1-self.alpha)*rtt_est

    def calcTimeout( self, host, port ):
        """Calculate a timeout for a message to a client
        
        @param host: The client host
        @param port: The UDP port of the client
        @return: the calculated timeout (in seconds)

        """

        return self.beta*self.rtt

    def calcWindow( self, host, port ):
        """Calculate a the current window size for a client
        
        @param host: The client host
        @param port: The UDP port of the client
        @return: the calculated window size

        """
        
        return self.window

    #####################################################################
    ##  Message processing
    #####################################################################

    def receiveMessage( self, host, port, msg ):
        """A L{ReliableMessage <msg.ReliableMessage>} was received from
        a client. The corresponding handler functions are called depending
        on the type of the message. 

        @param host: The host that sent this message
        @param port: The UDP port of the client that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>}

        """
        
        print "Received message:", msg.seqno
        
        if( msg.type == TYPE_SYN ):
            self.handleSyn( host, port, msg )

        elif( msg.type == TYPE_SYNACKACK ):
            self.handleSynAckAck( host, port, msg )
            
        elif( msg.type == TYPE_ACK ):
            self.handleAck( host, port, msg )

        elif( msg.type == TYPE_FINACK ):
            self.handleFinAck( host, port, msg )

        else:
            print "Unexpected message type: " + msg.type

#####################################################################
##  Parse options / start network
#####################################################################

def makeServerParser():
    """Parse server options
    """

    host = gethostname()
    port = 6829
    window = 4
    chunk_size = 700
    loss = 0
    debug = False

    parser = OptionParser()
    parser.add_option("-w", "--window", dest="window", type="int",
                      default=window, action="store",
                      help="fixed window size (in packets)")
    parser.add_option("-l", "--loss", dest="loss", type="float",
                      default=loss, action="store",
                      help="loss rate of channel between client and server")
    parser.add_option("-p", "--port", dest="port", type="int", default=port,
                      action="store", help="port to listen on")
    parser.add_option("-s", "--size", dest="size", type="int",
                      default=chunk_size, action="store",
                      help="data chunk size")
    parser.add_option("-d", "--debug", dest="debug",
                      default=debug, action="store_true",
                      help="print network debug info")
    parser.add_option("-n", "--name", dest="host", type="string",
                      default=host, action="store",
                      help="host name to bind to")
    return parser

def main(argv=None):
    """Run a server from the command line.

    Parse options, and enter the L{Network <network.Network>} event loop.
    """
    
    if argv is None:
        argv = sys.argv

    # Get command line options
    parser = makeServerParser()
    (options, args) = parser.parse_args()

    window = options.window
    loss = options.loss
    port = options.port
    chunk_size = options.size
    host = options.host
    debug = options.debug

    # create a network
    network = Network( host, loss, port, debug ) 

    server = ReliableServer( network, window, chunk_size )

    try:
        network.loop()
    except KeyboardInterrupt:
        print "Exiting on command.  Goodbye."

if __name__ == "__main__":
    sys.exit(main())
