import java.util.*;
import java.io.*;

import edu.mit.six825.bn.functiontable.*;
import edu.mit.six825.bn.bayesnet.*;
/**
 * Compute CPTs for a network given data using maximum likelihood
 * estimation.
 *
 * @author drkp
 */
public class MLEstimate {
    
    private MLEstimate() {
    }

    public static void computeCPTs(BayesNet bn, BNData data) {
        BNUtils.wipeCPTs(bn);

        for (Iterator it =
                 bn.nodes.getNodesWithTopologicalOrdering().iterator();
             it.hasNext();) {
            BayesNetNode node = (BayesNetNode) it.next();

            // Compute each entry in the CPT
            FunctionVariableSet vars = node.cpt.variables;
            double [] entries = new double[vars.cartesianProductSize()];

            for (Iterator it2 = vars.assignmentIterator(); it2.hasNext();) {
                Assignment a = (Assignment) it2.next();
                Assignment parentAssignment = a.subtract(node.var);

                float parentCount = data.countMatching(parentAssignment);
                float count = data.countMatching(a);
                
                int index = a.computePosition();
//                 if (parentCount != 0) {
//                     entries[index] = count / parentCount;
//                 } else {
//                     System.out.println("Setting CPT to zero because parents not observed.");
//                     entries[index] = 0.0;
//                 }

                // Use Laplacian correction in case of
                // zero-probability events
                entries[index] = (count + 1.0) / (parentCount + 2.0);
            }
            node.cpt = new Function(vars, entries);
        }
    }

}
