import java.util.Iterator;

import java.util.Locale;
import java.text.NumberFormat;
import java.text.DecimalFormat;

import edu.mit.six825.bn.functiontable.*;
import edu.mit.six825.bn.bayesnet.*;
import edu.mit.six825.bn.inputs.*;

// Needed by DumpToFile
import java.io.*;


public class BNUtils {

    // Constants used by my Latex code
    private static String tblEndRow = " \\\\\n";
    private static String tblColSep = " & ";
    private static String tblRowSep = "\\hline\n";

    private BNUtils() {
    }

    /** 
     * Given a BayesNet object, mutate all its nodes so their CPTs are
     * completely zeroed.
     */
    public static void wipeCPTs (BayesNet bn) {
	for (Iterator i = bn.nodes.iterator(); i.hasNext(); ) {
	    BayesNetNode n = (BayesNetNode)i.next();

	    FunctionVariableSet vars = n.cpt.variables;
	    double [] zerovals = new double [vars.cartesianProductSize()];

	    n.cpt = new Function(vars, zerovals);
	}
    }

    /* Not pretty, but it displays contents of a BayesNet for inspection. */
    public static void printNet (BayesNetNodeSet nodes) {
	for (Iterator i = nodes.iterator(); i.hasNext(); ) {
	    BayesNetNode n = (BayesNetNode) i.next();
	    System.out.println(n + " <== " + n.parents + n.cpt);
	    System.out.println("");
	}
    }

    /**
     * Compute the log-likelihood of a particular data assignment
     * given the CPTs in the Bayes net.
     */
    public static double logLikelihood (BayesNet bn, BNData data) {
        double ll = 0.0;

	// If we wanted to do more optimization, this would be the
	// place to do it. Basically, invert these loops, and cache
	// the total log-likelihood (over all the data) of nodes whose
	// CPTs have not changed since the last time their
	// log-likelihood was calculated. This would, of course,
	// require us to figure out which nodes actually needed MLE
	// CPT-regeneration after a BNOp, and only regenerate *those*
	// CPTs, blah blah blah. It's enough work that it's probably
	// not worth it, given the scope of this project.
        for (Iterator it = data.iterator(); it.hasNext();) {
            Assignment a = (Assignment) it.next();
            for (Iterator it2 = bn.nodes.iterator(); it2.hasNext();) {
                BayesNetNode node = (BayesNetNode) it2.next();
                ll += Math.log(node.cpt.evaluate(a)) / Math.log(2);
            }
        }
        
        return ll;
    }

    /**
     * Compute the dimension of a Bayes net: the sum of the number of
     * entries in each CPT. Note that this uses the fact that one
     * value per row of the CPT does not need to be stored (since it
     * can be recomputed via the fact that the row sums to 1).
     */
    public static int dimension (BayesNet bn) {
        int d = 0;

        for (Iterator it = bn.nodes.iterator(); it.hasNext();) {
            BayesNetNode node = (BayesNetNode) it.next();
            int nodeDimension = 1;
            for (Iterator it2 = node.parents.iterator(); it2.hasNext();) {
                BayesNetNode parent = (BayesNetNode) it2.next();
                nodeDimension *= parent.var.domain.size();
            }
            nodeDimension *= (node.var.domain.size()-1);
            d += nodeDimension;
        }
        
        return d;
    }

    /**
     * Compute the structural penalty of the BIC score.
     * (log_2 M / 2 * Dim[G])
     */
    public static double structuralPenalty (BayesNet bn, BNData data) {
        return (Math.log(data.size()) / Math.log(2)) / 2.0 *
            dimension(bn);
    }

    /**
     * Compute the BIC score (log likelihood - structural penalty) of
     * a Bayes net given data.
     */
    public static double BICScore (BayesNet bn, BNData data) {
        return logLikelihood(bn, data) - structuralPenalty(bn, data);
    }
    
    /**
     * Compute the AIC score (log likelihood - dimension) of
     * a Bayes net given data.
     */
    public static double AICScore (BayesNet bn, BNData data) {
        return logLikelihood(bn, data) - dimension(bn);
    }
    
    /**
     * Return the Bayes Net object, with nulled CPTs, for the given
     * dataset. Use the assumed-structure from task 1.
     */
    public static BayesNet getWipedNet (String netname) {
	BayesNet bn;
	if (netname.equals("a")) {
	    bn = Nets.getNetA();
	    BNUtils.wipeCPTs(bn);
	} else {
	    bn = getWipedNetB();
	}
	return(bn);
    }

    /**
     * Return node named "name" from given set of nodes.
     */
    public static BayesNetNode getNodeByName (BayesNetNodeSet nodes,
					      String name) {
	for (int i=0; i < nodes.size(); i++) {
	    if (name.equals(nodes.getNode(i).var.toString())) {
		return(nodes.getNode(i));
	    }
	}
	throw new IllegalArgumentException("Unrecognized var /"
					   + name + "/");
    }

    public static BayesNetNodeSet NetB_VarParents(BayesNetNodeSet nodes,
						  FunctionVariable var) {
	// Not the most compact implementation, but it's quick and
	// easy for now. Note, this assumes that nodes are being
	// constructed in topological order from roots/priors down to
	// leaves: /nodes/ has to already contain the node objects of
	// /var/'s parents.
	String v = var.toString();

	BayesNetNode[] parents = null;

	if (v.equals("A") || v.equals("E")) {
	    return(BayesNetNodeSet.EMPTY_BAYES_NET_VARIABLE_SET);
	} else if (v.equals("B") || v.equals("C")) {
	    parents = new BayesNetNode[] {getNodeByName(nodes, "A")};
	} else if (v.equals("D")) {
	    parents = new BayesNetNode[] {getNodeByName(nodes, "B"),
					  getNodeByName(nodes, "C")};
	} else if (v.equals("F")) {
	    parents = new BayesNetNode[] {getNodeByName(nodes, "E")};
	} else if (v.equals("G")) {
	    parents = new BayesNetNode[] {getNodeByName(nodes, "F")};
	} else if (v.equals("H")) {
	    parents = new BayesNetNode[] {getNodeByName(nodes, "G")};
	}

	if (parents == null) {
	    throw new IllegalArgumentException("Unrecognized var /"
					       + var.toString() + "/");
	}
	return(new BayesNetNodeSet(parents));
    }

    public static BayesNet getWipedNetB() {
	// Important: this array is in topological order, with all of
	// a given node's parents listed EARLIER than the node itself.
	String[] varorder = new String[] {"A", "E", "B", "C",
					  "F", "D", "G", "H"};

	BayesNetNodeSet nodes = BayesNetNodeSet.EMPTY_BAYES_NET_VARIABLE_SET;

	for(int i=0; i < varorder.length; i++) {
	    FunctionVariable var =
		new FunctionVariable(varorder[i], Domain.BOOLEAN_DOMAIN);
	    BayesNetNodeSet parents = NetB_VarParents(nodes, var);
	    FunctionVariableSet cptvars =
		FunctionVariableSet.union(parents.getFunctionVariableSet(),
					  var);

	    double[] cptvals = new double[cptvars.cartesianProductSize()];

	    Function cpt = new Function(cptvars, cptvals);

	    BayesNetNode newnode = new BayesNetNode(var, parents, cpt);

	    nodes = new BayesNetNodeSet(nodes, newnode);
	}

	return new BayesNet(nodes);
    }

    /* Output contents of a BayesNet with Latex formatting. */
    public static String latexNet (BayesNetNodeSet nodes) {
	String answer = "";

	for (Iterator i = nodes.iterator(); i.hasNext(); ) {
	    String varlatex = "";
	    BayesNetNode n = (BayesNetNode) i.next();
	    FunctionVariableSet parentvars =
		n.parents.getFunctionVariableSet();

	    // Table format declaration
	    varlatex += "\\begin{center}\n";

	    varlatex += "\\begin{tabular}{";
	    int cols = 0;
	    for (int j=0; j < parentvars.size(); j++) {
		cols++;
		if (j == 0) {
		    varlatex += "|c|";
		} else {
		    varlatex += "c|";
		}
	    }
	    varlatex += "|";
	    for (int j=0; j < n.var.domain.size(); j++) {
		cols++;
		varlatex += "c|";
	    }
	    varlatex += "}\n";

	    varlatex += tblRowSep;
	    varlatex += ("\\multicolumn{" + cols + "}{|c|}{"
			 + "Estimated CPT for \\textbf{" + n.var + "}}"
			 + tblEndRow);

	    // Header line
	    varlatex += tblRowSep;
	    {
		String line = "";
		boolean NeedSeparator = false;
		for (int j=0; j < parentvars.size(); j++) {
		    if (NeedSeparator) {
			line += tblColSep;
		    } else {
			NeedSeparator = true;
		    }
		    line += "\\textbf{" + parentvars.getVariable(j) + "}";
		}
		for (int j=0; j < n.var.domain.size(); j++) {
		    if (NeedSeparator) {
			line += tblColSep;
		    } else {
			NeedSeparator = true;
		    }
		    line += ("\\textbf{P("
			     + n.var + "="
			     + latexComparable(n.var.domain.getValue(j))
			     + ")}");
		}
		varlatex += line + tblEndRow;
	    }
	    varlatex += tblRowSep + tblRowSep;

	    if (parentvars.size() > 0) {
		for (Iterator ipar = parentvars.assignmentIterator();
		     ipar.hasNext(); ) {
		    String line = "";
		    Assignment parAss = (Assignment) ipar.next();

		    varlatex += latexCPTRow(n, parAss) + tblRowSep;
		}
	    } else {
		// slap together a null assignment
		Assignment parAss = new Assignment(new FunctionVariableSet(n.var));
		parAss = parAss.subtract(parAss);

		varlatex += latexCPTRow(n, parAss) + tblRowSep;
	    }
		
	    varlatex += "\\end{tabular}\n\\end{center}\n";
	    answer += varlatex;
	}
	return(answer);
	
    }

    public static String latexCPTRow (BayesNetNode n,
				      Assignment parents) {
	NumberFormat form = NumberFormat.getInstance(Locale.US);
	if (form instanceof DecimalFormat) {
	    ((DecimalFormat) form).applyPattern("0.00####");
	}

	String line = latexAssignment(parents);
	for (int j=0; j < n.var.domain.size(); j++) {
	    double prob =
		n.cpt.evaluate(new Assignment(parents,
					      n.var,
					      n.var.domain.getValue(j)));
	    if (!line.equals("")) {
		line += tblColSep;
	    }
	    line += form.format(prob);
	}
	
	return(line + tblEndRow);
    }

    public static String latexAssignment (Assignment a) {
	String tex = "";
	for (int i=0; i < a.variables.size(); i++) {
	    if (i != 0) {
		tex += tblColSep;
	    }
	    FunctionVariable v = a.variables.getVariable(i);
	    tex += "\\textbf{" + latexComparable(a.getAssignedValue(v)) + "}";
	}
	return(tex);
    }

    public static String latexComparable (Comparable b) {
	if (b.equals(ComparableBoolean.TRUE)) {
	    return("1");
	} else if (b.equals(ComparableBoolean.FALSE)) {
	    return("0");
	} else {
	    return(b.toString());
	}
    }

    // This doesn't belong in this class, but it's not worth starting
    // a new utils class (yet)
    public static void DumpToFile(String s, String filename) {
	try {
	    File outputFile = new File(filename);
	    FileWriter out = new FileWriter(outputFile);
	    out.write(s);
	    out.close();
	    System.out.println("Data dumped to file " + filename);
	}
	catch (IOException e) {
	    System.out.println("ERROR: " + e);
	    System.exit(1);
	}
    }

		
 }
