package anastore.store;

import anastore.util.ForkedTester;
import anastore.util.IllegalTimestampException;
import anastore.util.Pair;
import anastore.util.TimeRange;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.BufferUnderflowException;
import java.util.*;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

import org.junit.*;
import static org.junit.Assert.*;

public class VersionStoreTest
{
    private static class TestBlock implements Storable
    {
        private final long _id;
        private final long _uniq;
        private static long _uniqCounter = 0;

        public TestBlock(long id)
        {
            _id = id;
            _uniq = _uniqCounter++;
        }

        public TestBlock(long id, long uniq)
        {
            _id = id;
            _uniq = uniq;
        }

        public int getDataSize()
        {
            return 24;
        }

        public byte[] getData()
        {
            ByteBuffer buf = ByteBuffer.allocate(getDataSize());
            buf.putInt(42);
            buf.putLong(_id);
            buf.putLong(_uniq);
            buf.putInt(43);
            return buf.array();
        }

        @Override
        public boolean equals(Object o)
        {
            if (!(o instanceof TestBlock))
                return false;
            TestBlock t2 = (TestBlock)o;
            return _id == t2._id && _uniq == t2._uniq;
        }

        @Override
        public String toString()
        {
            return "TestBlock<id=" + _id + ",uniq=" + _uniq + ">";
        }
    }

    private static class TestBlockFactory implements StorableFactory<TestBlock>
    {
        public TestBlock fromData(byte[] data)
            throws IllegalArgumentException
        {
            try {
                ByteBuffer buf = ByteBuffer.wrap(data);
                if (buf.getInt() != 42)
                    throw new IllegalArgumentException("Invalid magic number 1");
                long id = buf.getLong();
                long uniq = buf.getLong();
                if (buf.getInt() != 43)
                    throw new IllegalArgumentException("Invalid magic number 2");
                if (buf.remaining() != 0)
                    throw new IllegalArgumentException("This data is too large");
                return new TestBlock(id, uniq);
            } catch (BufferUnderflowException e) {
                throw new IllegalArgumentException("This data is too small", e);
            }
        }
    }

    private static class Backing implements MissHandler<TestBlock>
    {
        public Pair<TimeRange, TestBlock> get(long id, long timestamp)
            throws NoSuchDatum
        {
            if (id >= 5 || timestamp < 10)
                throw new NoSuchDatum(id, timestamp);
            if (timestamp >= 100)
                return getLatest(id);
            timestamp = timestamp - (timestamp%10);
            TimeRange r = new TimeRange(timestamp, timestamp + 9);
            TestBlock b = new TestBlock(id, timestamp);
            // Simulate some network latency for the benefit of
            // concurrency tests
            try {
                Thread.sleep(10);
            } catch (InterruptedException e) {
            }
            return new Pair<TimeRange, TestBlock>(r, b);
        }

        public Pair<TimeRange, TestBlock> getLatest(long id)
            throws NoSuchDatum
        {
            if (id >= 5)
                throw new NoSuchDatum(id);
            TimeRange r = new TimeRange(100);
            TestBlock b = new TestBlock(id, 100);
            try {
                Thread.sleep(10);
            } catch (InterruptedException e) {
            }
            return new Pair<TimeRange, TestBlock>(r, b);
        }
    }

    private Map<Long, TestBlock> mkWriteSet(long startID, int count)
    {
        Map<Long, TestBlock> res = new HashMap<Long, TestBlock>();
        for (long i = 0; i < count; ++i) {
            res.put((i+startID), new TestBlock(i+startID));
        }
        return res;
    }

    private File _dir;
    private VersionStore<TestBlock> _vs;
    private Backing _backing;

    private void create()
        throws IOException
    {
        create(false);
    }

    private void create(boolean backed)
        throws IOException
    {
        _dir = File.createTempFile("vstoretest-", "");
        _dir.delete();
        load(backed);
    }

    private void load()
        throws IOException
    {
        load(false);
    }

    private void load(boolean backed)
        throws IOException
    {
        if (_vs != null)
            _vs.testGetDisk().close();

        if (backed) {
            _backing = new Backing();
            _vs = new VersionStore<TestBlock>(_dir, new TestBlockFactory(),
                                                _backing);
            // Make sure we're allowed to read from what's in the
            // backing cache
            _vs.deprecate(0, 200, Collections.<Long>emptySet());
        } else
            _vs = new VersionStore<TestBlock>(_dir, new TestBlockFactory());
    }

    @After
    public void cleanup()
        throws IOException
    {
        if (_dir.exists()) {
            for (File f : _dir.listFiles())
                f.delete();
            _dir.delete();
        }
    }

    //////////////////////////////////////////////////////////////////
    // Basic in-memory reads and writes
    //

    @Test(expected=NoSuchDatum.class)
    public void readLatestFromEmpty()
        throws IOException
    {
        create();
        _vs.getLatest(0, null);
    }

    @Test(timeout=500)
    public void readLatestAfterOneWrite()
        throws IOException
    {
        long id = 0;
        create();
        Map<Long, TestBlock> wset = mkWriteSet(id, 1);
        _vs.write(0, wset, null, null);
        assertEquals(wset.get(id), _vs.getLatest(0, null).get());
    }

    @Test(timeout=500)
    public void readLatestAfterTwoWrites()
        throws IOException
    {
        long id = 0;
        create();
        Map<Long, TestBlock> wset1 = mkWriteSet(id, 1);
        _vs.write(0, wset1, null, null);
        Map<Long, TestBlock> wset2 = mkWriteSet(id, 1);
        _vs.write(1, wset2, null, null);
        TestBlock latest = _vs.getLatest(0, null).get();
        assertEquals(wset2.get(id), latest);
        assertFalse(wset1.get(id).equals(latest));
    }

    @Test
    public void readFromPast()
        throws IOException
    {
        long id = 0;
        create();
        Map<Long, TestBlock> wset1 = mkWriteSet(id, 1);
        _vs.write(0, wset1, null, null);
        Map<Long, TestBlock> wset2 = mkWriteSet(id, 1);
        _vs.write(1, wset2, null, null);
        TestBlock old = _vs.get(id, 1).get();
        assertEquals(wset1.get(id), old);
        assertFalse(wset2.get(id).equals(old));
        TestBlock latest = _vs.get(id, 2).get();
        assertEquals(wset2.get(id), latest);
        assertFalse(wset1.get(id).equals(latest));
    }

    @Test(expected=IllegalTimestampException.class)
    public void readFromFutureFails()
        throws IOException
    {
        long id = 0;
        create();
        Map<Long, TestBlock> wset1 = mkWriteSet(id, 1);
        try {
            _vs.write(0, wset1, null, null);
        } catch (IllegalTimestampException e) {
            fail();
        }
        _vs.get(id, 2);
    }

    @Test(expected=NoSuchDatum.class)
    public void readFromBeforeWriteFails()
        throws IOException
    {
        long id = 0;
        create();
        Map<Long, TestBlock> wset1 = mkWriteSet(id, 1);
        _vs.write(0, wset1, null, null);
        _vs.get(id, 0);
    }

    @Test()
    public void readMemoizes()
        throws IOException
    {
        create();
        Map<Long, TestBlock> wset1 = mkWriteSet(0, 1);
        _vs.write(0, wset1, null, null);

        TestBlock b1 = _vs.get(0, 1).get();
        System.gc();
        TestBlock b2 = _vs.get(0, 1).get();
        assertSame(b1, b2);
    }

    @Test(expected=NoSuchDatum.class)
    public void readNonexistentFails()
        throws IOException
    {
        create();
        Map<Long, TestBlock> wset1 = mkWriteSet(0, 1);
        _vs.write(0, wset1, null, null);
        _vs.get(1, 0);
    }

    //////////////////////////////////////////////////////////////////
    // Write and deprecate ordering
    //

    @Test(timeout=500, expected=IllegalTimestampException.class)
    public void writeInPastFails()
        throws IOException
    {
        create();
        Map<Long, TestBlock> wset1 = mkWriteSet(0, 1);
        try {
            _vs.write(0, wset1, null, null);
        } catch (IllegalTimestampException e) {
            fail();
        }
        _vs.write(0, wset1, null, null);
    }

    @Test(timeout=500, expected=IllegalStateException.class)
    public void writeSurpassedFails()
        throws IOException
    {
        create();
        final Map<Long, TestBlock> wset1 = mkWriteSet(0, 1);
        new ForkedTester()
        {
            public void parent()
                throws IOException
            {
                _vs.write(1, wset1, null, null);
            }

            public void child()
                throws InterruptedException
            {
                // Wait for write to block
                Thread.sleep(100);
                Set<Long> dep = Collections.emptySet();
                _vs.deprecate(0, 5, dep);
            }
        };
    }

    @Test(timeout=500)
    public void writeWriteOrdering()
        throws IOException
    {
        long id = 0;
        create();
        final Map<Long, TestBlock> wset1 = mkWriteSet(id, 1);
        final Map<Long, TestBlock> wset2 = mkWriteSet(id, 1);
        new ForkedTester()
        {
            public void parent()
                throws IOException
            {
                _vs.write(1, wset2, null, null);
            }

            public void child()
                throws IOException, InterruptedException
            {
                // Wait for write to block
                Thread.sleep(100);
                Set<Long> dep = Collections.emptySet();
                _vs.write(0, wset1, null, null);
            }
        };
        TestBlock latest = _vs.getLatest(id, null).get();
        assertEquals(wset2.get(id), latest);
        assertFalse(wset1.get(id).equals(latest));

        assertEquals(2, _vs.getLeastUpperBound());
    }

    @Test(timeout=500, expected=IllegalTimestampException.class)
    public void deprecateInPastFails()
        throws IOException
    {
        create();
        Map<Long, TestBlock> wset1 = mkWriteSet(0, 1);
        try {
            _vs.write(0, wset1, null, null);
        } catch (IllegalTimestampException e) {
            fail();
        }
        _vs.deprecate(0, 1, wset1.keySet());
    }

    //////////////////////////////////////////////////////////////////
    // Deprecation notification
    //

    // XXX

    //////////////////////////////////////////////////////////////////
    // Backed cache
    //

    @Test(expected=NoSuchDatum.class)
    public void backedGetTooEarlyFails()
        throws IOException
    {
        create(true);
        _vs.get(0, 5);
    }

    @Test(expected=NoSuchDatum.class)
    public void backedGetMissingIdFails()
        throws IOException
    {
        create(true);
        _vs.get(10, 40);
    }

    @Test(expected=NoSuchDatum.class)
    public void backedGetLatestMissingIdFails()
        throws IOException
    {
        create(true);
        _vs.getLatest(10, null);
    }

    @Test
    public void backedGetLatest()
        throws IOException
    {
        create(true);
        Version<TestBlock> v = _vs.getLatest(0, null);
        checkVersions(_backing.getLatest(0), v);
    }

    @Test
    public void backedGetOldExact()
        throws IOException
    {
        create(true);
        Version<TestBlock> v = _vs.get(0, 10);
        checkVersions(_backing.get(0, 10), v);
    }

    @Test
    public void backedGetOldInexact()
        throws IOException
    {
        create(true);
        Version<TestBlock> v = _vs.get(0, 15);
        checkVersions(_backing.get(0, 10), v);
    }

    @Test
    public void backedGetSameVersionDifferentTimestamps()
        throws IOException
    {
        create(true);
        Version<TestBlock> v1 = _vs.get(0, 15);
        checkVersions(_backing.get(0, 10), v1);
        Version<TestBlock> v2 = _vs.get(0, 18);
        checkVersions(_backing.get(0, 10), v2);
        assertSame(v1, v2);
        Version<TestBlock> v3 = _vs.get(0, 20);
        assertNotSame(v1, v3);
    }

    @Test
    public void backedConcurrentGets()
        throws IOException
    {
        create(true);
        final CyclicBarrier barrier = new CyclicBarrier(2);
        final Map<Integer, Version<TestBlock>>
            parentVs = new HashMap<Integer, Version<TestBlock>>(),
            childVs = new HashMap<Integer, Version<TestBlock>>();

        new ForkedTester()
        {
            public void parent()
                throws InterruptedException, BrokenBarrierException, IOException
            {
                for (int i = 10; i < 100; i += 10) {
                    barrier.await();
                    Version<TestBlock> v = _vs.get(0, i);
                    parentVs.put(i, v);
                }
            }

            public void child()
                throws InterruptedException, BrokenBarrierException, IOException
            {
                for (int i = 10; i < 100; i += 10) {
                    barrier.await();
                    Version<TestBlock> v = _vs.get(0, i+1);
                    childVs.put(i, v);
                }
            }
        };

        for (int i = 10; i < 100; i += 10) {
            Version<TestBlock> expected = _vs.get(0, i);
            assertSame(expected, parentVs.get(i));
            assertSame(expected, childVs.get(i));
        }
    }

    private static void checkVersions(Pair<TimeRange, TestBlock> expected,
                                      Version<TestBlock> actual)
        throws IOException
    {
        assertEquals(expected.first.getLowerBound(),
                     actual.getRange().getLowerBound());
        assertEquals(expected.first.hasUpperBound(),
                     actual.getRange().hasUpperBound());
        if (expected.first.hasUpperBound())
            assertEquals(expected.first.getUpperBound(),
                         actual.getRange().getUpperBound());
        assertEquals(expected.second, actual.get());
    }

    //////////////////////////////////////////////////////////////////
    // Restoring from disk
    //

    @Test
    public void reloadFlushed()
        throws IOException
    {
        create();
        Map<Long, TestBlock> set1 = mkWriteSet(0, 2);
        Map<Long, TestBlock> set2 = mkWriteSet(0, 2);
        _vs.write(0, set1, null, null);
        _vs.write(1, set2, null, null);

        for (long id : set1.keySet()) {
            Version<TestBlock> v1 = _vs.get(id, 1);
            Version<TestBlock> v2 = _vs.get(id, 2);
            assertNotSame(v1, v2);

            assertEquals(set1.get(id), v1.get());
            assertEquals(set2.get(id), v2.get());

            v1.testForgetData();
            v2.testForgetData();

            assertEquals(set1.get(id), v1.get());
            assertEquals(set2.get(id), v2.get());
        }
    }

    @Test
    public void unboundedSummary()
        throws IOException
    {
        unboundedSummaryHelper(false, false);
    }

    @Test
    public void unboundedSummaryWithDeprecations()
        throws IOException
    {
        unboundedSummaryHelper(true, false);
    }

    @Test
    public void loadFromDiskBasic()
        throws IOException
    {
        create();
        Map<Long, TestBlock> set = mkWriteSet(0, 2);
        _vs.write(0, set, null, null);

        load();
        TimeRange r = new TimeRange(1);
        for (long id : set.keySet()) {
            Version<TestBlock> v = _vs.get(id, 1);
            checkVersions(new Pair<TimeRange, TestBlock>(r, set.get(id)),
                          v);
        }
    }

    @Test
    public void loadFromDisk()
        throws IOException
    {
        unboundedSummaryHelper(false, true);
    }

    @Test
    public void loadFromDiskWithDeprecations()
        throws IOException
    {
        unboundedSummaryHelper(true, true);
    }

    private void unboundedSummaryHelper(boolean dep, boolean loadTest)
        throws IOException
    {
        create();
        Map<Long, TestBlock> set1 = mkWriteSet(0, 3);
        Map<Long, TestBlock> set2 = mkWriteSet(1, 3);
        _vs.write(0, set1, null, null);
        _vs.write(1, set2, null, null);
        if (dep)
            _vs.deprecate(2, 2, Collections.<Long>singleton((long)2));

        if (loadTest) {
            load();
            // We won't observe the deprecation because we don't flush
            // upper bound changes
            dep = false;
        }

        // Check the summary
        assertEquals(dep ? 3 : 2, _vs.getLeastUpperBound());
        Map<Long, Long> summary = _vs.getUnboundedSummary();
        assertEquals(dep ? 3 : 4, summary.size());
        assertEquals(1, summary.get((long)0));
        assertEquals(2, summary.get((long)1));
        if (!dep)
            assertEquals(2, summary.get((long)2));
        assertEquals(2, summary.get((long)3));

        if (loadTest) {
            // Check that it loaded the blocks correctly.  Note that
            // tr1 is [1..?], not [1..1] because we don't flush upper
            // bound changes.
            TimeRange tr1 = new TimeRange(1);
            TimeRange tr2 = new TimeRange(2);
            for (Map.Entry<Long, Long> entry : summary.entrySet()) {
                long id = entry.getKey();
                long timestamp = entry.getValue();
                Version<TestBlock> v = _vs.get(id, timestamp);
                if (timestamp == 1)
                    checkVersions(new Pair<TimeRange, TestBlock>(tr1, set1.get(id)),
                                  v);
                else if (timestamp == 2)
                    checkVersions(new Pair<TimeRange, TestBlock>(tr2, set2.get(id)),
                                  v);
                else
                    fail("Unexpected timestamp " + timestamp);
            }
        }
    }

    //////////////////////////////////////////////////////////////////
    // Restoring cache (sparse) data from disk
    //

    @Test
    public void cachingONB() throws IOException
    {
        // Get old, get newer, get between
        checkCaching('O', 'N', 'B');
    }

    @Test
    public void cachingOBN() throws IOException
    {
        // Get old, get between, get newer
        checkCaching('O', 'B', 'N');
    }

    @Test
    public void cachingOLB() throws IOException
    {
        // Get old, get latest, get between
        checkCaching('O', 'L', 'B');
    }

    @Test
    public void cachingNOB() throws IOException
    {
        // Get new, get old, get between
        checkCaching('N', 'O', 'B');
    }

    @Test
    public void cachingNBO() throws IOException
    {
        // Get new, get between, get old
        checkCaching('N', 'B', 'O');
    }

    @Test
    public void cachingLOB() throws IOException
    {
        // Get latest, get old, get between
        checkCaching('L', 'O', 'B');
    }

    @Test
    public void cachingLBO() throws IOException
    {
        // Get latest, get between, get old
        checkCaching('L', 'B', 'O');
    }

    private static class GotVersion
    {
        public final Pair<TimeRange, TestBlock> _expected;
        public final Version<TestBlock> _actual;

        public GotVersion(Pair<TimeRange, TestBlock> expected,
                          Version<TestBlock> actual)
        {
            _expected = expected;
            _actual = actual;
        }

        public void check()
            throws IOException
        {
            checkVersions(_expected, _actual);
        }
    }

    private void checkCaching(char... vnames)
        throws IOException
    {
        boolean gotLatest = false;
        long youngestVer = 0;
        TestBlock youngest = null;
        List<GotVersion> versions = new LinkedList<GotVersion>();
        create(true);

        // Get the versions in the order requested
        for (char vname : vnames) {
            Version<TestBlock> v;
            Pair<TimeRange, TestBlock> expect;
            switch (vname) {
            case 'O':
                v = _vs.get(0, 10);
                expect = _backing.get(0, 10);
                break;
            case 'B':
                v = _vs.get(0, 20);
                expect = _backing.get(0, 20);
                break;
            case 'N':
                v = _vs.get(0, 30);
                expect = _backing.get(0, 30);
                break;
            case 'L':
                v = _vs.getLatest(0, null);
                expect = _backing.getLatest(0);
                gotLatest = true;
                break;
            default:
                throw new IllegalArgumentException
                    ("Unknown version name " + vname);
            }
            GotVersion gv = new GotVersion(expect, v);
            gv.check();
            versions.add(gv);
            if (youngestVer < expect.first.getLowerBound()) {
                youngestVer = expect.first.getLowerBound();
                youngest = expect.second;
            }
        }

        // Check that the versions are still right
        for (GotVersion gv : versions)
            gv.check();

        // Clear the in-memory data
        for (GotVersion gv : versions)
            gv._actual.testForgetData();

        // Check that all of the versions reload from disk correctly
        for (GotVersion gv : versions)
            gv.check();

        // Reload the store, this time without the backing cache
        load();

        // Check the summary
        Map<Long, Long> summary = _vs.getUnboundedSummary();
        if (gotLatest) {
            assertEquals(1, summary.size());
            assertTrue(summary.containsKey((long)0));
            assertEquals(100, summary.get((long)0));
        } else {
            assertEquals(0, summary.size());
        }

        // Check that the youngest block we read is present
        Version<TestBlock> v;
        if (gotLatest) {
            v = _vs.getLatest(0, null);
        } else {
            v = _vs.get(0, youngestVer);
        }
        assertEquals(youngestVer, v.getRange().getLowerBound());
        assertEquals(youngest, v.get());
        

        // Check that other versions are not present
        for (long ts = 10; ts <= 30; ++ts) {
            if (ts == youngestVer)
                continue;
            try {
                _vs.get(0, ts);
                fail("Version " + ts + " should not be present");
            } catch (NoSuchDatum e) {
                // Good
            }
        }
    }

    //////////////////////////////////////////////////////////////////
    // Expiration
    //

    // XXX
}
