package anastore.util;

import java.lang.reflect.Method;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * A utility for implementing multi-threaded unit tests.  To split
 * part of a test into multiple threads, subclass ForkedTester
 * (typically using an anonymous class), and implement the 'parent'
 * method, as well as any number of public methods prefixed with the
 * string 'child'.  When the class is instantiated, it will spawn off
 * a thread for each child method and then call the parent method in
 * the main thread.  If the parent method returns, the main thread
 * will join with the child threads until they have all terminated.
 * If any thread throws an exception, all of the threads are
 * immediately killed and the exception is propagated out of the main
 * thread.
 */
public abstract class ForkedTester
{
    private AtomicBoolean _stopping = new AtomicBoolean(false);
    private List<Thread> _threads = new ArrayList<Thread>();
    private Thread _parent;

    public abstract void parent() throws Exception;

    /**
     * Create and spawn threads for each public method whose name
     * begins with 'child'.  Then invoke parent.  When parent returns,
     * join with all of the threads and return.  If any of the threads
     * throws an exception, kill all of the threads and throw a
     * WrappedException from the constructor.  If parent throws an
     * unchecked exception, rethrow it.  If the parent throws a
     * checked exception, wrap it in a WrappedException and rethrow
     * it.
     */
    public ForkedTester()
    {
        _parent = Thread.currentThread();

        for (Method m : getClass().getMethods()) {
            if (!ForkedTester.class.isAssignableFrom(m.getDeclaringClass()))
                continue;
            if (m.getName().equals("parent"))
                continue;
            if (!m.getName().startsWith("child"))
                throw new RuntimeException
                    ("Public method " + m.getName() + " not prefixed with 'child'");
            final Method runMethod = m;
            Thread thr = new Thread() {
                    public void run()
                    {
                        try {
                            // UGH!  For some reason we can't call
                            // public methods of anonymous classes in
                            // other packages.
                            runMethod.setAccessible(true);
                            runMethod.invoke(ForkedTester.this);
                        } catch (InvocationTargetException e) {
                            stopAll(e.getCause());
                        } catch (Exception e) {
                            stopAll(e);
                        }
                    }
                };
            _threads.add(thr);
        }
        if (_threads.isEmpty())
            throw new IllegalArgumentException("No test threads");

        try {
            for (Thread thr : _threads)
                thr.start();
            try {
                parent();
            } catch (RuntimeException e) {
                throw e;
            } catch (Exception e) {
                throw new WrappedException(e);
            }
            for (Thread thr : _threads) {
                try {
                    thr.join();
                } catch (InterruptedException e) {
                    throw new RuntimeException("Thread join interrupted");
                }
            }
        } finally {
            stopAll(null);
        }
    }

    @SuppressWarnings("deprecation")
    private void stopAll(Throwable toParent)
    {
        if (!_stopping.getAndSet(true)) {
            for (Thread thr : _threads) {
                if (thr != Thread.currentThread())
                    thr.stop();
            }
            if (toParent != null)
                _parent.stop(new WrappedException(toParent));
        }
    }

    public static class WrappedException extends RuntimeException
    {
        public WrappedException(Throwable cause)
        {
            super(cause);
        }
    }
}
