import java.util.*;
import java.io.*;

import edu.mit.six825.bn.functiontable.*;
import edu.mit.six825.bn.bayesnet.*;

/**
 * A mutable BN structure implementation
 *
 * @author nocturne
 */
public class BNStructure {
    public  BayesNet bn;
    public  boolean useBIC = true;

    private LinkedList opHistory = new LinkedList();

    private boolean DEBUG = false;
    private Random rng = new Random();


    /*
     * Constructors.
     */
    public BNStructure(BayesNet _bn) {
	bn = _bn;
    }

    public BNStructure(BayesNetNodeSet nodes) {
	bn = new BayesNet(nodes);
    }

    public String toString() {
	// Perhaps call BNUtils.printNet(bn) instead?
	return(bn.nodes.toString());
    }


    /*
     * Build initial configurations
     */

    /**
     * Remove all edges, giving an unconnected network.
     */
    public void makeUnconnected() {
        for (Iterator it = bn.nodes.iterator(); it.hasNext();) {
            BayesNetNode node = (BayesNetNode) it.next();
            GiveNodeNewParents(node, new BayesNetNodeSet());
        }
    }

    /**
     * Make a fully connected network: every node is a parent of all
     * nodes that follow it in the topological ordering.
     */
    public void makeFullyConnected() {
        BayesNetNodeSet parents = new BayesNetNodeSet();

        for (Iterator it = bn.nodes.iterator(); it.hasNext();) {
            BayesNetNode node = (BayesNetNode) it.next();
            GiveNodeNewParents(node, parents);
            parents = new BayesNetNodeSet(parents, node);
        }
    }

    /**
     * Make the network randomly connected: remove all existing edges,
     * then add k edges between random nodes (that don't create
     * directed cycles). k can vary to specify the density or sparsity
     * of edges in the graph.
     */
    public long makeRandomlyConnected(int k) {
	long randomseed = rng.nextLong();
	makeRandomConnWithSeed(k, randomseed);
	return (randomseed);
    }


    /**
     * Make the network connected in a pseudorandom deterministic
     * fashion, as determined by k and randseed.
     *
     * For a given (constant) value of 'k', multiple invocations of
     * this method will reset the network structure to identically
     * "randomized" states. (Behavior is undefined if you vary k
     * between calls.)
     */
    public void makeRandomConnWithSeed(int k, long randseed) {
	Random rewindableRNG = new Random(randseed);
	
        makeUnconnected();
        
        for (int i = 0; i < k; i++) {
            boolean found = false;
            while (!found) {
                int indexFrom = rewindableRNG.nextInt(bn.nodes.size());
                int indexTo   = indexFrom;
                while (indexTo == indexFrom) {
                    indexTo = rewindableRNG.nextInt(bn.nodes.size());
                }
                
                BayesNetNode from = bn.nodes.getNode(indexFrom);
                BayesNetNode to   = bn.nodes.getNode(indexTo);

                if (!BNOp.EdgeFromTo(from, to) &&
                    !BNOp.IsAncestorOf(to, from)) {
                    // OK to add this edge; do so.
                    AddEdge(from, to);
                    found = true;
                } else {
                    // to is an ancestor of from; can't add this
                    // edge. Will try again.                    
                }
            }
        }

    }
    
    /*************** Public mutators ****************/

    public void applyOp(BNOp op) {
	if (!op.isValid()) {
	    throw new IllegalArgumentException("Invalid BNOp " + op);
	}
	opHistory.add(op);

	if (op.move == MoveType.ADD) {
	    this.AddEdge(op.from, op.to);
	} else if (op.move == MoveType.DELETE) {
	    this.DeleteEdge(op.from, op.to);
	} else if (op.move == MoveType.REVERSE) {
	    this.ReverseEdge(op.from, op.to);
	} else {
	    throw new IllegalArgumentException("Unknown move type "
					       + op.move);
	}
    }

    public void unrollOp(BNOp op) {
	if (opHistory.getLast() != op) {
	    throw new IllegalArgumentException(
                "Cannot unroll BNOp " + op + ", because the most recent "
                + "BNOp was " + (BNOp) opHistory.getLast());
	}
	
	if (op.move == MoveType.ADD) {
	    this.DeleteEdge(op.from, op.to);
	} else if (op.move == MoveType.DELETE) {
	    this.AddEdge(op.from, op.to);
	} else if (op.move == MoveType.REVERSE) {
	    this.ReverseEdge(op.to, op.from);
	} else {
	    throw new IllegalArgumentException("Unknown move type "
					       + op.move);
	}
	opHistory.removeLast();
    }

    /////////////////////////////////////////////////////////////////////
    // AddEdge, DeleteEdge, ReverseEdge, and GiveNodeNewParents are
    // private utility functions which perform no argument checking.
    // All argument/sanity checking should be performed by the calling
    // routines.

    // Make From a parent of To
    private void AddEdge(BayesNetNode from, BayesNetNode to) {
	this.GiveNodeNewParents(to, new BayesNetNodeSet(to.parents, from));
    }

    // Make From cease to be a parent of To
    private void DeleteEdge(BayesNetNode from, BayesNetNode to) {
	BayesNetNode[] newparents = new BayesNetNode[to.parents.size()-1];

	for (int pOld=0, pNew=0; pNew<newparents.length; pOld++, pNew++) {
	    if (to.parents.getNode(pOld).equals(from)) {
		pOld++;
	    }
	    newparents[pNew] = to.parents.getNode(pOld);
	}
	this.GiveNodeNewParents(to, new BayesNetNodeSet(newparents));
    }

    // Reverse parental relationship between From and To
    private void ReverseEdge(BayesNetNode from, BayesNetNode to) {
	BayesNetNode oldParent;
	BayesNetNode oldChild;
	if (to.parents.contains(from)) {
	    // This is the only case that isValid()
	    oldParent = from;
	    oldChild  = to;
	} else {
	    // but allow this also for now...
	    System.out.println("WARNING: allowing semi-valid reverse-edge");
	    oldParent = to;
	    oldChild  = from;
	}

	this.DeleteEdge(oldParent, oldChild);
	this.AddEdge   (oldChild, oldParent);
    }

    // Modify Node so its parents are as requested, and install a
    // blank CPT.
    private void GiveNodeNewParents(BayesNetNode node,
				    BayesNetNodeSet parents) {
	FunctionVariableSet cptvars = parents.getFunctionVariableSet();
	cptvars = FunctionVariableSet.union(cptvars, node.var);

	double[] cptvals = new double [cptvars.cartesianProductSize()];

      	Function newcpt = new Function(cptvars, cptvals);

	node.parents = parents;
	node.cpt     = newcpt;
    }

    /*************** Accessors ****************/

    /*
     * What is the BIC score of this network structure, given the
     * data? (Before score is computed, MLEstimate is invoked to
     * regenerate the CPTs)
     *
     * NOTE 1: This is very inefficient; code could be written to
     * reuse those previously-computed CPTs which were still valid,
     * and only recalcuate the ones which changed as a result of the
     * BNOp(s) applied since the last time the CPTs were calculated.
     * However, I don't think we need the speed boost; it's probably
     * fine if this takes a while to run. Similarly, we should be able
     * to cache portions of the BIC score.
     *
     * NOTE 2: I'm making the possibly-unfounded assumption that
     * MLEstimate.computeCPTs doesn't require the CPTs to have already
     * been wiped clean. However, the best-op scoring comes out the
     * same regardless of whether the CPTs are wiped in advance, so
     * I'm going to assume it's fine to not wipe the CPTs before
     * running MLEstimate.
     *   \-> This assumption is correct. --drkp
     */
    public double score(BNData data) {
	// BNUtils.wipeCPTs(bn);
	MLEstimate.computeCPTs(bn, data);
	if (useBIC) {
	    return(BNUtils.BICScore(bn, data));
	} else {
	    return(BNUtils.AICScore(bn, data));
	}
    }

    // A simple wrapper allowing demonstration/testing of the
    // BNOpIterator method.
    public void showAllOps() {
	Iterator opIt = (Iterator) new BNOp.BNOpRandIterator(bn, rng);

	for ( ; opIt.hasNext() ; ) {
	    System.out.println(opIt.next());
	}
    }

    /**
     * If FirstRatherThanBest is true: == getNextImprovingOp
     *
     *   Try all valid BNOps (in random order) until we find one that
     *   yields an improvement on the current state score. The first
     *   such op that is found will be returned immediately. If no
     *   BNOp yields an improvement over the current state score, the
     *   least-bad BNOp (i.e. the one with the best score) will be
     *   returned.
     *
     * If FirstRatherThanBest is false: == getBestPossibleOp
     *
     *   Among all (randomly-ordered) valid BNOps, return the one
     *   which yields the best score given Data.
     */
    public BNOp getNextOp(BNData data, double curStateScore,
				  boolean FirstRatherThanBest) {
	BNOp bestop = null;
	double bestScore = java.lang.Double.NEGATIVE_INFINITY;

	Iterator opIt = (Iterator) new BNOp.BNOpRandIterator(bn, rng);

	for ( ; opIt.hasNext() ; ) {
	    BNOp op = (BNOp) opIt.next();

	    double opScore = this.scoreOp(op, data);
	    if (opScore > bestScore) {
		if (DEBUG) {
		    System.out.println("Improved score from " + bestScore +
				       " to " + opScore + " with op " + op);
		}
		bestop = op;
		bestScore = opScore;
		if (FirstRatherThanBest && (bestScore > curStateScore)) {
		    return (bestop);
		}
	    }
	}

	/**
	 * Note, when we eventually reach the local maximum, the score
	 * of this op will in fact be worse than curStateScore,
	 * regardless of whether we're doing First-Ascent or Greedy
	 * search. And thus, when performing FirstRatherThanBest
	 * search from a local maximum, first we'll try all ops, and
	 * then fall through to here -- and return our best
	 * (non-improving) move.
	 */
	return(bestop);
    }

    public double scoreOp(BNOp op, BNData data) {
	/**
	 * Evaluate a BNOp given the data. Because a given BNOp object
	 * is only ever applied to a single structure-state, we can
	 * cache the bnop's score so we don't have to calculate it
	 * more than once.
	 */
	if (op.hasScoreCache()) {
	    return(op.getScoreCache());
	} else {
	    this.applyOp(op);
	    double opScore = this.score(data);
	    this.unrollOp(op);

	    op.setScoreCache(opScore);
	    return(opScore);
	}
    }

	

    /**
     * Choose a random valid BNOp
     */
    public BNOp getRandomOp() {
	MoveType move;
	BayesNetNode from, to;
	while(true) {
	    move = MoveType.random(rng);
	    int indexFrom = rng.nextInt(bn.nodes.size());
	    int indexTo   = indexFrom;
	    while (indexTo == indexFrom) {
		// Because no move is valid when its source and target are
		// equal.
		indexTo = rng.nextInt(bn.nodes.size());
	    }

	    from = bn.nodes.getNode(indexFrom);
	    to   = bn.nodes.getNode(indexTo);

	    BNOp candidateOp = new BNOp(move, from, to);
	    if (candidateOp.isValid()) {
		return(candidateOp);
	    }

	    if (DEBUG) {
	    	System.out.println("Ignored invalid randomly-generated op "
				   + candidateOp);
	    }
	}
    }


    /**
     * Save a representation of the Bayes net structure in
     * dot/graphviz format, then execute 'dot -Tps file.dot' to
     * convert it to PostScript. The title is printed in the graph,
     * and the BIC score is appended.
     */
    public void printDotFile(BNData data, String filename, String title) {
        try {
            FileWriter fw = new FileWriter(filename + ".dot");
            PrintWriter pw = new PrintWriter(fw);

	    pw.println(this.toDotString(score(data), title));
            pw.flush();

            // Exec dot to generate postscript files of the
            // graph. This is pretty much guaranteed not to be
            // portable, for some suitable definition of portable.
            String[] cmd = { "/bin/sh", "-c",
                             "dot -Tps " + filename + ".dot" +
                             " > " + filename + ".ps" };
            Runtime.getRuntime().exec(cmd);
        } catch (IOException e) {
            System.out.println("IO exception: " + e);
        }
    }

    public String toDotString(double score, String title) {
	String s = "";

	String scorename = (useBIC) ? "BIC" : "AIC";

	s += "digraph bn {\n";
	s += ("label = \"" + title
	      + " -- " + scorename + "=" + score + "\"\n");
	
	for (Iterator it = bn.nodes.iterator(); it.hasNext();) {
	    BayesNetNode node = (BayesNetNode) it.next();
	    s += (node.var.name + ";\n");
	    for (Iterator it2 = node.parents.iterator(); it2.hasNext();) {
		BayesNetNode parent = (BayesNetNode) it2.next();
		s += (parent.var.name + " -> " + node.var.name + ";\n");
	    }
	}
	
	s += ("};");

	return(s);
    }


    /**
     * Perform greedy structure search, as per K&F Figure 14.7,
     * p.598.
     *
     * Returns the number of iterations required to find a maximum.
     *
     * Will terminate after maxIterations steps, or continue until no
     * further progress can be made if maxIterations = 0
     */
    public SearchResult	greedyStructureSearch(BNData data, int maxIterations) {
	return(structureSearch(data, maxIterations, true));
    }

    /**
     * Perform first-ascent structure search, as per <something>.
     *
     * Returns the number of iterations required to find a maximum.
     *
     * Will terminate after maxIterations steps, or continue until no
     * further progress can be made if maxIterations = 0
     */
    public SearchResult firstAscentStructureSearch(BNData data, int maxIterations) {
	return(structureSearch(data, maxIterations, false));
    }

    public SearchResult structureSearch(BNData data, int maxIterations,
			       boolean greedyRatherThanFirstAscent) {
        if (maxIterations == 0) {
            maxIterations = Integer.MAX_VALUE;
        }

        double curScore = score(data);
        int iteration;
        
        for (iteration = 0; iteration < maxIterations; iteration++) {
            
            BNOp bestOp = getNextOp(data, curScore,
				    !greedyRatherThanFirstAscent);
            if (DEBUG) {
                    System.out.println("ITERATION " +  iteration +
                                       ": best possible op is " + bestOp);
            }

            applyOp(bestOp);
            double opScore;

	    if (bestOp.hasScoreCache()) {
		opScore = bestOp.getScoreCache();
	    } else {
		opScore = score(data);
	    }

            if (opScore > curScore) {
                if (DEBUG) {
                    System.out.println("The score improves from " + curScore +
                                       " to " + opScore);
                }
                curScore = opScore;
            } else {
                if (DEBUG) {
                    System.out.println("Score is " + opScore +
                                       " which is worse than current " +
                                       curScore);
                }
                unrollOp(bestOp);
                break;
            }
        }

	SearchResult result = new SearchResult();
	result.steps = iteration;
	result.score = curScore;
	result.dimension = BNUtils.dimension(bn);
	result.structuralPenalty = BNUtils.structuralPenalty(bn, data);
        
        return result;
    }

    public static class SearchResult {
	public int    steps;
	public double score;
	public int    dimension;
	public double structuralPenalty;

	public SearchResult() { }
    }
}    
