from __future__ import with_statement

from config import *

import os, sys
import subprocess
import threading
from logging import debug, info, warning, error, critical, exception
from pipes import quote

class Progress(object):
    """A context manager that prints out progress messages before and
    after an action."""

    def __init__(self, msg, done = "done"):
        self.__msg = msg
        self.__done = done

    def __enter__(self):
        info("%s...", self.__msg)

    def __exit__(self, typ, value, traceback):
        if typ == None:
            info("%s... %s", self.__msg, self.__done)
        else:
            info("%s... FAILED (%s)", self.__msg, value)

class Host(object):
    _nonce = 0

    def __init__(self, hostname = None):
        self.__hostname = hostname

        # Start a master ssh for connection sharing.  We run cat in
        # this ssh so that, when the subprocess object is garbage
        # collected, it will close the stdin pipe and thus terminate
        # the ssh.
        if hostname:
            self.__mpath = "/tmp/benchssh-%s-%d" % (hostname, Host._nonce)
            Host._nonce += 1
            self.__master = \
                subprocess.Popen(["ssh", "-M", "-S", self.__mpath, "-x",
                                  hostname, "cat"],
                                 stdin = subprocess.PIPE)

    def __str__(self):
        if self.__hostname:
            return self.__hostname
        return "localhost"

    def __repr__(self):
        return "Host(" + `self.__hostname` + ")"

    def _sh(self, cmd):
        """Return an argument list that executes the argument list CMD
        on this host."""

        if self.__hostname:
            return (["ssh", "-tqxS", self.__mpath, self.__hostname] +
                    [quote(str(t)) for t in cmd])
        return cmd

    def sendSrcFiles(self, *files, **kwargs):
        """Send FILES from SRCDIR on the current host to DESTDIR on
        this host.  Each listed file or directory will appear in the
        root of DESTDIR (this does not reconstruct the source
        directory structure)."""

        self.__sendSrcFiles(files, **kwargs)

    def __sendSrcFiles(self, files, delete = False):
        with Progress("Sending %s to %s" % (" ".join(files), self)):
            srcArgs = [os.path.join(SRCDIR, d) for d in files]
            if delete:
                extraArgs = ["--delete"]
            else:
                extraArgs = []
            if self.__hostname:
                args = (["rsync"] + extraArgs +
                        ["-rltzqe", "ssh -qxS %s" % quote(self.__mpath)] +
                        srcArgs + ["%s:%s" % (self, quote(DESTDIR))])
            else:
                args = (["rsync"] + extraArgs +
                        ["-rltq"] + srcArgs + [DESTDIR])
            Cmd(*args).run()

    def getFile(self, rpath, lpath):
        with Progress("Retrieving %s from %s" % (rpath, self)):
            if self.__hostname:
                Cmd("rsync", "-rltzqe", "ssh -qxS %s" % quote(self.__mpath),
                    "%s:%s" % (self, quote(rpath)),
                    lpath).run()
            else:
                Cmd("rsync", "-rltq", rpath, lpath).run()

    def writeFile(self, path, data):
        """Write DATA to PATH on this host."""

        with Progress("Writing %s on %s" % (path, self)):
            cmd = (Cmd("cat", ">", path)
                   .sh().ssh(self).redirect(stdin = subprocess.PIPE))
            p = cmd.start()
            p.stdin.write(data)
            p.stdin.close()
            code = p.wait()
            if code:
                raise CmdError(cmd, code)

    def readFile(self, path):
        """Read PATH from this host, returning a string of its
        contents."""

        with Progress("Reading %s from %s" % (path, self)):
            return Cmd("cat", path).ssh(self).runRead()

    def pathExists(self, path):
        """Returns true iff PATH exists on this host."""

        cmd = Cmd("test", "-e", path).ssh(self)
        p = cmd.start()
        code = p.wait()
        if code == 0:
            return True
        elif code == 1:
            return False
        else:
            raise CmdError(cmd, code)

    def flushBC(self):
        """Flush the buffer cache on this host."""

        with Progress("Flushing buffer cache on %s" % self):
            Cmd("echo", "3", ">", "/proc/sys/vm/drop_caches").sh().sudo().ssh(self).run()

class CmdError(Exception):
    """A CmdError is raised when a command returns an unexpected error
    code.  It tracks the original Cmd object and the error code."""

    def __init__(self, cmd, code):
        self.cmd, self.code = cmd, code

    def __str__(self):
        return "Command %s exited with code %s" % (self.cmd, self.code)

    def detailError(self):
        """Log a detailed report about the error.  If possible, log
        the tail of the command's log."""

        error("%s", self)
        log = self.cmd._getStdout()
        if log and isinstance(log, file):
            try:
                log.seek(0, 2)
                end = log.tell()
                log.seek(max(end - 1024, 0))
                if log.tell() != 0:
                    # Toss the partial line
                    log.readline()
                # Get the last few lines
                lines = []
                for l in log:
                    if len(lines) > 5:
                        lines.pop(0)
                    lines.append(l.rstrip("\n"))
                if hasattr(log, "name"):
                    error("Log tail (%s)", log.name)
                else:
                    error("Log tail")
                for l in lines:
                    error("  %s", l)
            except (IOError, OSError), e:
                # Too bad, we tried our best
                warning("Failed to tail log (%s)", e)

class Cmd(object):
    """An executable command."""

    def __init__(self, *cmd):
        """Create an executable command, where the arguments are the
        command arguments.  The first argument is the command name."""

        self.__cmd = cmd
        self.__stdin = self.__stdout = self.__stderr = None

    def __dup(self, *newcmd):
        sh = Cmd(*newcmd)
        sh.__stdin  = self.__stdin
        sh.__stdout = self.__stdout
        sh.__stderr = self.__stderr
        return sh

    def __str__(self):
        return " ".join(quote(str(t)) for t in self.__cmd)

    def _getStdout(self):
        return self.__stdout

    def redirect(self, stdin = None, stdout = None, stderr = None):
        """Return a copy of this command where stdin, stdout, and/or
        stderr are redirected.  See the description for
        subprocess.Popen for the possible values of the arguments."""

        sh = self.__dup(*self.__cmd)
        if stdin:
            sh.__stdin = stdin
        if stdout:
            sh.__stdout = stdout
        if stderr:
            sh.__stderr = stderr
        return sh

    def ssh(self, host):
        """Return a copy of this command that will execute the command
        on the given host."""

        if self.__stdin or self.__stdout or self.__stderr:
            raise ValueError("Cannot execute a remote command with redirections")
        return self.__dup(*host._sh(self.__cmd))

    def sudo(self):
        """Return a copy of this command that will execute the command
        as root."""

        return self.__dup("sudo", *self.__cmd)

    def nonohup(self):
        """Return a copy of this command that will execute the command
        using nonohup.  The caller is responsible for redirecting
        stdin to a pipe."""

        if self.__stdin:
            raise ValueError("Cannot nonohup a command with stdin redirected")
        return self.__dup(os.path.join(DESTDIR, "nonohup"), *self.__cmd)

    def sh(self):
        """Return a copy of this command that will execute it via
        'sh'.  The arguments are quoted such that each becomes exactly
        one shell token.  The shell can interpret these tokens
        specially."""

        toks = []
        for t in self.__cmd:
            t = str(t)
            if t in ["<", ">", "<<", ">>", "||", "&", "&&",
                     ";", ";;", "(", ")", "|", "|&", "\n"]:
                toks.append(t)
            else:
                toks.append(quote(t))
        return self.__dup("sh", "-c", " ".join(toks))

    def envs(self, **envs):
        """Return a copy of this command that will execute with the
        given environment variables set."""

        args = ["%s=%s" % (k,v) for k,v in envs.items()]
        args.extend(self.__cmd)
        return self.__dup("env", *args)

    def start(self):
        """Execute this command asynchronously, returning a
        subprocess.Popen object."""

        return subprocess.Popen(map(str, self.__cmd),
                                stdin = self.__stdin,
                                stdout = self.__stdout,
                                stderr = self.__stderr)

    def run(self):
        """Execute this command synchronously, raising a CmdError if
        its exit code is non-zero."""

        if self.__stdin:
            raise ValueError("Cannot run a command with stdin redirected")
        p = self.start()
        code = p.wait()
        if code != 0:
            raise CmdError(self, code)

    def attempt(self):
        """Execute this command synchronously, returning the exit
        code.  If the exit code is non-zero, logs a warning."""

        try:
            self.run()
        except CmdError, e:
            warning(str(e))
            return e.code
        return 0

    def runRead(self):
        """Execute this command synchronously, reading and returning
        the data printed to stdout.  Raise CmdError if the commands
        returns a non-zero exit code."""

        if self.__stdout:
            raise ValueError("Cannot runRead a command with stdout redirected")
        cmd = self.redirect(stdout = subprocess.PIPE)
        p = cmd.start()
        data = p.stdout.read()
        code = p.wait()
        if code:
            raise CmdError(cmd, code)
        return data

def pmap(fn, lst):
    """Like map, but fn is applied to each element of lst in parallel.
    If any applications throw an exception, the one applied to the
    earliest list element is re-raised in the calling thread."""

    return map(fn, lst)

    exceptions = []
    res = [None] * len(lst)
    def wrapper(n, arg):
        try:
            res[n] = fn(arg)
        except:
            exceptions.append(sys.exc_info())
            raise
    threads = []
    for n, l in enumerate(lst):
        thread = threading.Thread(target = wrapper, args = (n, l))
        thread.start()
        threads.append(thread)
    for t in threads:
        t.join()
    if len(exceptions):
        raise exceptions[0][1], None, exceptions[0][2]
    return res
