import edu.mit.six825.bn.functiontable.*;
import edu.mit.six825.bn.bayesnet.*;
import edu.mit.six825.bn.inputs.*;

import java.util.*;

/**
 * Hacking on things for part 2.
 *
 * @author nocturne
 */
public class part2 {
    public static void main(String[] args) {

	// demoOps();
	// demoOpUse();
	// demoSearchStep();
        // demoInitialNetworks();
        // demoGreedySearch();
        runExperiments();
    }

    public static void demoSearchStep() {
	BNStructure bns = new BNStructure(BNUtils.getWipedNet("b"));

	BNData data = new BNData(bns.bn.nodes.getFunctionVariableSet());
	data.readFile("data/dataB.dat");

        bns.printDotFile(data, "demoSearchStep-start",
                         "Given structure");
        
	double origscore = bns.score(data);
	BNUtils.wipeCPTs(bns.bn); // hopefully unnecessary

	System.out.println("These are all the legal ops at this point:");
	bns.showAllOps();

	printLine();
	BNOp nextop = bns.getNextOp(data, 0, false);
	bns.applyOp(nextop);
	double newscore = bns.score(data);
	System.out.println("Best possible op is " + nextop);
	System.out.println("The score improves from " + origscore + " to "
			   + newscore);

        bns.printDotFile(data, "demoSearchStep-end",
                         "After one op");
	
    }

    public static void demoOpUse() {
	BayesNet tmpbn = BNUtils.getWipedNet("b");

	BayesNetNode nodeA = tmpbn.nodes.getNode("A");
	BayesNetNode nodeB = tmpbn.nodes.getNode("B");
	BayesNetNode nodeC = tmpbn.nodes.getNode("C");
	BayesNetNode nodeD = tmpbn.nodes.getNode("D");

	BayesNetNodeSet tmpset =
	    new BayesNetNodeSet(new BayesNetNode[] {nodeA, nodeB,
						    nodeC, nodeD});

	BNStructure bns = new BNStructure(tmpset);

	System.out.println("Testing structure moves...");
	BNOp op = new BNOp(MoveType.REVERSE, nodeA, nodeB);
	BNOp op2 = new BNOp(MoveType.REVERSE, nodeA, nodeB);

	if (op == op2) {
	    System.out.println("Look ==");
	}

	bns.applyOp(op);
	bns.applyOp(op2);
	bns.unrollOp(op2);
	bns.unrollOp(op);

	// BNUtils.printNet(bns.bn.nodes);
	op = bns.getRandomOp();
	System.out.println("Got random op " + op);
	bns.applyOp(op);
	// BNUtils.printNet(bns.bn.nodes);

	op2 = bns.getRandomOp();
	System.out.println("Got random op " + op2);
	bns.applyOp(op2);
	// BNUtils.printNet(bns.bn.nodes);

    }

    public static void printLine() {
	System.out.println("======================================");
    }

    public static void demoOps() {
	BayesNet bn = BNUtils.getWipedNet("b");

	BayesNetNode nodeA = bn.nodes.getNode("A");
	BayesNetNode nodeB = bn.nodes.getNode("B");
	BayesNetNode nodeC = bn.nodes.getNode("C");
	BayesNetNode nodeD = bn.nodes.getNode("D");
	// BNData data = new BNData(bn.nodes.getFunctionVariableSet());
	// data.readFile("data/dataB.dat");

	BNOp op = new BNOp(MoveType.ADD, nodeA, nodeB);
	showOp(op);
	op = new BNOp(MoveType.DELETE, nodeA, nodeB);
	showOp(op);
	op = new BNOp(MoveType.REVERSE, nodeA, nodeB);
	showOp(op);
	op = new BNOp(MoveType.REVERSE, nodeB, nodeD);
	showOp(op);
	System.out.println("Making B a parent of C");
	nodeC.parents = new BayesNetNodeSet(nodeC.parents, nodeB);
	showOp(op);
    }

    public static void showOp(BNOp op) {
	if (op.isValid()) {
	    System.out.println("Got valid move " + op);
	} else {
	    System.out.println("Got INvalid move " + op);
	}
    }

    public static void demoInitialNetworks() {
	BNStructure bns = new BNStructure(BNUtils.getWipedNet("b"));
  	BNData data = new BNData(bns.bn.nodes.getFunctionVariableSet());
	data.readFile("data/dataB.dat");

        bns.printDotFile(data, "demoInitialNetworks-start",
                         "Given structure");

        bns.makeUnconnected();
        bns.printDotFile(data, "demoInitialNetworks-unconnected",
                         "Unconnected");

        bns.makeFullyConnected();
        bns.printDotFile(data, "demoInitialNetworks-connected",
                         "Fully connected");

        bns.makeRandomlyConnected(bns.bn.nodes.size());
        bns.printDotFile(data, "demoInitialNetworks-random-n",
                         "Random (n edges)");
    }

    public static void demoGreedySearch() {
	BNStructure bns = new BNStructure(BNUtils.getWipedNet("b"));
	BNData data = new BNData(bns.bn.nodes.getFunctionVariableSet());
	data.readFile("data/dataB.dat");

        bns.makeFullyConnected();
        bns.printDotFile(data, "demoGreedySearch-start",
                         "Starting configuration");

        bns.greedyStructureSearch(data, 0);
        bns.printDotFile(data, "demoGreedySearch-end",
                         "Output network");
    }

    public static void runExperiments() {
        List nets = new ArrayList();
        nets.add("a");
        nets.add("b");
        
        List initialConfigs = new ArrayList();
        initialConfigs.add("given");
        initialConfigs.add("unconnected");
        initialConfigs.add("connected");
        initialConfigs.add("random1");
        initialConfigs.add("random2");
        initialConfigs.add("random3");

        for (Iterator it = nets.iterator(); it.hasNext();) {
            String netStr = (String) it.next();

            for (Iterator it2 = initialConfigs.iterator(); it2.hasNext();) {
                String icStr = (String) it2.next();

                BNStructure bns =
                    new BNStructure(BNUtils.getWipedNet(netStr));
                BNData data =
                    new BNData(bns.bn.nodes.getFunctionVariableSet());
                data.readFile("data/data" + netStr.toUpperCase() + ".dat");

                if (icStr.equals("given")) {
                    // do nothing, use the structure as is
                } else if (icStr.equals("unconnected")) {
                    bns.makeUnconnected();
                } else if (icStr.equals("connected")) {
                    bns.makeFullyConnected();
                } else {
                    // Assume this means random, use k=n
                    bns.makeRandomlyConnected(bns.bn.nodes.size());
                }

                String filename = "figures/part2-net" + netStr +
                    "-" + icStr;

                bns.printDotFile(data, filename + "-initial",
                                 "Net " + netStr.toUpperCase() +
                                 " Initial structure (" + icStr +
                                 ")");

                
                long startTime = System.currentTimeMillis();
                BNStructure.SearchResult r = bns.greedyStructureSearch(data, 0);
		int steps = r.steps;
                long endTime = System.currentTimeMillis();
                System.out.println(filename + ": required " + steps +
                                   " steps, " + (endTime-startTime) + "ms");
                

                bns.printDotFile(data, filename,
                                 "Net " + netStr.toUpperCase() +
                                 " Solution (initial structure " + icStr +
                                 ")");
                BNUtils.DumpToFile(BNUtils.latexNet(bns.bn.nodes),
                                   filename + ".tex");
            }
        }
    }
}
