#!/usr/bin/env python

from socket import *
from msg import *
from optparse import OptionParser
from network import Network
from time import time
import sys

class ReliableClient:
    """A class implementing the client side of the reliable 6.829 message
    transfer protocol.

    This class is capable of asking a L{server.ReliableServer} for one file,
    receiving that file, and writing it to a local output file.

    @cvar default_timeout: Where to start the timeouts (in seconds)
    @cvar default_chunk_size: Default value of chunk size (if not negotiated in SYN handshake)
    @cvar STATE_SENT_SYN:  Corresponds to a client state in which the
                           client sent the SYN, but not the SYNACKACK
    @cvar STATE_SENT_SYNACKACK:  Corresponds to a client state in which the
                                 client sent the SYNACKACK, but has not yet
                                 received any data
    @cvar STATE_DATA: Corresponds to a client state in which the client is
                      actively receiving data packets


    @ivar network: The client's L{network.Network} object
    @ivar filename: The filename to request from the server
    @ivar output: The local output file for the received server file
    @ivar chunk_size: The expected size of the data chunks sent by the server.
                      This is set during the SYN handshake.
    @ivar fd: The writable file descriptor for the output file
    @ivar seqno_start: The sequence number of the first data packet
    @ivar state: This variable can be used to keep track of the current state
                 of the client (eg: STATE_SENT_SYN)
    """

    default_timeout = 3

    STATE_SENT_SYN = 1
    STATE_SENT_SYNACKACK=2
    STATE_DATA=3

    default_chunk_size = 700

    #####################################################################
    ##  Initialization
    #####################################################################
    def __init__( self, network, filename, output ):
        """Initialize client data members, and register network functions.

        @param network: A configured L{network.Network} object
        @param filename: The name of the remote file to request
        @param output: The name of the local file to write to.

        """
        self.network = network
        self.filename = filename
        self.output = output
        self.chunk_size = 0  # Must be set by use during SYN handshake
        self.fd = 0
        self.seqno_start = 0
        self.state=0

        self.network.registerReceiveFunc( self.receiveMessage )
        self.network.registerTimeoutFunc( self.handleTimeout )

    #####################################################################
    ##  User-defined functions
    #####################################################################

    def sendSyn( self, host, port ):
        """Start a SYN handshake process with a remote host by sending an
        initial SYN packet

        @param host: The DNS name or IP address of the remote server
        @param port: The UDP port the remote server is listening on
        """
        # Start off with a SYN, along with a filename
        syn = ReliableMessage( TYPE_SYN, 0, 0, self.filename )
        self.network.sendMessage( host, port, syn, self.default_timeout )
        self.state = self.STATE_SENT_SYN
        
    def handleSynAck( self, host, port, msg ):
        """Receive a mesage of type SYNACK from the server

        The chunk size is negotiated between the client and the server
        in the payload of the SYNACK packet.  (eg: self.chunksize =
        msg.data). If the server does not set the chunk size in the
        SYNACK packet (or sets a zero value), use the default chunk
        size.

        The timeout for the SYN packet must be cancelled, since the
        SYN has been successfully received.

        Must respond with a SYNACKACK, and open the output file for writing.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type SYNACK

        """

        # handle syn ack only if you have sent a syn
        print "Got SYNACK"
        if( self.state == self.STATE_SENT_SYN ):
            self.network.cancelTimeout(host,port,0)
            self.chunk_size = msg.data  # chunk size set here
            # use the default value if chunk size not negotiated in
            # SYN handshake
            if(self.chunk_size <= 0):
                self.chunk_size = default_chunk_size
            
            self.seqno_start = msg.seqno + 1

            synackack = ReliableMessage( TYPE_SYNACKACK, msg.seqno,
                                         msg.timestamp )
            self.network.sendMessage( host, port, synackack,
                                      self.default_timeout )
            self.fd = open( self.output, "w" )
            self.state = self.STATE_SENT_SYNACKACK

    def handleSynRetry(self, host, port, msg ):
        """Receive a mesage of type SYNRETRY from the server

        This message indicates that the server did not have the requested file.
        This method should request a different file from the user (perhaps
        using the raw_input function) and try to establish the connection
        again by resending the SYN.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of
                    type SYNRETRY

        """

        if(self.state == self.STATE_SENT_SYN):
            #The file requested wasnt found, so retransmit SYN
            # with new file name
            print "Received Syn retry. Retransmitting SYN"
            self.network.cancelTimeout( host, port, 0 ) #cancel old timeout
            self.filename = raw_input( "Requested file not found. " +
                                       "Enter new filename:")
            self.sendSyn(host,port)
            
    def handleData( self, host, port, msg ):
        """Receive a mesage of type DATA from the server

        The message should contain a chunk of the requested file.  Write
        the data to disk at the correct location, and send an acknowledgement.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type DATA

        """
                    
        # cancel timeout for SYNACKACK, this happens for the
        # first data packet
        if( msg.seqno >= self.seqno_start and
            self.state == self.STATE_SENT_SYNACKACK ):

            self.network.cancelTimeout( host, port, self.seqno_start-1 )
            self.state = self.STATE_DATA
            
        if( msg.seqno >= self.seqno_start and
            self.state == self.STATE_DATA ):

            print "Received data, seqno:",msg.seqno
            self.writeFile( msg.data,
                            (msg.seqno-self.seqno_start)*self.chunk_size )
            ack = ReliableMessage( TYPE_ACK, msg.seqno, msg.timestamp )
            self.network.sendMessage( host, port, ack )

    def handleFin( self, host, port, msg ):
        """Receive a mesage of type FIN from the server

        The message indicates that the server has sent the entire file,
        and that the connection is over.  Acknowledge the FIN, close the
        output file, and close the connection.  You can close the connection,
        and exit the program, using L{Network.close}.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type FIN

        """

        if(self.state == self.STATE_DATA):
            
            self.fd.close()
            finack = ReliableMessage( TYPE_FINACK, msg.seqno )
            print "File download complete. GoodBye"
            # Send multiple times, just to be sure, because we're not
            # waiting around
            self.network.sendMessage( host, port, finack )
            self.network.sendMessage( host, port, finack )
            self.network.sendMessage( host, port, finack )
            self.network.sendMessage( host, port, finack )
            self.network.sendMessage( host, port, finack )
            self.network.close()

    def handleTimeout( self, host, port, msg ):
        """A message has timed out.

        A sent message has not yet been acknowledged, or else someone
        forgot to call L{cancelTimeout <network.Network.cancelTimeout>}.
        Take the appropriate action.

        @param host: The intended receiver of the timed out message
        @param port: The UDP port of the receiver
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} that timed out

        """
                
        print "This message timed out:", msg.seqno
        msg.timestamp = time()
        self.network.sendMessage( host, port, msg, self.default_timeout )

    #####################################################################
    ##  Message processing
    #####################################################################

    def receiveMessage( self, host, port, msg ):
        """A L{ReliableMessage <msg.ReliableMessage>} was received from
        the server.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>}

        """

        print "Received message:", msg.seqno, "server port=", port
        
        if( msg.type == TYPE_SYNACK ):
            self.handleSynAck( host, port, msg )

        elif( msg.type == TYPE_SYNRETRY ):
            self.handleSynRetry( host, port, msg )

        elif( msg.type == TYPE_DATA ):
            self.handleData( host, port, msg )
            
        elif( msg.type == TYPE_FIN ):
            self.handleFin( host, port, msg )

        else:
            print "Unexpected message type: " + msg.type

    def writeFile( self, data, offset ):
        """Write data to the output file at some offset.

        @param data: The data to write to the file (a string).
        @param offset: Seek this many bytes in the file before writing

        """

        self.fd.seek( offset, 0 )
        self.fd.write( data )

#####################################################################
##  Parse options / start network
#####################################################################

def makeClientParser():
    """Parse client options
    """

    # Set the socket parameters
    host = gethostname()
    port = 6829
    loss = 0
    filename = "server.py"
    output = "downloaded_file"
    
    parser = OptionParser()
    parser.add_option("-s", "--server", dest="host", type="string",
                      default=host, action="store",
                      help="hostname or IP of server")
    parser.add_option("-p", "--port", dest="port", type="int", default=port,
                      action="store", help="UDP port of server")
    parser.add_option("-f", "--filename", dest="filename", type="string",
                      default=filename, action="store",
                      help="name of file to request")
    parser.add_option("-o", "--output_filename", dest="output", type="string",
                      default=output, action="store",
                      help="where to store the received file")
    parser.add_option("-l", "--loss", dest="loss", type="float",
                      default=loss, action="store",
                      help="loss rate of channel between server and client")    

    return parser

def main(argv=None):
    """Run a client from the command line.

    Parse options, send a SYN packet, and enter the
    L{Network <network.Network>} event loop.
    """
    
    if argv is None:
        argv = sys.argv

    # Get command line options
    parser = makeClientParser()
    (options, args) = parser.parse_args()
    
    host = options.host
    port = options.port
    loss = options.loss
    filename = options.filename
    output = options.output    

    # create a network
    network = Network( gethostname(), loss )

    client = ReliableClient( network, filename, output )
    client.sendSyn( host, port )
    
    Network.loop()


if __name__ == "__main__":
    sys.exit(main())



