/*
 * HomologHMM 1.05
 * (c) Lukas Kll
 * Distributable under GPL license.
 * See terms of license at gnu.org.
 */
package se.ki.cgb.labeledhmm;

import java.util.*;
import java.io.*;
import java.text.*;
import org.biojava.bio.dp.*;
import org.biojava.bio.dp.onehead.*;
import org.biojava.bio.symbol.*;

/** This class holds and calculates the posterior label probabilities of an sequence or alignment.
 * @author Lukas.Kall@cgb.ki.se
 * @version $Revision: 1.6 $
 */
public class LabelProbMatrix extends SingleDPMatrix {

protected HashMap label2states;
protected int[] state2labelIx;
protected Set labels;
//protected State[] states;
//protected SymbolList[] symList;

//public double[][] scores;

public static final long serialVersionUID = 9080834;

  public LabelProbMatrix(DP dp, SymbolList seq) {
	super(dp,seq);
    mapLabels(dp.getStates());
	scores = new double[seq.length()][labels.size()];
  }
  
  public void setValues(FBMatrix fb) {
	int seqlen = symList[0].length();
	// Calculate p(label =lab | sequence) = sum_j(with label =lab) (p( state=j | sequence))
	// = sum(forward*backward/p(sequence)
	for (int i = 0;i<seqlen;i++) {
		int type=0;
		double [] gamma = fb.getGamma(i);
		for (Iterator it = labels.iterator(); it.hasNext(); type++) {
			String lab = (String) it.next();
			if (gamma == null) {
				scores[i][type] = 0.0;					
//				scores[i][type] = 1.0/labels.size();					
			} else {
				Iterator indexes = ((HashSet) label2states.get(lab)).iterator();
				double sum = 0.0;
				while (indexes.hasNext()) {
					int j = ((Integer)indexes.next()).intValue();
					sum += gamma[j];
				}
				scores[i][type] = sum;
			}
			if (scores[i][type]>1.0) {
				System.err.println("Error detected at pos=" + i +", label=" + lab + ", Symbol is " + symList[0].symbolAt(i+1).getName()+", prob=" + scores[i][type]);
				scores[i][type] = 1.0;											
			}
			if (scores[i][type]<0) {
				System.err.println("Error detected at " + i +", " + type + ", Symbol is " + symList[0].symbolAt(i+1).getName());
				scores[i][type] = 0;											
			}
		}
	}  	
  }
  
	public LabelProbMatrix(DP dp, SingleDPMatrix forward,SingleDPMatrix backward) {
		super(dp,forward.symList()[0]);
		initiateMatrix(forward,backward);
	}

	public void initiateMatrix(SingleDPMatrix forward,SingleDPMatrix backward) {
		mapLabels(forward.states());
		
//		states = forward.states();
//		symList = forward.symList();
		int seqlen = symList[0].length();
		double Px = forward.getScore();
		// Calculate p(label =lab | sequence) = sum_j(with label =lab) (p( state=j | sequence))
		// = sum(forward*backward/p(sequence)
		scores = new double[seqlen][labels.size()];
		int firstAA=0,lastAA=seqlen-1;
		while (symList[0].symbolAt(firstAA+1).getName().equals("[]")) firstAA++;
		while (symList[0].symbolAt(lastAA+1).getName().equals("[]")) lastAA--;
		for (int i = 0;i<seqlen;i++) {
			int type=0;
			for (Iterator it = labels.iterator(); it.hasNext(); type++) {
				String lab = (String) it.next();
				if (i<firstAA || i>lastAA ||  symList[0].symbolAt(i+1).getName().equals("[]")) {
					scores[i][type] = - Math.log(labels.size());					
				} else if (i>=1 && i<=seqlen-1) {
					Iterator indexes = ((HashSet) label2states.get(lab)).iterator();
					double unlog =0.0;
					while (indexes.hasNext()) {
						int j = ((Integer)indexes.next()).intValue();
						unlog += Math.exp(forward.scores[i][j] + backward.scores[i][j] - Px);
						if (Double.isInfinite(unlog) && unlog >0) {
							unlog=0;
						}
					}
					scores[i][type] = Math.log(unlog);
				}
				if (Double.isNaN(scores[i][type])) {
					System.err.println("NaN detected at " + i +", " + type + ", Symbol is " + symList[0].symbolAt(i).getName());
					scores[i][type] = Double.NEGATIVE_INFINITY;											
				}
			}
		}
/*
		// Interpolate gaps in matrix
		int lastBefore = firstAA;
		int pos = lastBefore+1,firstAfter=0;
		while (pos<lastAA) {
			if (!symList[0].symbolAt(pos).getName().equals("[]")) {
				if (pos != lastBefore +1) {
					firstAfter =pos;
					// Here we have a gap that neads to be filled in
					for (int i = lastBefore+1;i<firstAfter;i++) {
						for (int type =0; type<labels.size(); type++) {
							scores[i][type] = Math.log((Math.exp(scores[lastBefore][type]) * (i - lastBefore)
									+ Math.exp(scores[firstAfter][type])* (firstAfter - i))
									/ (firstAfter - lastBefore));
						}
					}					
				}
				lastBefore = pos;
			}
			pos++;
		}
	*/
	}
	
	public void mapLabels(State[] states) {
		String lab = "";
		label2states = new HashMap();
		for (int j=0; j<states.length; j++) {
			try { 
				lab = (String) states[j].getAnnotation().getProperty("Label"); 
				if (!label2states.containsKey(lab)) {
					HashSet hs = new HashSet();
					hs.add(new Integer(j));
					label2states.put(lab,hs);
				} else {
					HashSet hs = (HashSet) label2states.get(lab);
					hs.add(new Integer(j));
				}
			}
			catch (Exception e) { ; }
		}
		labels = label2states.keySet();
		state2labelIx = new int[states.length];
		int ix = 0, stateIx = 0;
		for (Iterator it = labels.iterator(); it.hasNext(); ix++) {
			lab = (String) it.next();
			HashSet hs = (HashSet) label2states.get(lab);
			for (Iterator it2 = hs.iterator(); it2.hasNext();) {
				stateIx = ((Integer) it2.next()).intValue();
				state2labelIx[stateIx] = ix;
			}
		}
	}

    public void printPLP(PrintStream os,boolean forFirstSeq) {
    	DecimalFormat form = new DecimalFormat("0.00E00");
    	String fieldSpace ="          ";
    	String lab;
 
    	os.print("#" + fieldSpace.substring(1));
		for (Iterator it = labels.iterator(); it.hasNext(); ) {
			lab = (String) it.next();
			os.print(lab + fieldSpace.substring(1));
		}
		os.println();
		String out;
		int pos = 0;
		for (int sym=0;sym<symList[0].length(); sym++) {
			if (forFirstSeq && symList[0].symbolAt(sym+1).getName().length()<3) continue;
			pos++;
			out = Integer.toString(pos);
			os.print(out + fieldSpace.substring(out.length()));
			for (int labIx=0;labIx<labels.size();labIx++){
				out = form.format(scores[sym][labIx]);
				os.print(out + fieldSpace.substring(out.length()));
			}
			os.println();
		}		
    }

	public void weight(double weight) {
		for (int sym=0;sym<symList[0].length(); sym++) {
			for (int lab=0;lab<labels.size();lab++){
				scores[sym][lab] *=weight;
			}
		}
	}
	
	public void join(LabelProbMatrix other, double weight) {
		for (int sym=0;sym<symList[0].length(); sym++) {
			for (int lab=0;lab<labels.size();lab++){
				scores[sym][lab] += other.scores[sym][lab]*weight;
				if (Double.isNaN(scores[sym][lab])) {
					System.err.println("NaN detected at " + sym +", " + lab + ", Symbol is " + symList[0].symbolAt(sym+1).getName());
					scores[sym][lab] = 0.0;					
				}							
			}
		}
	}

	public void normalize() {
		for (int sym=0;sym<symList[0].length(); sym++) {
			double norm = 0.0; 
			for (int lab=0;lab<labels.size();lab++){
				norm += scores[sym][lab];
			}
			if (norm>0.0) {
				for (int lab=0;lab<labels.size();lab++){
					scores[sym][lab] /= norm;
				}
			}
		}
	}

	/**
	 * @return
	 */
	public HashMap getLabel2states() {
		return label2states;
	}
	
	/**
	 * @return
	 */
	public Set getLabels() {
		return labels;
	}
	
	/**
	 * @return
	 */
	public State[] getStates() {
		return states;
	}
	
	/**
	 * @return
	 */
	public SymbolList[] getSymList() {
		return symList;
	}
}
