package anastore.util;

import java.io.*;
import java.util.*;

import org.junit.runner.*;
import org.junit.runner.notification.Failure;
import org.junit.runner.notification.RunListener;

public class JUnitRunner
{
    private static final boolean REDIRECT = true;
    private static final boolean OUT_ON_OK = false;

    private static final String RED    = "\u001b[31;1m";
    private static final String GREEN  = "\u001b[32;1m";
    private static final String YELLOW = "\u001b[33;1m";
    private static final String NONE   = "\u001b[0m";

    private static final PrintStream out = System.out;
    private static final PrintStream err = System.err;

    private static class RedirectableOutputStream extends FilterOutputStream
    {
        public RedirectableOutputStream(OutputStream init)
        {
            super(init);
        }

        public void redirect(boolean close, OutputStream n)
            throws IOException
        {
            try {
                if (close)
                    out.close();
                else
                    out.flush();
            } finally {
                out = n;
            }
        }
    }

    private static class Listener extends RunListener
    {
        private String className = null;
        private boolean running = false;
        private File redirectTo = null;
        private RedirectableOutputStream redirectable =
            new RedirectableOutputStream(System.out);
        private PrintStream redirect =
            new PrintStream(redirectable);

        public void testStarted(Description description)
        {
            if (!description.isTest())
                return;

            String name = description.getDisplayName();
            String testName = name.substring(0, name.indexOf('('));
            String newClassName = name.substring(name.indexOf('(')+1,
                                                 name.indexOf(')'));
            if (!newClassName.equals(className)) {
                className = newClassName;
                out.println(className);
            }
            out.print("  " + testName + "...");
            out.flush();
            running = true;
            if (REDIRECT) {
                redirectTo = new File("/tmp/" + className + "-" + testName + ".log");
                try {
                    redirectable.redirect(false, new FileOutputStream(redirectTo));
                    System.setOut(redirect);
                    System.setErr(redirect);
                } catch (IOException e) {
                    System.setOut(out);
                    System.setErr(err);
                    e.printStackTrace();
                    // Ignore
                }
            }
        }

        public void testFailure(Failure failure)
        {
            out.println(" " + RED + "FAILED" + NONE);
            unredirect();
            running = false;
            if (REDIRECT)
                showCapture();
        }

        public void testFinished(Description description)
        {
            unredirect();
            if (running)
                out.println(" " + GREEN + "OK" + NONE);
            if (REDIRECT && OUT_ON_OK && running)
                showCapture();
            running = false;
        }

        private void unredirect()
        {
            if (running && REDIRECT) {
                System.setOut(out);
                System.setErr(err);
                redirect.flush();
                try {
                    redirectable.redirect(true, out);
                } catch (IOException e) {
                    e.printStackTrace();
                    // Ignore
                }
                if (redirectTo.length() == 0)
                    redirectTo.delete();
            }
        }

        private void showCapture()
        {
            if (!redirectTo.exists())
                return;
            try {
                FileInputStream is = new FileInputStream(redirectTo);
                byte[] data = new byte[1024];
                while (true) {
                    int count = is.read(data);
                    if (count == -1)
                        break;
                    out.write(data, 0, count);
                }
                out.flush();
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
                // Ignore
            }
        }

        public void testIgnored(Description description)
        {
            testStarted(description);
            out.println(" " + YELLOW + "IGNORED" + NONE);
            running = false;
        }

        public void testRunFinished(Result result)
        {
            StringBuilder msg =
                new StringBuilder("Done. " +
                                  result.getRunCount() + " tests run.");
            if (result.getIgnoreCount() > 0) {
                msg.append(" " + YELLOW + result.getIgnoreCount() +
                           " ignored." + NONE);
            }
            if (result.getFailureCount() > 0) {
                msg.append(" " + RED + result.getFailureCount() +
                           " failed." + NONE);
            }
            out.println(msg);
        }

        public void testRunStarted(Description description)
        {
        }
    }

    public static void main(String args[])
    {
        JUnitRunner inst = new JUnitRunner();
        Result res = inst.run(args);
        inst.showResult(res);
        if (!res.wasSuccessful())
            System.exit(1);
    }

    private Result run(String args[])
    {
        JUnitCore core = new JUnitCore();

        core.addListener(new Listener());
        List<Class<?>> classes = new ArrayList<Class<?>>();
        List<Failure> missing = new ArrayList<Failure>();
        for (String name : args) {
            try {
                classes.add(Class.forName(name));
            } catch (ClassNotFoundException e) {
                Description mdesc = Description.createSuiteDescription(name);
                Failure failure = new Failure(mdesc, e);
                missing.add(failure);
            }
        }
        Result res = core.run(classes.toArray(new Class[0]));
        for (Failure mfail : missing)
            res.getFailures().add(mfail);
        return res;
    }

    private void showResult(Result res)
    {
        for (Failure f : res.getFailures()) {
            out.println("");
            out.println(RED + f.getTestHeader() + NONE);
            Throwable exc = f.getException();
            trimException(exc);
            exc.printStackTrace(out);
        }
    }

    private void trimException(Throwable t)
    {
        StackTraceElement[] st = t.getStackTrace();
        int pos;
        for (pos = st.length-1; pos >= 0; --pos) {
            String cn = st[pos].getClassName();
            if (cn.startsWith("org.junit."))
                break;
        }
        for (; pos >= 0; --pos) {
            String cn = st[pos].getClassName();
            if (!(cn.startsWith("org.junit") ||
                  cn.startsWith("java.lang.reflect") ||
                  cn.startsWith("sun.reflect")))
                break;
        }
        if (pos == -1)
            pos = st.length-1;
        StackTraceElement[] newst = new StackTraceElement[pos+1];
        System.arraycopy(st, 0, newst, 0, newst.length);
        t.setStackTrace(newst);
        if (t.getCause() != null)
            trimException(t.getCause());
    }
}
