#!/usr/bin/env python

PREVENT_OVERWRITE = False

from socket import *
from msg import *
from optparse import OptionParser
from network import Network
from time import time
import sys
import os

STATE_SENT_SYN = 1
STATE_SENT_SYNACKACK = 2
STATE_DATA = 3
STATE_FIN = 4

def ppm( msg ):
    types = {TYPE_SYN:"SYN",
             TYPE_SYNACK:"SYNACK",
             TYPE_SYNRETRY:"SYNRETRY",
             TYPE_SYNACKACK:"SYNACKACK",
             TYPE_DATA:"DATA",
             TYPE_ACK:"ACK",
             TYPE_FIN:"FIN",
             TYPE_FINACK:"FINACK"}
    data = `msg.data`
    if len(data) > 40:
        data = data[:40]+"..."
    return "<%s @%d %s>" % (types[msg.type], msg.seqno, data)

class PendingTimeout:
    def __init__(self, host, port, func):
        self.host = host
        self.port = port
        self.func = func

class ReliableClient:
    """A class implementing the client side of the reliable 6.829 message
    transfer protocol.

    This class is capable of asking a L{server.ReliableServer} for one file,
    receiving that file, and writing it to a local output file.

    @cvar default_timeout: Where to start the timeouts (in seconds)
    @cvar default_chunk_size: Default value of chunk size (if not
                              negotiated in SYN handshake)
    @cvar default_finack_count: How many FINACKs to send
    @cvar default_finack_separation: Seconds between sending FINACKs
    @cvar STATE_SENT_SYN:  Corresponds to a client state in which the
                           client sent the SYN, but not the SYNACKACK
    @cvar STATE_SENT_SYNACKACK:  Corresponds to a client state in which the
                                 client sent the SYNACKACK, but has not yet
                                 received any data
    @cvar STATE_DATA: Corresponds to a client state in which the client is
                      actively receiving data packets


    @ivar network: The client's L{network.Network} object
    @ivar filename: The filename to request from the server
    @ivar output: The local output file for the received server file
    @ivar chunk_size: The expected size of the data chunks sent by the server.
                      This is set during the SYN handshake.
    @ivar fd: The writable file descriptor for the output file
    @ivar seqno_start: The sequence number of the first data packet
    @ivar state: This variable can be used to keep track of the current state
                 of the client (eg: STATE_SENT_SYN)
    @ivar pendingTimeouts: Mapping of sequence numbers to
                           L{PendingTimeout} objects for each pending message
    @ivar pendingSynMessage: The SYN message that is awaiting response or None
    @ivar pendingSynAckAckMessage: The SYNACKACK message awaiting response
    @ivar synackSeqno: The sequence number of the SYNACK, or None

    """

    default_timeout = 3

    default_chunk_size = 700

    default_finack_count = 5
    default_finack_separation = 0.2

    #####################################################################
    ##  Initialization
    #####################################################################
    def __init__( self, network, filename, output ):
        """Initialize client data members, and register network functions.

        @param network: A configured L{network.Network} object
        @param filename: The name of the remote file to request
        @param output: The name of the local file to write to.

        """
        self.network = network
        self.filename = filename
        self.output = output
        self.chunk_size = 0  # Must be set by use during SYN handshake
        self.fd = None
        self.seqno_start = 0
        self.state = 0

        self.network.registerReceiveFunc( self.receiveMessage )
        self.network.registerTimeoutFunc( self.handleTimeout )

        self.pendingTimeouts = {}

        self.pendingSynMessage = None
        self.pendingSynAckAckMessage = None

        self.pendingFinacks = None

        self.synackSeqno = None

    #####################################################################
    ##  User-defined functions
    #####################################################################

    def sendSyn( self, host, port ):
        """Start a SYN handshake process with a remote host by sending an
        initial SYN packet

        @param host: The DNS name or IP address of the remote server
        @param port: The UDP port the remote server is listening on
        """
        # Send SYN with retry
        out = ReliableMessage(TYPE_SYN, 0, time(), self.filename)
        self.__sendReliably(host, port, out)
        self.pendingSynMessage = out
        self.state = STATE_SENT_SYN

    def __handlePostSyn( self ):
        """Handle processing that needs to occur when any response to
        a SYN (either a SYNACK or a SYNRETRY) is received.

        """
        # Cancel the SYN timeout
        if self.pendingSynMessage is not None:
            self.__cancelTimeout(self.pendingSynMessage)
        self.pendingSynMessage = None

    def handleSynAck( self, host, port, msg ):
        """Receive a mesage of type SYNACK from the server

        The chunk size is negotiated between the client and the server
        in the payload of the SYNACK packet.  (eg: self.chunk_size =
        msg.data). If the server does not set the chunk size in the
        SYNACK packet (or sets a zero value), use the default chunk
        size.

        The timeout for the SYN packet must be cancelled, since the
        SYN has been successfully received.

        Must respond with a SYNACKACK, and open the output file for writing.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type SYNACK

        """
        # Ignore duplicate SYNACKs
        if self.state != STATE_SENT_SYN:
            print "Warning: Stray SYNACK", ppm(msg)
            return
        self.__handlePostSyn()

        # Validate chunk size and record it
        self.chunk_size = msg.data
        if isinstance(msg.data, int) and msg.data > 0:
            self.chunk_size = msg.data
        else:
            print "Warning: Server returned invalid chunk size", ppm(msg)
            self.chunk_size = default_chunk_size

        # Record the sequence number
        self.synackSeqno = msg.seqno

        # Open output file
        if self.fd is None:
            if PREVENT_OVERWRITE:
                if os.path.exists(self.output):
                    raise RuntimeError, \
                          "I will not overwrite an existing file"
            self.fd = file(self.output, "wb")
        else:
            print "Bug: Output file already open"

        # Send SYNACKACK with retry
        out = ReliableMessage(TYPE_SYNACKACK, msg.seqno, time())
        self.__sendReliably(host, port, out)
        self.pendingSynAckAckMessage = out
        self.state = STATE_SENT_SYNACKACK

    def handleSynRetry(self, host, port, msg ):
        """Receive a mesage of type SYNRETRY from the server

        This message indicates that the server did not have the requested file.
        This method should request a different file from the user (perhaps
        using the raw_input function) and try to establish the connection
        again by resending the SYN.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of
                    type SYNRETRY

        """
        # Ignore duplicate SYNRETRYs
        if self.state != STATE_SENT_SYN:
            print "Warning: Stray SYNRETRY", ppm(msg)
            return
        self.__handlePostSyn()

        # Restart with a different file
        print "Error retrieving file.  Please specify a different file."
        self.filename = raw_input()
        self.sendSyn(host, port)

    def __handlePostSynAckAck(self):
        """Handle processing that needs to occur when any response to
        a SYNACK (either a DATA or a FIN) is received.

        """
        # Cancel SYNACKACK timeout
        if self.state == STATE_SENT_SYNACKACK:
            self.__cancelTimeout(self.pendingSynAckAckMessage)
            self.pendingSynAckAckMessage = None
            self.state = STATE_DATA

    def handleData( self, host, port, msg ):
        """Receive a mesage of type DATA from the server

        The message should contain a chunk of the requested file.  Write
        the data to disk at the correct location, and send an acknowledgement.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type DATA

        """
        self.__handlePostSynAckAck()

        # Ignore out-of-order DATA messages
        if self.state != STATE_DATA:
            print "Warning: Stray DATA message", ppm(msg)
            return

        # Check sequence number (logically, this should be <=)
        if msg.seqno < self.synackSeqno:
            print "Warning: Invalid DATA sequence number", ppm(msg)
            return

        # Record the data
        offset, data = msg.data
        self.writeFile(data, offset)

        # Send an ACK unreliably
        out = ReliableMessage(TYPE_ACK, msg.seqno, msg.timestamp)
        self.__sendUnreliably(host, port, out)

    def handleFin( self, host, port, msg ):
        """Receive a mesage of type FIN from the server

        The message indicates that the server has sent the entire file,
        and that the connection is over.  Acknowledge the FIN, close the
        output file, and close the connection.  You can close the connection,
        and exit the program, using L{Network.close}.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} of type FIN

        """
        self.__handlePostSynAckAck()

        # Ignore duplicate FIN messages
        if self.state != STATE_DATA:
            print "Warning: Stray FIN message", ppm(msg)
            return

        # Check sequence number (logically, this should be <=)
        if msg.seqno < self.synackSeqno:
            print "Warning: Invalid FIN sequence number", ppm(msg)
            return

        # Close the file
        self.fd.close()

        # Send a couple copies of the FINACK with a delay between them
        out = ReliableMessage(TYPE_FINACK, msg.seqno, msg.timestamp)
        self.pendingFinacks = self.default_finack_count
        self.state = STATE_FIN
        def resender():
            self.__sendReliably(host, port, out,
                                self.default_finack_separation,
                                resender)
            self.pendingFinacks -= 1
            if self.pendingFinacks == 0:
                self.network.close()
        resender()

    def __sendReliably( self, host, port, msg,
                        timeout = default_timeout, func = None ):
        """Send a message reliably

        This associates a timeout and a function to call if that
        timeout expires with the message being sent.  If
        L{ReliableClient.__cancelTimeout} isn't called before this
        timeout expires, func is called.  If timeout is not specified,
        it defaults to the default timeout.  If func is not specified,
        it defaults to simply retransmitting the packet.

        @param host: The host to send the message to
        @param port: The UDP port to send the message to
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} to send
        @param timeout: The time to wait before calling func
        @param func: The function to call when timeout expires

        """
        print "Sending reliably:", ppm(msg)
        if func is None:
            func = lambda : self.__sendReliably(host, port, msg, timeout)
        self.network.sendMessage(host, port, msg, timeout)
        self.pendingTimeouts[msg.seqno] = PendingTimeout(host, port, func)

    def __sendUnreliably( self, host, port, msg ):
        """Send a message unreliably

        This simply sends a message without involving any timeout
        mechanisms.

        @param host: The host to send the message to
        @param port: The UDP port to send the message to
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} to send

        """
        print "Sending unreliably:", ppm(msg)
        self.network.sendMessage(host, port, msg)

    def __cancelTimeout( self, msg, ignorenetwork = False ):
        """Cancel the timeout associated with msg

        This cancels the timeout that was set up by a previous call to
        __sendReliably for msg.  Generally, this should be called when
        any message indicating receipt of msg is received from the
        server.

        @param msg: The message that was previously sent
        @param ignorenetwork: A bug work-around.  This should only be
          set if this is called from L{handleTimeout} or a timeout
          handler.

        """
        if msg.seqno in self.pendingTimeouts:
            pt = self.pendingTimeouts[msg.seqno]
            if not ignorenetwork:
                # Bug work-around.  Set ignorenetwork if this is being
                # called from handleTimeout so the timeout isn't
                # cancelled from the timeout firing loop
                self.network.cancelTimeout(pt.host, pt.port, msg.seqno)
            del self.pendingTimeouts[msg.seqno]
        else:
            print "Bug: Cancelling unknown timeout", ppm(msg)

    def handleTimeout( self, host, port, msg ):
        """A message has timed out.

        A sent message has not yet been acknowledged, or else someone
        forgot to call L{cancelTimeout <network.Network.cancelTimeout>}.
        Take the appropriate action.

        @param host: The intended receiver of the timed out message
        @param port: The UDP port of the receiver
        @param msg: The L{ReliableMessage <msg.ReliableMessage>} that timed out

        """
        pt = self.pendingTimeouts.get(msg.seqno, None)
        if pt is None:
            print "Bug: Timeout for unknown message", ppm(msg)
        else:
            print "Timeout for message", ppm(msg)
            self.__cancelTimeout(msg, True)
            pt.func()

    #####################################################################
    ##  Message processing
    #####################################################################

    def receiveMessage( self, host, port, msg ):
        """A L{ReliableMessage <msg.ReliableMessage>} was received from
        the server.

        @param host: The host that sent this message
        @param port: The UDP port of the server that sent this message
        @param msg: The L{ReliableMessage <msg.ReliableMessage>}

        """

        print "Received message:", ppm(msg)
        
        if( msg.type == TYPE_SYNACK ):
            self.handleSynAck( host, port, msg )

        elif( msg.type == TYPE_SYNRETRY ):
            self.handleSynRetry( host, port, msg )

        elif( msg.type == TYPE_DATA ):
            self.handleData( host, port, msg )
            
        elif( msg.type == TYPE_FIN ):
            self.handleFin( host, port, msg )

        else:
            print "Unexpected message type: " + msg.type

    def writeFile( self, data, offset ):
        """Write data to the output file at some offset.

        @param data: The data to write to the file (a string).
        @param offset: Seek this many bytes in the file before writing

        """

        self.fd.seek( offset, 0 )
        self.fd.write( data )

#####################################################################
##  Parse options / start network
#####################################################################

def makeClientParser():
    """Parse client options
    """

    # Set the socket parameters
    host = gethostname()
    port = 6829
    loss = 0
    filename = "server.py"
    output = "downloaded_file"
    
    parser = OptionParser()
    parser.add_option("-s", "--server", dest="host", type="string",
                      default=host, action="store",
                      help="hostname or IP of server")
    parser.add_option("-p", "--port", dest="port", type="int", default=port,
                      action="store", help="UDP port of server")
    parser.add_option("-f", "--filename", dest="filename", type="string",
                      default=filename, action="store",
                      help="name of file to request")
    parser.add_option("-o", "--output_filename", dest="output", type="string",
                      default=output, action="store",
                      help="where to store the received file")
    parser.add_option("-l", "--loss", dest="loss", type="float",
                      default=loss, action="store",
                      help="loss rate of channel between server and client")    

    return parser

def main(argv=None):
    """Run a client from the command line.

    Parse options, send a SYN packet, and enter the
    L{Network <network.Network>} event loop.
    """
    
    if argv is None:
        argv = sys.argv

    # Get command line options
    parser = makeClientParser()
    (options, args) = parser.parse_args()
    
    host = options.host
    port = options.port
    loss = options.loss
    filename = options.filename
    output = options.output    

    # create a network
    network = Network( gethostname(), loss )

    client = ReliableClient( network, filename, output )
    client.sendSyn( host, port )
    
    network.loop()


if __name__ == "__main__":
    sys.exit(main())



