package anastore.util;

import java.io.*;

import org.junit.*;
import static org.junit.Assert.*;

public class FastObjectStreamTest
{
    private static class Thing implements Serializable
    {
        long _l;
        Boolean _b;

        public Thing(long l, Boolean b)
        {
            _l = l;
            _b = b;
        }

        @Override
        public boolean equals(Object o)
        {
            if (!(o instanceof Thing))
                return false;
            Thing t2 = (Thing)o;
            return (t2._l == _l && t2._b.equals(_b));
        }
    }

    private ByteArrayOutputStream bo;
    private FastObjectOutputStream o;

    private ByteArrayInputStream bi;
    private FastObjectInputStream i;

    @Before
    public void setUp()
        throws IOException
    {
        bo = new ByteArrayOutputStream();
        o = new FastObjectOutputStream(bo);
    }

    private void startRead()
        throws IOException
    {
        o.flush();
        HexDump.dump(bo);
        bi = new ByteArrayInputStream(bo.toByteArray());
        i = new FastObjectInputStream(bi);
    }

    private void assertSameExpectations()
    {
        assertEquals(o.getExpectations(), i.getExpectations());
    }

    private int count(String s, String substring)
    {
        int c = 0;
        int index = -1;
        while ((index = s.indexOf(substring, index+1)) != -1)
            ++c;
        return c;
    }

    //////////////////////////////////////////////////////////////////
    // Tests
    //

    @Test
    public void basicWriteRead()
        throws IOException, ClassNotFoundException
    {
        Thing t = new Thing(42, true);
        o.writeObject(t);

        startRead();
        assertEquals(t, i.readObject());

        assertSameExpectations();

        // Check encoding
        assertFalse(bo.toString("ascii").contains(".String"));
    }

    @Test
    public void multiWriteRead()
        throws IOException, ClassNotFoundException
    {
        Thing t1 = new Thing(42, true);
        Thing t2 = new Thing(21, false);
        o.writeObject(t1);
        o.writeObject(t2);

        startRead();
        assertEquals(t1, i.readObject());
        assertEquals(t2, i.readObject());

        assertSameExpectations();
    }

    @Test
    public void reset()
        throws IOException, ClassNotFoundException
    {
        Thing t = new Thing(42, true);
        o.writeObject(t);
        o.reset();
        o.writeObject(t);

        startRead();
        Thing tr1 = (Thing)i.readObject();
        assertEquals(t, tr1);
        Thing tr2 = (Thing)i.readObject();
        assertEquals(t, tr2);
        assertNotSame(tr1, tr2);

        assertSameExpectations();
    }

    private static class Cons implements Serializable
    {
        long _car;
        Cons _cdr;

        public Cons(long car, Cons cdr)
        {
            _car = car;
            _cdr = cdr;
        }

        @Override
        public boolean equals(Object o)
        {
            if (!(o instanceof Cons))
                return false;
            Cons c2 = (Cons)o;
            return (_car == c2._car &&
                    ((_cdr == null && c2._cdr == null) ||
                     (_cdr.equals(c2._cdr))));
        }
    }

    @Test
    public void recursiveClass()
        throws IOException, ClassNotFoundException
    {
        Cons c2 = new Cons(2, null);
        Cons c1 = new Cons(1, c2);
        o.writeObject(c1);

        startRead();
        assertEquals(c1, i.readObject());

        assertSameExpectations();
    }

    private static class Skip implements Serializable
    {
        Thing _t;
        // We need an expected field to force the skip record out
        Integer _z;

        public Skip(Thing t, Integer z)
        {
            _t = t;
            _z = z;
        }

        @Override
        public boolean equals(Object o)
        {
            if (!(o instanceof Skip))
                return false;
            Skip s2 = (Skip)o;
            return (_t.equals(s2._t) && _z.equals(s2._z));
        }
    }

    @Test
    public void writeEmbeddedBackRef()
        throws IOException, ClassNotFoundException
    {
        // This is a tad convoluted.  First, we write out something of
        // type A, then we write out something that has a field of
        // type A.  This should produce a skip record for that field,
        // as long as there's another field following the A field.
        Thing t = new Thing(42, true);
        o.writeObject(t);
        Skip s = new Skip(t, 21);
        o.writeObject(s);

        startRead();
        assertEquals(t, i.readObject());
        assertEquals(s, i.readObject());

        assertSameExpectations();

        // Check encoding
        String data = bo.toString("ascii");
        assertEquals(1, count(data, "$Thing"));
    }

    private static class ExpectFoo implements Serializable
    {
        Foo _f;
    }

    private static class Foo implements Serializable
    {
    }

    private static class Bar extends Foo
    {
    }

    private static class Baz extends Bar
    {
    }

    @Test
    public void writeDynamicSubclass()
        throws IOException, ClassNotFoundException
    {
        // Test writing a field whose dynamic value is a subclass of
        // the field's static type
        ExpectFoo f = new ExpectFoo();
        f._f = new Baz();
        o.writeObject(f);

        startRead();
        ExpectFoo r = (ExpectFoo)i.readObject();
        assertTrue(r._f instanceof Baz);

        assertSameExpectations();

        // Check encoding
        String data = bo.toString("ascii");
        assertEquals(1, count(data, "$Baz"));
        assertFalse(data.contains("$Bar"));
        assertFalse(data.contains("$Foo"));
    }

    private static class Stringy implements Serializable
    {
        String _a;
        Integer _b;
    }

    @Test
    public void writeStringFields()
        throws IOException, ClassNotFoundException
    {
        Stringy s = new Stringy();
        s._a = "Foo";
        s._b = 42;
        o.writeObject(s);

        startRead();
        Stringy r = (Stringy)i.readObject();
        assertEquals("Foo", r._a);
        assertEquals(42, r._b);

        assertSameExpectations();

        // Check encoding
        String data = bo.toString("ascii");
        assertFalse(data.contains(".String"));
        // This is a little sketchy, since bytes equal to CODE_SKIP
        // could appear for other reasons, but it works with the
        // encoding at the time this was written.
        for (byte b : bo.toByteArray())
            assertFalse(b == StreamExpectations.CODE_SKIP);
    }
}
