import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

import edu.mit.six825.bn.functiontable.*;
import edu.mit.six825.bn.bayesnet.*;
import edu.mit.six825.bn.inputs.*;
import techniques.utils.*;
import java.util.Map;
import java.util.HashMap;
import java.util.ArrayList;

/**
 * An implementation of the Gibbs sampling algorithm for Bayes nets.
 *
 * @author drkp
 */
public class GibbsSolver extends Solver {
    private static final boolean DEBUG = false;

    private int discardPrefixLen = 5000;
    private int requiredSamples = 10000;
    private boolean randomFlipChoice = false;

    private Map/*<BayesNetNode, List<BayesNetNode>>*/ childMap;

    public GibbsSolver(BayesNet _bn) {
	super(_bn);
    }
        
    public GibbsSolver() {
	super();
    }
    
    /**
     * Set the number of samples that are discarded while the Markov chain
     * converges. Note that the total number of samples performed is still
     * that specified by setRequiredSamples, i.e. the total number that are
     * actually used is requiredSamples - discardPrefixLen.
     */
    public void setDiscardPrefixLen(int newLen) {
	discardPrefixLen = newLen;
    }

    /**
     * Set the number of samples performed
     */
    public void setRequiredSamples(int newSamples) {
	requiredSamples = newSamples;
    }

    /**
     * If set, the variable to be flipped in each sample will be chosen
     * uniformly at random rather than in systematic passes.
     */
    public void randomFlipChoice(boolean newVal) {
	randomFlipChoice = newVal;
    }
    

    /**
     * Generate an initial sample with all non-evidence variables selected
     * randomly.
     */
    private Assignment initialSample(BayesNetNodeSet bnset,
                                     Assignment evidence) {
	Assignment r = evidence;

	for (Iterator i =
		 bnset.getNodesWithTopologicalOrdering().iterator();
	     i.hasNext();) {
	    BayesNetNode node = (BayesNetNode) i.next();
	    if (!evidence.contains(node.var)) {
		Comparable val = SampleUtils.sampleNode(node, r, true, null);
		r = new Assignment(r, node.var, val);
	    }
	}

	return r;
    }

    /**
     * Resample the given node in the network, choosing a new node at random
     * with the probability distribution given by the current assignment of
     * the rest of the network.
     */
    private Assignment flipNode(Assignment last, BayesNetNode node) {
	Comparable val = SampleUtils.sampleNode(node, last, false, childMap);
	if (DEBUG) {
	    System.out.println("Flipping node " + node.var + " to " + val);
	}

	return new Assignment(last, node.var, val);
    }

    public Function query (BayesNetNode node) {
        System.out.println("Query for " + node);
	System.out.println("Samples = " + requiredSamples);
        System.out.println("Discard = " + discardPrefixLen);
        

        FunctionVariable [] queryvar = new FunctionVariable[] {node.var};
        BayesNetNodeSet bnset = VESolver.GrabRelevantNodes(_bn.nodes,
                                                           queryvar,
                                                           _evidence);

        childMap = SampleUtils.buildChildMap(bnset);

	double [] prob = new double[node.var.domain.size()];

	Assignment x = initialSample(bnset, _evidence);

	int sampleCount = 0;


	loop:
	while (true) {
	    for (int i = 0; i < bnset.size(); i++) {
		if (sampleCount >= requiredSamples) {
		    break loop;
		}

		BayesNetNode nodeToBeFlipped;
		// Node to be flipped is either sequential (given by the for
		// loop) or random
		if (randomFlipChoice) {
		    nodeToBeFlipped =
			bnset.getNode(Random.random(bnset.size()));
		} else {
		    nodeToBeFlipped = bnset.getNode(i);
		}
		
		// ...but it better not be an evidence variable.
		if (_evidence.contains(nodeToBeFlipped.var)) {
		    continue;
		}

		x = flipNode(x, nodeToBeFlipped);

		sampleCount += 1;
		if (sampleCount > discardPrefixLen) {
                    Comparable val = x.getAssignedValue(node.var);
                    int ind = node.var.domain.getIndex(val).i;
                    prob[ind] += 1;
		}
	    }
	}

	return Compute.normalize(new Function(node.var, prob));
    }
    
    public String toString() {
	return "GibbsSolver";
    }

    public static void main(String[] args) {
	System.out.println("Prob(Burglary|JohnCalls=true, MaryCalls=true)");
	System.out.println("Burglary=TRUE AIMA: " + 0.284);

	final BayesNet bn = edu.mit.six825.bn.inputs.Nets.getBurglary();
	final Solver solver = new GibbsSolver();
	solver.setBayesNet(bn);
	//Solver solver = new EnumerationSolver(bn);
	// ...GibbsSamplerSolver(bn);
	// ...LikelihoodWeightingSolver(bn);
	// ...VariableEliminationSolver(bn);

	final FunctionVariable[] vars = new FunctionVariable[2];
	vars[0] = new FunctionVariable("JohnCalls");
	vars[1] = new FunctionVariable("MaryCalls");
	final Comparable[] vals = new Comparable[2];
	vals[0] = ComparableBoolean.TRUE;
	vals[1] = ComparableBoolean.TRUE;
	final Assignment evidence = new Assignment(new FunctionVariableSet(vars), vals);
	solver.setEvidence(evidence);

	final BayesNetNode burgVar = bn.nodes.getNode("Burglary");
	final Function burgProb = solver.query(burgVar);
	System.out.println("Burglary=TRUE Calc.: " + burgProb);
    }

}
