import edu.mit.six825.bn.functiontable.*;
import edu.mit.six825.bn.bayesnet.*;
import edu.mit.six825.bn.inputs.*;
import java.io.*;

import java.util.Locale;
import java.text.NumberFormat;
import java.text.DecimalFormat;

import java.util.*;

/**
 * Hacking on things for part 3.
 *
 * @author nocturne
 */
public class part3 {
    public static boolean FASTDEBUGGING = false;

    public static void main(String[] args) {

	// demoOps();
	// demoOpUse();
	// demoSearchStep();
        // demoInitialNetworks();
        // demoGreedySearch();

        runExperiments();
        // loopRandomExperiments();
    }

    public static void runExperiments() {
	boolean writeFiles = true;

        List nets = new ArrayList();
        nets.add("a");
	if (!FASTDEBUGGING) { nets.add("b"); }
        
        List initialConfigs = new ArrayList();
        initialConfigs.add("given");
        initialConfigs.add("unconnected");
        initialConfigs.add("connected");
        initialConfigs.add("random1");
        initialConfigs.add("random2");
        initialConfigs.add("random3");

	List searchTypes = new ArrayList();
	searchTypes.add(Boolean.TRUE);  // == greedy
	searchTypes.add(Boolean.FALSE); // == first-ascent

	List scoreTypes = new ArrayList();
	scoreTypes.add(Boolean.TRUE); // == BIC
	scoreTypes.add(Boolean.FALSE); // == AIC

	List results =
	    doExperimentBatch(nets, initialConfigs, searchTypes,
			      scoreTypes, writeFiles);

	System.out.println("");
	logResults(results, "part3-basic.log");
    }

    public static void loopRandomExperiments() {
	while (true) {

	    List nets = new ArrayList();
	    nets.add("a");
	    if (!FASTDEBUGGING) { nets.add("b"); }
	    
	    List initialConfigs = new ArrayList();
	    initialConfigs.add("random");

	    List searchTypes = new ArrayList();
	    searchTypes.add(Boolean.TRUE);  // == greedy
	    searchTypes.add(Boolean.FALSE); // == first-ascent

	    List scoreTypes = new ArrayList();
	    scoreTypes.add(Boolean.TRUE); // == BIC
	    scoreTypes.add(Boolean.FALSE); // == AIC

	    List results =
		doExperimentBatch(nets, initialConfigs, searchTypes,
				  scoreTypes, false /*writeFiles*/);

	    logResults(results, "part3-random.log");
	}
    }

    // Append results to the given log file
    public static void logResults(List results, String filename) {
	try {
            FileWriter fw = new FileWriter(filename, true /* append */);
            PrintWriter pw = new PrintWriter(fw);

	    pw.println("# the following was logged at "
		       + (new Date()).toString());

	    String hname = java.net.InetAddress.getLocalHost().getHostName();
	    pw.println("# on host " + hname);

	    for (Iterator rit = results.iterator(); rit.hasNext();) {
		pw.println( ((ExpResult) rit.next()).toCSV());
	    }
	    pw.flush();
	    System.out.println("Appended exp results to " + filename);
        } catch (IOException e) {
            System.out.println("IO exception: " + e);
	    System.exit(1);
        }
    }


    /**
     * Given a particular initial network state, iterate over all
     * possible combinations of:
     *   -- specified nets ("a" and/or "b")
     *   -- specified initial configs ("random", "given", etc)
     *   -- the specified scoring metrics (e.g. BIC (true), AIC (false))
     *   -- the specified search types (e.g. greedy (true), first-asc (false))
     *
     * Return a list of the experiment results.
     */
    public static List /* [ExpResult] */ doExperimentBatch(List nets,
					    List initialConfigs,
					    List searchTypes,
					    List scoreTypes,
					    boolean createFiles) {
	List allresults = new ArrayList();
	
        for (Iterator it = nets.iterator(); it.hasNext();) {
            String netStr = (String) it.next();
	    
            for (Iterator it2 = initialConfigs.iterator(); it2.hasNext();) {
                String icStr = (String) it2.next();

		List results = compareMethods(netStr, icStr,
					      searchTypes, scoreTypes,
					      createFiles);

		/* If we wanted to compare how our search methods
		 * perform on a given initial random structure, this
		 * would be the place to do it.
		 */

		allresults.addAll(results);
	    }
        }

	return(allresults);
    }

    /**
     * Given a particular initial network state, iterate over all
     * possible combinations of the specified scoring metrics (e.g.
     * BIC, AIC) and the specified search types (e.g. first, greedy).
     *
     * Return a list of the experiment results.
     */
    public static List /* [ExpResult] */ compareMethods(String netStr,
							String icStr,
							List searches,
							List scores,
							boolean createFiles) {
	List results = new ArrayList();
	Long structSeed = null;

	for (Iterator searchIt = searches.iterator(); searchIt.hasNext();) {
	    boolean doGreedy = ((Boolean)searchIt.next()).booleanValue();

	    for (Iterator scoreIt = scores.iterator(); scoreIt.hasNext();) {
		boolean useBIC = ((Boolean)scoreIt.next()).booleanValue();
	    
		BNStructure bns =
		    new BNStructure(BNUtils.getWipedNet(netStr));
		BNData data =
		    new BNData(bns.bn.nodes.getFunctionVariableSet());
		data.readFile("data/data" + netStr.toUpperCase() + ".dat");
	    
		if (icStr.equals("given")) {
		    // do nothing, use the structure as is
		} else if (icStr.equals("unconnected")) {
		    bns.makeUnconnected();
		} else if (icStr.equals("connected")) {
		    bns.makeFullyConnected();
		} else {
		    // Assume this means random, use k=n
		    int k = bns.bn.nodes.size();
		    if (structSeed == null) {
			// Whip up a brand new random network, and
			// remember how to whip it up again later
			structSeed =
			    new Long(bns.makeRandomlyConnected(k));
		    } else {
			bns.makeRandomConnWithSeed(k,
						   structSeed.longValue());
		    }
		}
		
		ExpResult exp = doExperiment(netStr, icStr, bns, data,
					     doGreedy, useBIC,
					     createFiles);
		results.add(exp);
	    }
	}
	return(results);
    }


    // Do an experiment. This will leave the network structure in an
    // undefined state, but should make no other permanent changes
    // (aside from logging data to several files).
    public static ExpResult doExperiment (String netStr, String icStr,
					  BNStructure bns, BNData data,
					  boolean greedyRatherThanFirstAscent,
					  boolean useBIC,
					  boolean createFiles) {
	boolean oldUseBIC = bns.useBIC;
	bns.useBIC = useBIC;

	ExpResult exp = new ExpResult();
	exp.greedy   = greedyRatherThanFirstAscent;
	exp.usedBIC  = useBIC;
	exp.dataname = netStr;
	exp.initname = icStr;

	String searchStr = (greedyRatherThanFirstAscent) ? "gr"  : "fa";
	String scoreStr  = (useBIC)                      ? "bic" : "aic";

	String searchWords =
	    (greedyRatherThanFirstAscent) ? "Greedy" : "First-Ascent";
	    
    
	String basename = "figures/part4-net" + netStr + "-" + icStr;

	String filename = basename + "-" + searchStr + "-" + scoreStr;

	if (createFiles) {
	    bns.printDotFile(data, basename + "-initial",
			     "Net " + netStr.toUpperCase() +
			     " Initial structure");
	}
	
                
	long startTime = System.currentTimeMillis();
	if (greedyRatherThanFirstAscent) {
	    exp.r = bns.greedyStructureSearch(data, 0);
	} else {
	    exp.r = bns.firstAscentStructureSearch(data, 0);
	}

	long endTime = System.currentTimeMillis();
	exp.ms  = (endTime-startTime);

	System.out.println(filename + ": required " + exp.r.steps +
			   " steps, " + exp.ms + "ms");


	if (createFiles) {
	    bns.printDotFile(data, filename,
			     "Net " + netStr.toUpperCase() +
			     " " + searchWords +
			     " Solution (initial structure " + icStr +
			     ")");
	    BNUtils.DumpToFile(BNUtils.latexNet(bns.bn.nodes),
			       filename + ".tex");
	}

	bns.useBIC = oldUseBIC;
	return(exp);
    }

    public static class ExpResult {
	public boolean greedy; // greedyRatherThanFirstAscent
	public boolean usedBIC; // was BIC scoring used?

	public BNStructure.SearchResult r;

	public String initname;
	public String dataname;
	public long   ms;    // milliseconds elapsed

	public ExpResult () {
	}

	public String toCSV() {
	    String searchtype = (greedy)  ? "greedy" : "first-ascent";
	    String scoretype  = (usedBIC) ? "BIC"    : "AIC";

	    String sep = ",\t";
	    String s = (dataname
			+ sep + initname
			+ sep + searchtype
			+ sep + scoretype
			+ sep + r.score
			+ sep + ms
			+ sep + r.steps
			+ sep + r.dimension
			+ sep + r.structuralPenalty
			);
	    return(s);
	}

	public String toString() {
	    NumberFormat form = NumberFormat.getInstance(Locale.US);
	    if (form instanceof DecimalFormat) {
		((DecimalFormat) form).applyPattern("0.00####");
	    }

	    String searchtype = (greedy)  ? "greedy" : "first-ascent";
	    String scoretype  = (usedBIC) ? "BIC"    : "AIC";

	    String s = ("Data " + dataname
			+ ", initial structure " + initname
			+ " (" + searchtype + " search) "
			+ scoretype + " = " + form.format(r.score)
			+ ", in " + ms + " ms, "
			+ r.steps + " steps");
	    return(s);
	}
    }
}
