#ifndef BPLUSTREE_H
#define BPLUSTREE_H

#include "persifs.h"
#include "marshal.h"
#include "happyio.h"

#include <vec.h>

/*
 * branchT (sometimes called T in the comments) is the minimum
 * branching factor of the tree. We follow the CLRS definition (*not*
 * the Knuth definition): every internal non-root node contains
 * between T-1 and 2T-1 routing elements (inclusive), and hence
 * between T and 2T children.
 */

/**
 * Internal representation:
 *
 * Header:
 *  magic (BPLUSTREE_MAGIC_NUMBER)           4 bytes
 *  branchT                                  4 bytes
 *  root (offset)                            4 bytes
 *
 * Internal node:
 *  magic (BPLUS_INTERNAL_MAGIC_NUMBER)      4 bytes
 *  n = # of elements (0 <= x <= 2*T - 1)    4 bytes
 *  marshalled keys (n strings)              variable  * n
 *  n+1 children (offsets)                   4 bytes   * n+1
 *  zero-padding
 *    To be padded to length: 8 + 8T + 2TK bytes, where
 *    K = 4 + the max length of a marshalled key
 *
 * Leaf node:
 *  magic (BPLUS_LEAF_MAGIC_NUMBER)          4 bytes
 *  marshalled key (string)                  variable
 *  marshalled value (string)                variable
 */

#define BPT_DEBUG 0

const unsigned long BPLUSTREE_MAGIC_NUMBER = 0xc3427dda;
const unsigned long BPLUS_INTERNAL_MAGIC_NUMBER = 0x8c936b58;
const unsigned long BPLUS_LEAF_MAGIC_NUMBER = 0x21318ba5;

struct BPlusInternalNode;

struct BPlusInternalNode
{
  unsigned long numElements;
  str *keys;
  unsigned long *children;
  unsigned long branchT;
  unsigned long maxKeyLength;
  unsigned long pos;

  static ref<BPlusInternalNode>
  create(unsigned long branchT, unsigned long maxKeyLength,
                unsigned long pos) {
    ref <BPlusInternalNode> n = New refcounted<BPlusInternalNode>();
    n->branchT = branchT;
    n->maxKeyLength = maxKeyLength;
    n->pos = pos;
    n->numElements = 0;
    n->keys = New str[2*branchT-1];
    n->children = New unsigned long [2*branchT];
//    warn << "creating\n";
    
    return n;
  }
  
  ~BPlusInternalNode() {
//    warn << "destructing\n";
    delete [] keys;
    delete [] children;
  }
  
  static ref<BPlusInternalNode>
  load(streamReader & r, unsigned long branchT,
              unsigned long maxKeyLength, unsigned long pos) {
    ref <BPlusInternalNode> n = New refcounted<BPlusInternalNode>();
    n->branchT = branchT;
    n->maxKeyLength = maxKeyLength;
    n->pos = pos;
//    warn << "creating\n";

    if (r.readLong() != BPLUS_INTERNAL_MAGIC_NUMBER) {
      fatal << "BPlusInternalNode read: bad magic number\n";
    }

    n->numElements = r.readLong();
    if (r.hasError() || r.eos()) {
      fatal << "BPlusInternalNode read: unexpected eos or error\n";
    }
    if (n->numElements > 2*branchT-1) {
      // Should also be >= branchT - 1, but the root can violate
      // this.
      fatal << "BPlusInternalNode read: invalid number of elements "
            << n->numElements << "\n";
    }

    n->keys = New str[2*branchT-1];
    for (unsigned long i = 0; i < n->numElements; i++) {
      n->keys[i] = r.readStr();
      if (r.hasError() || r.eos()) {
        fatal << "BPlusInternalNode read: unexpected eos or error\n";
      }
    }

    n->children = New unsigned long [2*branchT];
    for (unsigned long i = 0; i < n->numElements + 1; i++) {
      n->children[i] = r.readLong();
      if (r.hasError() || r.eos()) {
        fatal << "BPlusInternalNode read: unexpected eos or error\n";
      }
    }
    
    return n; 
  }

  str marshal() {
    strWriter s;
    s.writeLong(BPLUS_INTERNAL_MAGIC_NUMBER);
    s.writeLong(numElements);
    for (unsigned long i = 0; i < numElements; i++) {
      s.writeStr(keys[i]);
    }
    for (unsigned long i = 0; i < numElements+1; i++) {
      s.writeLong(children[i]);
    }
    if (s.hasError()) {
      fatal << "BPlusInternalNode: marshalling error\n";
    }

    // Zero pad
    unsigned long padLength = 2 * (sizeof(unsigned long))
      + 4 * branchT * sizeof(unsigned long)
      + 2 * branchT * (maxKeyLength + sizeof(unsigned long))
      - s.getStr().len();
    for(unsigned long i = 0; i < padLength; i++) {
      s.writeByte(0);
    }
    
    return s.getStr();
  }

//private:
  BPlusInternalNode() {
  }
  
};

struct BPlusLeafNode
{
  str key;
  str value;
  unsigned long pos;

  BPlusLeafNode(streamReader & r, unsigned long branchT,
                unsigned long maxKeyLength, unsigned long pos)
    : pos(pos) {
    if (r.readLong() != BPLUS_LEAF_MAGIC_NUMBER) {
      fatal << "BPlusLeafNode read: bad magic number\n";
    }
    key = r.readStr();
    value = r.readStr();
    if (r.hasError() || r.eos()) {
      fatal << "BPlusLeafNode read: unexpected eos or error\n";
    }
  }

  BPlusLeafNode(unsigned long branchT, unsigned long maxKeyLength,
                unsigned long pos)
    : pos(pos) {
  }

  str marshal() {
    strWriter w;
    w.writeLong(BPLUS_LEAF_MAGIC_NUMBER);
    w.writeStr(key);
    w.writeStr(value);
    if (w.hasError()) {
      fatal << "BPlusLeafNode: marshalling error\n";
    }
    return w.getStr();
  }
};



template<class K, class V,
         class MK = marshaller<K>, class MV = marshaller<V>,
         class UMK = unmarshaller<K>, class UMV = unmarshaller<V> >
class BPlusTree
{
  const MK marshK;
  const MV marshV;
  const UMK unmarshK;
  const UMV unmarshV;
  
public:
  static void create(str filename, unsigned long branchT,
                     callback<void,
                     ref<BPlusTree<K,V,MK,MV,UMK,UMV> > >::ref cb,
                     cbe error)  {
    bool initialize;

    if (access(filename.cstr(), F_OK) == 0) {
      // File exists
      initialize = false;
    } else {
      initialize = true;
    }
  
    FILE *f = fopen(filename.cstr(),
                    initialize ? "w+b" : "r+b");
    if (!f) {
      warn << "BPlusTree: failed to open " << filename << "\n";
      error(PERR_IO);
      return;
    }
    
    if (initialize) {
      fileWriter w(f);
      w.writeLong(BPLUSTREE_MAGIC_NUMBER);
      w.writeLong(branchT);
      w.writeLong(ftell(f) + sizeof(unsigned long));

      // Create a new root
      // Need to have a marshaller in order to get maxLength
      MK mK;
      ref<BPlusInternalNode> root =
        BPlusInternalNode::create(branchT, mK.maxLength(), ftell(f));
      root->numElements = 0;
      root->children[0] = 0;
      w.writeStrWithoutLen(root->marshal());
      if (w.hasError() || fflush(f) || fseek(f, 0, SEEK_SET)) {
        warn << "BPlusTree: failed to initialize new tree file\n";
        error(PERR_IO);
        return;
      }
    }

    cb(New refcounted<BPlusTree<K,V,MK,MV,UMK,UMV> >(f));
  }
  
  void insert(const K & key, const V & val) {
    doInsert(key, val);
//    printTree();
  }
  
  void search(const K & key, callback<void, K, V>::ref cb, cbe error) {
    ptr<V> ignore = doSearch(rootPos, key, cb, error, false);
  }

  ptr<V> syncSearch(const K & key) {
    return doSearch(rootPos, key, NULL, NULL, false);
  }
  
  void predecessor(const K & key, callback<void, K, V>::ref cb, cbe error) {
    doSearch(rootPos, key, cb, error, true);
  }

  ptr<V> syncPredecessor(const K & key) {
    return doSearch(rootPos, key, NULL, NULL, true);
  }

  str testMarshal(const K & key, const V & val) {
    return marshK(key);
  }
  
  K testUnmarshal(const str & s) {
    K r = *(unmarshK(s));
    return r;
  }

  void printTree() {
    printSubtree(rootPos);
  }
  
protected:
  unsigned long branchT;        // branching factor
  FILE *f;
  fileReader fr;
  fileWriter fw;
  unsigned long rootPos;

  BPlusTree (FILE * f)
    : marshK(MK()), marshV(MV()),
      unmarshK(UMK()), unmarshV(UMV()),
      f(f), fr(f), fw(f) {
    // Read header
    fseek(f, 0, SEEK_SET);
    
    if (fr.readLong() != BPLUSTREE_MAGIC_NUMBER) {
      fatal << "BPlusTree file invalid -- bad header magic\n";
    }
    branchT = fr.readLong();
    rootPos = fr.readLong();

#if BPT_DEBUG
    warn << "Loaded BPlusTree header: branchT=" << branchT
         << " rootPos=" << rootPos << "\n";
#endif
    
    ref<BPlusInternalNode> root = getInternalNode(rootPos);
#if BPT_DEBUG
    warn << "Loaded BPlusTree; root@" << rootPos
         << " with " << root->numElements << " children\n";
#endif
  }

  bool nodeIsInternal(unsigned long pos) {
    fseek(f, pos, SEEK_SET);
    unsigned long magic = fr.readLong();
    if (fr.hasError() || fr.eos()) {
      fatal << "nodeIsInternal: read error/eos\n";
    }
    if (magic == BPLUS_INTERNAL_MAGIC_NUMBER) {
      return true;
    } else if (magic == BPLUS_LEAF_MAGIC_NUMBER) {
      return false;
    } else {
      fatal << "nodeIsInternal: magic is neither internal nor leaf\n";
    }
  }

  ref<BPlusInternalNode> getInternalNode(unsigned long pos) {
    fseek(f, pos, SEEK_SET);
    return BPlusInternalNode::load(fr, branchT, marshK.maxLength(), pos);
  }

  void putInternalNode(ref<BPlusInternalNode> n) {
    fseek(f, n->pos, SEEK_SET);
#if BPT_DEBUG
    warn << "writing internal node at " << ftell(f)
         << "  " << n->numElements << " elements\n";
#endif
    fw.writeStrWithoutLen(n->marshal());
    if (fw.hasError() || fflush(f) || fseek(f, 0, SEEK_SET)) {
      fatal << "BPlusTree: failed to put internal node\n";
      return;
    }
  }
  
  BPlusLeafNode getLeafNode(unsigned long pos) {
    fseek(f, pos, SEEK_SET);
    return BPlusLeafNode(fr, branchT, marshK.maxLength(), pos);
  }

  void putLeafNode(BPlusLeafNode & n) {
    fseek(f, n.pos, SEEK_SET);
#if BPT_DEBUG
    warn << "writing leaf node at " << ftell(f)
         << " key=" << n.key << " value= " << n.value << "\n";
#endif
    fw.writeStrWithoutLen(n.marshal());
    if (fw.hasError() || fflush(f) || fseek(f, 0, SEEK_SET)) {
      fatal << "BPlusTree: failed to put leaf node\n";
      return;
    }
  }

  ref<BPlusInternalNode> allocateInternalNode() {
    fseek(f, 0, SEEK_END);
    ref<BPlusInternalNode> r = BPlusInternalNode::create(
      branchT, marshK.maxLength(), ftell(f));
    putInternalNode(r);
    return r;
  }

  BPlusLeafNode createLeafNode(K key, V val) {
    fseek(f, 0, SEEK_END);
    BPlusLeafNode r(branchT, marshK.maxLength(), ftell(f));
    r.key = marshK(key);
    r.value = marshV(val);
#if BPT_DEBUG
    warn << "created leaf node " << key << ":"  << val
         << " @ " << r.pos << "\n";
#endif
    putLeafNode(r);
    return r;
  }

  void setRoot(ref<BPlusInternalNode> r) {
    rootPos = r->pos;
    fseek(f,2*sizeof(unsigned long), SEEK_SET);
    fw.writeLong(r->pos);
    if (fw.hasError() || fflush(f) || fseek(f, 0, SEEK_SET)) {
      fatal << "BPlusTree: failed to set root\n";
      return;
    }
  }
  
  /**
   * Split a full child of a non-full node.
   * Preconditions:
   *   y is the ith child of x
   *   y is full (2T-1) elements; x is not full
   *   z is a newly-allocated empty node
   * Postcondition:
   *   z adopts the t largest children of y, and becomes the new i+1th
   *    child of x
   *
   * Ref: CLRS B-Tree-Split-Child, p. 444
   */
  void splitChild(ref<BPlusInternalNode> x, unsigned long i,
                  ref<BPlusInternalNode> y, ref<BPlusInternalNode> z)
    {
      if (x->numElements == 2 * branchT-1 ||
          y->numElements != 2 * branchT-1 ||
          z->numElements != 0) {
            fatal << "splitChild: numElements precondition violated\n";
          }
      if (x->children[i] != y->pos) {
        fatal<< "splitChild: x[i]=y precondition violated\n";
      }

      z->numElements = branchT-1;
      
      for (unsigned long j = 0; j < branchT - 1; j++) {
        z->keys[j] = y->keys[branchT + j];
      }
      
      for (unsigned long j = 0; j < branchT; j++) {
        z->children[j] = y->children[branchT + j];
      }
      
      y->numElements = branchT-1;

      for (unsigned long j = x->numElements; j > i; j--) {
        x->children[j+1] = x->children[j];
      }
      
      x->children[i+1] = z->pos;
      
      if (x->numElements != 0) {
        for (long j = x->numElements-1; j >= (long) i; j--) {
          x->keys[j+1] = x->keys[j];
        }
      }
      
      
      x->keys[i] = y->keys[branchT-1];

      x->numElements++;

      putInternalNode(x);
      putInternalNode(y);
      putInternalNode(z);
    }

  /**
   * Insert key:val into the tree.
   *
   * Ref: CLRS B-Tree-Insert, p. 445
   */
  void doInsert(K key, V val) {
    //TRACE();
    ref<BPlusInternalNode> r = getInternalNode(rootPos);
    if (r->numElements == 2*branchT-1) {
      //TRACE();
      // Root is full, split it
      ref<BPlusInternalNode> s = allocateInternalNode();
      s->children[0] = r->pos;
      ref<BPlusInternalNode> z = allocateInternalNode();
      setRoot(s);
      splitChild(s, 0, r, z);
      insertNonFull(s, key, val);
    } else {
      insertNonFull(r, key, val);
    }
  }

  /**
   * Insert key:val into the tree rooted at x. Assumes that x is
   * non-full.
   */
  void insertNonFull(ref<BPlusInternalNode> x, K key, V val) {
    //TRACE();
    if (x->numElements == 0 && x->children[0] == 0) {
      //TRACE();
      // Empty subtree (should only happen on root of empty tree)
      BPlusLeafNode l = createLeafNode(key, val);
      x->children[0] = l.pos;
      putInternalNode(x);
    } else if (nodeIsInternal(x->children[0])) {
      //TRACE();
      // Children aren't leaves
      long i;
      for (i = x->numElements-1; i >= 0; i--) {
        if (unmarshK(x->keys[i]) <= key)
          break;
      }
      i++;
      // x.keys[i] is the smallest elt > key

      ref<BPlusInternalNode> c = getInternalNode(x->children[i]);

      if (c->numElements == 2*branchT-1) {
        //TRACE();
        ref<BPlusInternalNode> z = allocateInternalNode();
        splitChild(x, i, c, z);
        if (key >= unmarshK(x->keys[i])) {
          // Move to right side of the new split
          i++;
          insertNonFull(z, key, val);
        } else {
          insertNonFull(c, key, val);
        } 
      } else {
        insertNonFull(c, key, val);
      }
    } else {
      
      //TRACE();
      // Children are leaves, and at least one already exists
      long i;
      // Unless we're inserting at the far left, we have to
      // special-case this

      
      K farLeftChildKey = unmarshK(getLeafNode(x->children[0]).key);
      if (key == farLeftChildKey) {
        BPlusLeafNode l = createLeafNode(key, val);
        x->children[0] = l.pos;
      }
      
      if (key < farLeftChildKey) {
        for (i = x->numElements-1; i >= 0; i--) {
          x->keys[i+1] = x->keys[i];
          x->children[i+2] = x->children[i+1];
        }
        //TRACE();
        BPlusLeafNode ol = getLeafNode(x->children[0]);
        BPlusLeafNode nl = createLeafNode(key, val);
        x->keys[0] = ol.key;
        x->children[1] = ol.pos;
        x->children[0] = nl.pos;
        x->numElements++;
        putInternalNode(x);
      } else {
        // Check whether this key already exists, and if so change its
        // value
        for (unsigned long j = 0; j < x->numElements; j++) {
          if (unmarshK(x->keys[j]) == key) {
            BPlusLeafNode l = createLeafNode(key, val);
            x->children[j+1] = l.pos;
            putInternalNode(x);
            return;
          }
        }

        // Otherwise, insert it
        
        for (i = x->numElements-1; i >= 0; i--) {
          if (unmarshK(x->keys[i]) < key) {
            break;
          }
          x->keys[i+1] = x->keys[i];
          x->children[i+2] = x->children[i+1];
        }
        //TRACE();
        x->keys[i+1] = marshK(key);
        BPlusLeafNode l = createLeafNode(key, val);
        x->children[i+2] = l.pos;
        x->numElements++;
        putInternalNode(x);
      }
    }
  }

  ptr<V>
  doSearch(unsigned int pos, K k, callback<void,K,V>::ptr cb,
           callback<void,pStat>::ptr error,
           bool predecessorOK) {
#if BPT_DEBUG
    warn << "doSearch: pos=" << pos << "\n";
#endif
    if (nodeIsInternal(pos)) {
      ref<BPlusInternalNode> x = getInternalNode(pos);
      
      unsigned long i;
      if (x->numElements == 0) {
        if (x->children[0] == 0) {
          // empty tree case
          if (error) {
            error(PERR_NOENT);
          }
          return NULL;
        } else {
          i = 0;
        }
      } else if (k < unmarshK(x->keys[0])) {
        i = 0;
      } else {
        for (i = 0; i < x->numElements; i++) {
          if (k < unmarshK(x->keys[i])) {
//            warn << "k=" << k << " <= x.keys[i]=" << unmarshK(x.keys[i]) << "\n";
            break;
          }
        }
      }
//      warn << "i = " << i << "\n";
      return doSearch(x->children[i], k, cb, error, predecessorOK);
    } else {
      // Leaf node
      BPlusLeafNode x = getLeafNode(pos);
#if BPT_DEBUG
      warn << "doSearch: reached leaf key=" << unmarshK(x.key) << "\n";
#endif
      if (unmarshK(x.key) == k || (predecessorOK && (unmarshK(x.key) <= k)))
      {
        K rKey = unmarshK(x.key);
        V rVal = unmarshV(x.value);
        if (cb) {
          cb(rKey, rVal);
        }
        return New refcounted<V>(rVal);
      } else {
        if (error) {
            error(PERR_NOENT);
        }
        return NULL;
      }
    } 
  }


  void printSubtree(unsigned int pos) {
    warn << "----------------------------------------\n";
    if (nodeIsInternal(pos)) {
      warn << "Internal node at " << pos << "\n";
      ref<BPlusInternalNode> n = getInternalNode(pos);
      warn << "Contains " << n->numElements << " elements\n";
      for (unsigned long i = 0; i < n->numElements; i++) {
        warn << "  " << i << ": " << unmarshK(n->keys[i]) << "\n";
      }
      for (unsigned long i = 0; i <= n->numElements; i++) {
        warn << "Node " << pos << " child #" << i << ": " << n->children[i]
             << "\n";
        printSubtree(n->children[i]);
      }
    } else {
      warn << "LEAF node at " << pos << "\n";
      BPlusLeafNode l = getLeafNode(pos);
      warn << "key=" << unmarshK(l.key) << "\n";
      warn << "val=" << unmarshV(l.value) << "\n";
    }
    warn << "----------------------------------------\n";
  }
};


#endif
