import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

import edu.mit.six825.bn.functiontable.*;
import edu.mit.six825.bn.bayesnet.*;
import edu.mit.six825.bn.inputs.*;
import techniques.utils.*;
import java.util.HashMap;

/**
 * An implementation of the likelihood weighting sampling algorithm
 * for Bayes nets.
 *
 * @author drkp
 */
public class LWSolver extends Solver {

    private static final double DEFAULT_REQUIRED_WEIGHT = 100.0;
    private double requiredWeight = DEFAULT_REQUIRED_WEIGHT;
    
    public LWSolver(BayesNet _bn) {
	super(_bn);
    }
        
    public LWSolver() {
	super();
    }

    public void setRequiredWeight(double newWeight) {
	requiredWeight = newWeight;
    }
    
    /**
     * A weighted sample -- just an assignment with a weight
     */
    public class WeightedParticle {
	public double weight;
	public Assignment assignment;
		
	public WeightedParticle(double _weight, Assignment _assignment) {
	    assignment = _assignment;
	    weight = _weight;
	}
    }

    /**
     * Generate a random sample satisfying the evidence, along with the
     * appropriate weight. This is implemented per procedure "LW-sample"
     * (Koller & Friedman p.330)
     */
    private WeightedParticle generateParticle(BayesNetNodeSet bnset) {
	double weight = 1.0;
	Assignment x = _evidence;	// sampled assignment so far
	List ordering = bnset.getNodesWithTopologicalOrdering();

//	System.out.println("Generating particle\n");
	for (Iterator i = ordering.iterator(); i.hasNext();) {
	    BayesNetNode node = (BayesNetNode) i.next();

	    if (_evidence.contains(node.var)) {
		double[] dist = SampleUtils.evaluateNodeGivenParents(node, x);
                Comparable val = _evidence.getAssignedValue(node.var);
                x = new Assignment(x, node.var, val);
                int ind = node.var.domain.getIndex(val).i;
                weight *= dist[ind];
// 		System.out.println("Forcing " + node.var + " to " +
// 				   x.getAssignedValue(node.var));
// 		System.out.println("Weight -> " + weight);
	    } else {
		x = new Assignment(x, node.var,
                                   SampleUtils.sampleNode(node, x,
                                                          true, null));
// 		System.out.println("Sampling " + node.var + " as " +
// 				   x.getAssignedValue(node.var));
	    }
	}

	return new WeightedParticle(weight, x);
    }

    public Function query (BayesNetNode node) {
	System.out.println("Query for " + node);
	System.out.println("Required weight: " + requiredWeight);

        FunctionVariable [] queryvar = new FunctionVariable[] {node.var};
        BayesNetNodeSet bnset = VESolver.GrabRelevantNodes(_bn.nodes,
                                                           queryvar,
                                                           _evidence);

	double totalWeight = 0.0;
        int numSamples = 0;
	double [] prob = new double[node.var.domain.size()];
	while (numSamples < requiredWeight) {
	    WeightedParticle particle = generateParticle(bnset);
	    totalWeight += particle.weight;

	    Comparable val = particle.assignment.getAssignedValue(node.var);
	    int i = node.var.domain.getIndex(val).i;
	    prob[i] += particle.weight;
            numSamples++;
	}

	// Normalize.
	for (int i = 0; i < node.var.domain.size(); i++) {
	    prob[i] /= totalWeight;
	}
	
	return new Function(node.var, prob);
    }
    
    public String toString() {
	return "LWSolver";
    }

    public static void main(String[] args) {
	System.out.println("Prob(Burglary|JohnCalls=true, MaryCalls=true)");
	System.out.println("Burglary=TRUE AIMA: " + 0.284);

	final BayesNet bn = edu.mit.six825.bn.inputs.Nets.getBurglary();
	final Solver solver = new LWSolver();
	solver.setBayesNet(bn);
	//Solver solver = new EnumerationSolver(bn);
	// ...GibbsSamplerSolver(bn);
	// ...LikelihoodWeightingSolver(bn);
	// ...VariableEliminationSolver(bn);

	final FunctionVariable[] vars = new FunctionVariable[2];
	vars[0] = new FunctionVariable("JohnCalls");
	vars[1] = new FunctionVariable("MaryCalls");
	final Comparable[] vals = new Comparable[2];
	vals[0] = ComparableBoolean.TRUE;
	vals[1] = ComparableBoolean.TRUE;
	final Assignment evidence = new Assignment(new FunctionVariableSet(vars), vals);
	solver.setEvidence(evidence);

	final BayesNetNode burgVar = bn.nodes.getNode("Burglary");
	final Function burgProb = solver.query(burgVar);
	System.out.println("Burglary=TRUE Calc.: " + burgProb);
    }

}
