/*
 * HomologHMM 1.05
 * (c) Lukas K�ll
 * Distributable under GPL license.
 * See terms of license at gnu.org.
 */
package se.ki.cgb.labeledhmm;

import se.ki.cgb.anhmmfile.*;
import java.util.*;

import org.biojava.bio.*;
import org.biojava.bio.seq.*;
import org.biojava.bio.symbol.*;
import org.biojava.bio.dist.*;
import org.biojava.bio.dp.*;
import org.biojava.bio.dp.onehead.*;

/**
 * Top level interface for different decoding algorithms
 * for labeled HMMs, like N-best, Viterbi and optimal accuracy
 * @author Lukas.Kall@cgb.ki.se
 * @version $Revision: 1.9 $
 */
public class AlignedDP extends SingleDP {
	private static final Symbol gap = AlphabetManager.getGapSymbol();
	protected final HashMap rawEmissions;
	protected double [][] rawTransitionScores = null;
	char [] labels;
	int [][] labelIx2states;
	int [] state2labelIx;
	char [] state2label;
	Constraints constraints = null;

	public static final long serialVersionUID = 9439139;


	protected void mapLabels(State[] states) {
		String lab = "";
		HashMap label2states = new HashMap();
		for (int j=0; j< this.getDotStatesIndex(); j++) {
			try { 
				lab = (String) states[j].getAnnotation().getProperty("Label"); 
				if (!label2states.containsKey(lab)) {
					HashSet hs = new HashSet();
					hs.add(new Integer(j));
					label2states.put(lab,hs);
				} else {
					HashSet hs = (HashSet) label2states.get(lab);
					hs.add(new Integer(j));
				}
			}
			catch (Exception e) { ; }
		}
		labels = new char [label2states.size()];
		List labList = new ArrayList(label2states.keySet());
		Collections.sort(labList);
		Iterator it =  labList.iterator();
		int ix =0;
		while (it.hasNext()) {
			labels[ix++] = ((String)it.next()).charAt(0);
		}
		labelIx2states = new int [label2states.size()][];
		ix =0;
		while(ix<labels.length) {
			HashSet hs = (HashSet) label2states.get("" + labels[ix]);
			List hsList = new ArrayList(hs);
			Collections.sort(hsList);
			it =  hsList.iterator();
			labelIx2states[ix] = new int [hs.size()];
			int j = 0;
			while (it.hasNext()) {
				labelIx2states[ix][j++] = ((Integer)it.next()).intValue();
			}
			ix++;
		}
		state2labelIx = new int[this.getDotStatesIndex()];
		state2label = new char[this.getDotStatesIndex()];
		for (int i = 0; i < labels.length; i++) {
			for (int j=0;j<labelIx2states[i].length;j++) {
				state2labelIx[labelIx2states[i][j]] = i;
				state2label[labelIx2states[i][j]] = labels[i];
			}
		}
	}


		public String doHomologyViterbi(Sequence[] sequences,double [] sWeights) throws Exception {
			lockModel();
	
			AlignmentCursor dpCursor = new AlignmentCursor(getStates(),sequences,1);
	    
			State [] states = getStates();
	
			int [][] transitions = getForwardTransitions();
			double [][] transitionScore = getForwardTransitionScores(ScoreType.PROBABILITY);
	//        for (int i = 0; i<transitionScore.length;i++) {
	//       	double [] tsv = transitionScore[i];
	//        	for (int j=0;j<tsv.length;j++)
	//        		tsv[j] *= noSeq;
	//        }
	
			int stateCount = states.length;
	
			BackPointer [] oldPointers = new BackPointer[stateCount];
			BackPointer [] newPointers = new BackPointer[stateCount];
	
			// initialize
			{
			  double [] vc = dpCursor.currentCol();
			  double [] vl = dpCursor.lastCol();
			  for (int l = 0; l < getDotStatesIndex(); l++) {
				if(states[l] == getModel().magicalState()) {
				  //System.out.println("Initializing start state to 0.0");
				  vc[l] = vl[l] = 0.0;
				  oldPointers[l] = newPointers[l] = new BackPointer(states[l]);
				} else {
				  vc[l] = vl[l] = Double.NEGATIVE_INFINITY;
				}
			  }
			  for (int l = getDotStatesIndex(); l < stateCount; l++) {
				int [] tr = transitions[l];
				double [] trs = transitionScore[l];
				double transProb = Double.NEGATIVE_INFINITY;
				double trans = Double.NEGATIVE_INFINITY;
				int prev = -1;
				for (int kc = 0; kc < tr.length; kc++) {
				  int k = tr[kc];
				  double t = trs[kc];
				  double s = vc[k];
				  double p = t + s;
				  if (p > transProb) {
					transProb = p;
					prev = k;
					trans = t;
				  }
				}
				if(prev != -1) {
				  vc[l] = vl[l] = transProb;
				  oldPointers[l] = newPointers[l] = new BackPointer(
					states[l],
					newPointers[prev],
					trans
				  );
				} else {
				  vc [l] = vl[l] = Double.NEGATIVE_INFINITY;
				  oldPointers[l] = newPointers[l] = null;
				}
			  }          
			}
	
			// viterbi
			while (dpCursor.canAdvance()) { // symbol i
			  dpCursor.advance();
			  double [] emissions = getProfileEmission(dpCursor.currentResidues(),sWeights);
			  //System.out.println(sym.getName());
			  double [] currentCol = dpCursor.currentCol();
			  double [] lastCol = dpCursor.lastCol();
			  for (int l = 0; l < states.length; l++) { // don't move from magical state
				double emission;
				if(l < getDotStatesIndex()) {
				  emission = emissions[l];
				} else {
				  emission = 0.0;
				}
				int [] tr = transitions[l];
				//System.out.println("Considering " + tr.length + " alternatives");
				double [] trs = transitionScore[l];
				if (emission == Double.NEGATIVE_INFINITY) {
				  //System.out.println(states[l].getName() + ": impossible emission");
				  currentCol[l] = Double.NEGATIVE_INFINITY;
				  newPointers[l] = null;
				} else {
				  double transProb = Double.NEGATIVE_INFINITY;
				  double trans = Double.NEGATIVE_INFINITY;
				  int prev = -1;
				  for (int kc = 0; kc < tr.length; kc++) {
					int k = tr[kc];
					double t = trs[kc];
					double s = (l < getDotStatesIndex()) ? lastCol[k] : currentCol[k];
					double p = t + s;
					if (p > transProb) {
					  transProb = p;
					  prev = k;
					  trans = t;
					}
				  }
				  if(prev != -1) {
					currentCol[l] = transProb + emission;
					newPointers[l] = new BackPointer(states[l],
					  (l < getDotStatesIndex()) ? oldPointers[prev] : newPointers[prev],
					  trans + emission);
				  } else {
					currentCol[l] = Double.NEGATIVE_INFINITY;
					newPointers[l] = null;
				  }
				}
			  }
	      
			  BackPointer [] bp = newPointers;
			  newPointers = oldPointers;
			  oldPointers = bp;
			}
	
			// find max in last row
	//		BackPointer best = oldPointers[0];
			BackPointer best = null;
			{
				double [] currentCol = dpCursor.currentCol();
				int [] tr = transitions[0];
	//			double [] trs = transitionScore[0];
				double transProb = Double.NEGATIVE_INFINITY;
				double trans = Double.NEGATIVE_INFINITY;
				int prev = -1;
				for (int kc = 0; kc < tr.length; kc++) {
					int k = tr[kc];
					double t = 0.0; // trs[kc];
					double s = currentCol[k];
					double p = t + s;
					if (p > transProb) {
					  transProb = p;
					  prev = k;
					  trans = t;
					}
				 }
				 if(prev != -1) {
					best = new BackPointer(states[0],oldPointers[prev],trans);
				 } 	
			}		
			StringBuffer pred = new StringBuffer(sequences[0].seqString());
			int i = pred.length()-1;
			while (best.state instanceof MagicalState) { best = best.back;}
			while(i>=0){
	// comment away next line if you want to include gaps in querry sequence in prediction
			  while (i>=0 && pred.charAt(i)=='-') i--;
			  while (i>=0 && best.state instanceof DotState) {best=best.back;}
			  if (i>=0 && best.state instanceof SimpleEmissionState) {
				  String lab;
				  try {lab=(String)best.state.getAnnotation().getProperty("Label");}
				  catch (Exception e) { lab = "X"; }
				  pred.setCharAt(i,lab.charAt(0));
			  }		
			  i--;
			  best = best.back;
			};
			unlockModel();
			return pred.toString();	
		}
	    /**
	     * Run a couple of prediction methods according to flags
	     */
	public void go(Sequence [] seqArr, boolean aligned, boolean doOptAcc, boolean doMaxPLP, boolean doViterbi, int printPLP,int Nbest) throws Exception  {
		boolean plpCalc = doOptAcc || doMaxPLP || (printPLP>0);
		LabelProbMatrix tot = null;
		Sequence firstSeq = null;
		double [] weights = null;
		if (aligned)
			weights = WeightScheme.getInstance().getWeights(seqArr);
		if (constraints!=null) constraints.mapConstraints(seqArr[0].length(),getDotStatesIndex(),labels,labelIx2states);

		OutputHandler.getInstance().printInit();

		int i=0;
		while (seqArr.length>i) {
			Sequence seq = seqArr[i];
			if (constraints!=null) constraints.adjust(seq);
			if (firstSeq==null)
				firstSeq=seq;
//			SymbolList[] seqL = {seq};
			Sequence[] seqL = {seq};
			LabelProbMatrix plp = null;
			if (plpCalc) plp = calcPosteriorLabelProbability(seqL);
			if (!aligned) {
				if (plpCalc) plp.normalize();
				Hashtable preds = LabeledFastaFormat.getLabels(seq);
				if (preds==null) preds = new Hashtable();
				if (doMaxPLP)
					preds.put("?M",maxPLP(plp));
				if (doViterbi)
					preds.put("?V",viterbi(seqL));
				if (doOptAcc) 
					preds.put("?O",optimalPLP(plp));
				if (Nbest>0) { 
					String[] out = doNbest(seqL,null,Nbest,false);
					for (int n=0;n<out.length;n++)
						preds.put("?"+n,out[n]);				
				}
				if (printPLP>0) {
					plp.printPLP(System.out,true);
				} else {
					LabeledFastaFormat.setLabels(seq,preds);
					OutputHandler.getInstance().printSequence(seq);
				}
			} else {
				if (tot == null) {
					if (plpCalc) {
						tot =plp;
						tot.weight(weights[i]);
					}
				} else {
					if (plpCalc) tot.join(plp,weights[i]);
				}
			}
			i++;
		}
		if (aligned) {
			if (plpCalc) tot.normalize();
			Hashtable preds = LabeledFastaFormat.getLabels(firstSeq);
			if (preds==null) preds = new Hashtable();
			if (doMaxPLP)
				preds.put("?M",maxPLP(tot));
			if (doOptAcc) 
				preds.put("?O",optimalPLP(tot));
			if (doViterbi) 
				preds.put("?V",doHomologyViterbi(seqArr,weights));
			if (Nbest>0) { 
				String[] out = doNbest(seqArr,weights,Nbest,true);
				for (int n=0;n<out.length;n++)
					preds.put("?"+n,out[n]);				
			}
			if (printPLP>0) {
				tot.printPLP(System.out,(printPLP & 2) == 0);
			} else {
				LabeledFastaFormat.setLabels(firstSeq,preds);
				OutputHandler.getInstance().printSequence(firstSeq);
			}
		}
		OutputHandler.getInstance().printTail();
	}
	
	public String viterbi(SymbolList[] seq) throws Exception {
		StatePath sp = this.viterbi(seq,ScoreType.PROBABILITY);
		String out = "";
		for(Iterator it = sp.symbolListForLabel(StatePath.STATES).iterator(); it.hasNext();)  {
			State st = (State) it.next();
			try {
				out  += (String) st.getAnnotation().getProperty("Label");
			}
			catch (Exception e) {
			}
		}
		return out;
	}

	public double [] getEmission(Symbol sym, ScoreType scoreType)
	throws IllegalSymbolException {
	  Map emissions;
	  if(scoreType == ScoreType.PROBABILITY) {
		emissions = emissionsProb;
	  } else if(scoreType == ScoreType.ODDS) {
		emissions = emissionsOdds;
	  } else if(scoreType == ScoreType.NULL_MODEL) {
		emissions = emissionsNull;
	  } else {
		throw new BioError("Unknown ScoreType object: " + scoreType);
	  }
	  double [] em = (double []) emissions.get(sym);
	  if(em == null) {
		int dsi = getDotStatesIndex();
		em = new double[dsi];
		State [] states = getStates();
		if(sym == AlphabetManager.getGapSymbol()) {
		  em[0] = 0;
		} else {
		  em[0] = Double.NEGATIVE_INFINITY;
		}
		for(int i = 1; i < dsi; i++) {
		  EmissionState es = (EmissionState) states[i];
		  Distribution dis = es.getDistribution();
		  double emS = scoreType.calculateScore(dis, sym);
		  if (emS>0.0) {
		em[i] = Math.log(emS);
		  } else {
			em[i] = Double.NEGATIVE_INFINITY;
		  }
		}
		emissions.put(sym, em);
		/*System.out.println("Emissions for " + sym);
		for(int i = 0; i < em.length; i++) {
		  System.out.println("\t" + states[i] + "\t-> " + em[i]);
		}*/
	  }
	  return em;
	}

	public double [] getProfileEmission(Symbol [] syms, double [] weights)
			throws IllegalSymbolException {
		int dsi = getDotStatesIndex();
		double [] ems = new double[dsi];
		int nums = 0;
		double doneWeight = 0;
		for (int i=0;i<syms.length;i++) {
			if ((syms[i].getName().length()>=3 && syms[i] != AlphabetManager.getGapSymbol())) {
				nums++;
				doneWeight += weights[i];
				double [] em = getRawEmission(syms[i]);
				for (int j=0;j<dsi;j++) {
					if (em[j]>0)
						ems[j] += Math.log(em[j])*weights[i];
					else
						ems[j] = Double.NEGATIVE_INFINITY;
				}
			}
		}
		if (nums<syms.length) {
			if (nums>0)  {
				for (int i=0;i<dsi;i++) {
					ems[i] = ems[i] /doneWeight;
				}
			} else {
				for (int i=1;i<dsi;i++)
					ems[i] = Double.NEGATIVE_INFINITY;
			    ems[0] = 0.0;
			}
		}
		return ems;
	}

	public double [] getRawProfileEmission(Symbol [] syms, double [] weights)
			throws IllegalSymbolException {
		int dsi = getDotStatesIndex();
		double [] ems = new double[dsi];
		Arrays.fill(ems,1.0);
		int nums = 0;
		double doneWeight = 0;
		for (int i=0;i<syms.length;i++) {
			if ((syms[i].getName().length()>=3 && syms[i] != AlphabetManager.getGapSymbol())) {
				nums++;
				doneWeight += weights[i];
				double [] em = getRawEmission(syms[i]);
				for (int j=0;j<dsi;j++) {
					if (em[j]>0)
						ems[j] *= Math.pow(em[j],weights[i]);
					else
						ems[j] = 0.0;
					if (em[j]!=ems[j])
						nums += 0;
				}
			}
		}
		if (doneWeight>0) doneWeight = 1/doneWeight;
		if (nums<syms.length) {
			if (nums>0)  {
				for (int i=0;i<dsi;i++) {
					ems[i] = Math.pow(ems[i] ,doneWeight);
				}
			} else {
				for (int i=1;i<dsi;i++)
					ems[i] = 0;
				ems[0] = 1.0;
			}
		}
		return ems;
	}

	public double [] getRawEmission(Symbol sym)
	throws IllegalSymbolException {
	  double [] em = (double []) rawEmissions.get(sym);
	  if(em == null) {
		int dsi = getDotStatesIndex();
		em = new double[dsi];
		State [] states = getStates();
		if(sym == gap) {
		  em[0] = 1;
		} else {
		  em[0] = 0;
		  for(int i = 1; i < dsi; i++) {
			EmissionState es = (EmissionState) states[i];
			Distribution dis = es.getDistribution();
			em[i] = dis.getWeight(sym);
			if(Double.isNaN(em[i])) {
				if(sym instanceof AtomicSymbol) {
					em[i]=0.0;
				} else {
					em[i]=1.0/20.0; // we cant set this to zero since it probably wont match anywhere
				}
			}
		  }
		}
		rawEmissions.put(sym, em);
	  }
	  return em;
	}

	public double [][] forwardTransitionRawScores(
	  MarkovModel model,
	  State [] states,
	  int [][] transitions) throws IllegalSymbolException {
	  if (rawTransitionScores != null)
	  	return rawTransitionScores;
	  int stateCount = states.length;
	  double [][] scores = new double[stateCount][];

	  for (int i = 0; i < stateCount; i++) {
		State is = states[i];
		scores[i] = new double[transitions[i].length];
		for (int j = 0; j < scores[i].length; j++) {
		  try {
			scores[i][j] = model.getWeights(states[transitions[i][j]]).getWeight(is);
		  } catch (IllegalSymbolException ite) {
			throw new BioError(ite,
			  "Transition listed in transitions array has dissapeared.");
		  }
		}
	  }
	  rawTransitionScores = scores;
	  return scores;
	}

	public static double [][] backwardTransitionRawScores(MarkovModel model,
	  State [] states,
	  int [][] transitions
	) throws IllegalSymbolException {
	  int stateCount = states.length;
	  double [][] scores = new double[stateCount][];

	  for (int i = 0; i < stateCount; i++) {
		State is = states[i];
		scores[i] = new double[transitions[i].length];
		for (int j = 0; j < scores[i].length; j++) {
		  try {
			scores[i][j] = model.getWeights(is).getWeight(states[transitions[i][j]]);
		  } catch (IllegalSymbolException ite) {
			throw new BioError(ite,
			  "Transition listed in transitions array has dissapeared");
		  }
		}
	  }

	  return scores;
	}

	public FBMatrix forwardbackwardMatrix(SymbolList seq) throws IllegalSymbolException {
		FBMatrix matrix 				= new FBMatrix(seq,getDotStatesIndex());
		int len 								= matrix.getUngappedLen();
		int [] ung2g 						= matrix.ung2g;
		double [] alfaTemp 			= new double[getStates().length];
		double [] alfaOld	 			= new double[getStates().length];
		double [] betaTemp 			= new double[getStates().length];
		State[] states 						= getStates();
		int [][] fwdTrans 					= getForwardTransitions();
		double [][] fwdTransScore = forwardTransitionRawScores(getModel(), states, fwdTrans);
		int [][] bkwTrans 				= getBackwardTransitions();
		double [][] bkwTransScore = backwardTransitionRawScores(getModel(), states, bkwTrans);
		int i,j,t;
		// Forward
		// Init
		// Take care of transitions
		int [] trans = bkwTrans[0]; // 0->first
		double [] transScore = bkwTransScore[0];
		for (i = 0; i<trans.length;i++)
		    if (trans[i]!=0)
				alfaTemp[trans[i]] = transScore[i];
	    // from dot states -> other (there could be dot states before seq)
        for (j=getDotStatesIndex(); j<states.length; j++) {
        	trans = bkwTrans[j]; // dot->other
			transScore = bkwTransScore[j];        			
			for (i = 0; i<trans.length;i++)
			    if (trans[i]!=0)
					alfaTemp[trans[i]] += transScore[i]*alfaTemp[j];
        }
		// Now emissions
		int tGapped = ung2g[0];
		double[] emissions = getRawEmission(seq.symbolAt(tGapped+1));
		for (j=0;j<getDotStatesIndex(); j++)
			alfaTemp[j] *=emissions[j];
		if (constraints!=null) constraints.constrain(tGapped,alfaTemp);
		// Normalize
		double alfsum=0;
		for (j=1; j<getDotStatesIndex(); j++) 
			alfsum += alfaTemp[j];		
		matrix.c[0] = 1.0/alfsum;		
		for (j=1; j<getDotStatesIndex(); j++) 
			matrix.alfa[0][j] = alfaTemp[j]*matrix.c[0];
		System.arraycopy(matrix.alfa[0],0,alfaTemp,0,getDotStatesIndex());
		// And update the dot states
		for (j=getDotStatesIndex(); j<states.length; j++) {
			trans = fwdTrans[j]; // other -> dot
			transScore = fwdTransScore[j];
			alfaTemp[j]=0;        						
			for (i = 0; i<trans.length;i++)
			if (trans[i]!=0)
				alfaTemp[j] += transScore[i]*alfaTemp[trans[i]];
		}
		// Now do the Recursion		
		for (t=1;t<len;t++) {
			tGapped = ung2g[t];
			double[] tmpArr = alfaOld;
			alfaOld=alfaTemp;
			alfaTemp=tmpArr;
			Arrays.fill(alfaTemp,0);
			emissions = getRawEmission(seq.symbolAt(tGapped+1));
			// Do the emitting ones			
			for (j=0;j<getDotStatesIndex(); j++) {
				trans = fwdTrans[j];
				transScore = fwdTransScore[j];
				for (i = 0; i<trans.length;i++)
					if (trans[i]!=0)
						alfaTemp[j] += transScore[i]*alfaOld[trans[i]];
				alfaTemp[j] *= emissions[j];
			}
			if (constraints!=null) constraints.constrain(tGapped,alfaTemp);
			// Normalize
			alfsum=0;
			for (j=1; j<getDotStatesIndex(); j++) 
				alfsum += alfaTemp[j];		
			matrix.c[t] = 1.0/alfsum;		
			for (j=1; j<getDotStatesIndex(); j++) 
				matrix.alfa[t][j] = alfaTemp[j]*matrix.c[t];
			System.arraycopy(matrix.alfa[t],0,alfaTemp,0,getDotStatesIndex());
			// And update the dot states
			for (j=getDotStatesIndex(); j<states.length; j++) {
				trans = fwdTrans[j]; // other -> dot
				transScore = fwdTransScore[j];
				for (i = 0; i<trans.length;i++)
					if (trans[i]!=0)
						alfaTemp[j] += transScore[i]*alfaTemp[trans[i]];
			}
		}
		// Terminate
		trans = fwdTrans[0]; // 0->first
//		transScore = fwdTransScore[0];
		alfsum = 0;
		for (i = 0; i<trans.length;i++) 
			if (trans[i]!=0)
				alfsum += alfaTemp[trans[i]];
//				alfsum += transScore[i]*alfaTemp[trans[i]];
		matrix.c[len] = 1.0/alfsum;
		double lp=0;
		for (t=0;t<=len;t++) 
			lp -= Math.log(matrix.c[t]);
		matrix.score = lp;
		// Backward
		// Init
		trans = fwdTrans[0]; // 0->first
		transScore = fwdTransScore[0];
		for (i = 0; i<trans.length;i++) 
			if (trans[i]!=0)
				betaTemp[trans[i]] = 1;
//				betaTemp[trans[i]] = transScore[i];
		// from dot states -> other (there could be dot states before seq)
		for (j=states.length-1; j>=getDotStatesIndex(); j--) {
			trans = fwdTrans[j]; // dot->other
			transScore = fwdTransScore[j];        			
			for (i = trans.length-1; i>=0;i--)
				if (trans[i]!=0)
					betaTemp[trans[i]] += transScore[i]*betaTemp[j];
		}
		if (constraints!=null) constraints.constrain(ung2g[len-1],betaTemp);
		// Normalize
		for (j=1; j<getDotStatesIndex(); j++) 
			matrix.beta[len-1][j] = betaTemp[j]*matrix.c[len-1]*matrix.c[len];
		// Now do the Recursion		
		for (t=len-2;t>=0;t--) {
			System.arraycopy(matrix.beta[t+1],0,betaTemp,0,getDotStatesIndex());
			int tGappedp = ung2g[t+1];
			emissions = getRawEmission(seq.symbolAt(tGappedp+1));
			// And update the dot states
			for (j=states.length-1;j>=getDotStatesIndex(); j--) {
				trans = bkwTrans[j]; // other -> dot
				transScore = bkwTransScore[j];
				betaTemp[j]=0;
				for (i = trans.length-1; i>=0; i--)
					if (trans[i]!=0)
						betaTemp[j] += transScore[i]*betaTemp[trans[i]];
			}
			// Do the emitting ones			
			for (j=0;j<getDotStatesIndex(); j++)
				betaTemp[j] *= emissions[j];
			for (j=0;j<getDotStatesIndex(); j++) {
				trans = bkwTrans[j];
				transScore = bkwTransScore[j];
				for (i = 0; i<trans.length;i++)
					if (trans[i]!=0)
						matrix.beta[t][j] += transScore[i]*betaTemp[trans[i]]*matrix.c[t];
			}
			if (constraints!=null) constraints.constrain(ung2g[t],matrix.beta[t]);
		}
		return matrix;
	}

	  public DPMatrix backwardMatrix(SymbolList [] seq, DPMatrix matrix, ScoreType scoreType)
	  	throws IllegalArgumentException, IllegalSymbolException,
	 		IllegalAlphabetException, IllegalSymbolException {
	    if(seq.length != 1) {
	      throw new IllegalArgumentException("seq must be 1 long, not " + seq.length);
	    }
	    
	    lockModel();
	    SingleDPMatrix sm = (SingleDPMatrix) matrix;
	    DPCursor dpCursor = new AlignedMatrixCursor(sm, new ReverseIterator(seq[0]), -1);
	    sm.setScore(backward(dpCursor, scoreType));
	    unlockModel();
	    
	    return sm;
	  }

	  public DPMatrix backwardMatrix(SymbolList [] seq, ScoreType scoreType)
		  throws IllegalSymbolException, IllegalAlphabetException, IllegalSymbolException {
		    SingleDPMatrix matrix = new SingleDPMatrix(this, seq[0]);		    
		    return backwardMatrix(seq, matrix, scoreType);
		  }

		  public DPMatrix forwardMatrix(SymbolList [] seq, DPMatrix matrix, ScoreType scoreType)
		  throws IllegalArgumentException, IllegalSymbolException,
		  IllegalAlphabetException, IllegalSymbolException {
		    if(seq.length != 1) {
		      throw new IllegalArgumentException("seq must be 1 long, not " + seq.length);
		    }
		    
		    lockModel();
		    SingleDPMatrix sm = (SingleDPMatrix) matrix;
		    DPCursor dpCursor = new AlignedMatrixCursor(sm, seq[0].iterator(), +1);
		    sm.setScore(forward(dpCursor, scoreType));
		    unlockModel();
		    
		    return sm;
		  }

		  public DPMatrix forwardMatrix(SymbolList [] seq, ScoreType scoreType)
		  throws IllegalSymbolException, IllegalAlphabetException, IllegalSymbolException {
		    SingleDPMatrix matrix = new SingleDPMatrix(this, seq[0]);
		    return forwardMatrix(seq, matrix, scoreType);
		  }
	/**
	 * @param model
	 * @throws org.biojava.bio.symbol.IllegalSymbolException
	 * @throws org.biojava.bio.dp.IllegalTransitionException
	 * @throws org.biojava.bio.BioException
	 */
	public AlignedDP(MarkovModel model, Constraints c)
		throws IllegalSymbolException, IllegalTransitionException, BioException {
		super(model);
		rawEmissions = new HashMap();
		mapLabels(this.getStates());
		constraints = c;
	}

	public String maxPLP(SymbolList[] seq) throws Exception {
		LabelProbMatrix plp = calcPosteriorLabelProbability(seq);
		return maxPLP(plp);	
	}

	public String maxPLP(LabelProbMatrix plp) throws Exception {
		// Find the max PLP path
		SymbolList[] seq = plp.getSymList();
		String out = "";
		for (int i = 0;i<seq[0].length();i++) {
			int typ=0;
			double maxV=Double.NEGATIVE_INFINITY;
			String maxLab ="Q";
			if (seq.length==1 && seq[0].symbolAt(i+1).getName().equals("[]")) {
				maxLab = "-";						
			} else {
				for (Iterator it = plp.getLabels().iterator(); it.hasNext(); typ++) {
					String lab = (String) it.next();
					if (plp.scores[i][typ]>maxV) {
						maxV=plp.scores[i][typ];
						maxLab = lab;
					}
				}
			}
			if (maxLab != "X") out += maxLab;
		}
		return out;
	}
/*	
	public LabelProbMatrix calcPosteriorLabelProbability(SymbolList[] seq) throws Exception {
		SingleDPMatrix forward = (SingleDPMatrix) forwardMatrix(seq,ScoreType.PROBABILITY);
		SingleDPMatrix backward = (SingleDPMatrix) backwardMatrix(seq,ScoreType.PROBABILITY);
		return new LabelProbMatrix(this,forward,backward);	
	} */
	
	public LabelProbMatrix calcPosteriorLabelProbability(SymbolList[] seq) throws Exception {
		FBMatrix fb = forwardbackwardMatrix(seq[0]);
		LabelProbMatrix lpm = new LabelProbMatrix(this,seq[0]);
		lpm.setValues(fb);
		return lpm;	
	}
	
	public String optimalPLP(SymbolList [] seq)
	throws Exception {
		LabelProbMatrix plp = calcPosteriorLabelProbability(seq);
		return optimalPLP(plp);	
	}

	public String optimalPLP(LabelProbMatrix plp)
	throws Exception {
	  lockModel();

	  SymbolList seq = plp.getSymList()[0];
	  DPCursor dpCursor = new SmallCursor(getStates(), seq, seq.iterator());
   
	  State [] states = getStates();

	  int [][] transitions = getForwardTransitions();
	  int stateCount = states.length;

	  BackPointer [] oldPointers = new BackPointer[stateCount];
	  BackPointer [] newPointers = new BackPointer[stateCount];

	  double [] currentCol;
	  double [] newC = dpCursor.currentCol();
	  double [] oldC = dpCursor.lastCol();
	  int pos=0;

	  // initialize
	  {
		for (int l = 0; l < getDotStatesIndex(); l++) {
		  if(states[l] == getModel().magicalState()) {
			newC[l] = 1.0;
			oldC[l] = 0.0;
			newPointers[l] = new BackPointer(states[l]);
			oldPointers[l] = null;
		  } else {
			newC[l] = oldC[l] = 0.0;
		  }
		}
		for (int l = getDotStatesIndex(); l < stateCount; l++) {
		  int [] tr = transitions[l];
		  double transProb = 0.0;
		  int prev = -1;
		  for (int kc = 0; kc < tr.length; kc++) {
			int k = tr[kc];
			double p = newC[k];
			if (p > transProb) {
			  transProb = p;
			  prev = k;
			}
		  }
		  if(prev != -1) {
			newC[l] = transProb;
			newPointers[l] = new BackPointer(
			  states[l],
			  newPointers[prev],
			  transProb
			);
		  } else {
			newC [l] = oldC[l] = 0.0;
			oldPointers[l] = newPointers[l] = null;
		  }
		}          
		BackPointer [] bp = newPointers;
		newPointers = oldPointers;
		oldPointers = bp;
	}
	// viterbi alike
	while (dpCursor.canAdvance()) { // symbol i
		dpCursor.advance();
		Symbol sym = dpCursor.currentRes();
		newC = dpCursor.currentCol();
		oldC = dpCursor.lastCol();
		newC[0] = 0.0;  
		newPointers[0] = null;
        if (!sym.getName().equals("[]")&&!sym.getName().equals("gap")) {
			currentCol = plp.scores[pos];
			for (int l = 1; l < states.length; l++) { // don't move to magical state
				int [] tr = transitions[l];
				double addon = (l < getDotStatesIndex()) ?currentCol[plp.state2labelIx[l]]:0.0;
				double trans =0.0;
				int prev = -1;
				for (int kc = 0; kc < tr.length; kc++) {
					  int k = tr[kc];
					  double p = (l < getDotStatesIndex()) ? oldC[k] : newC[k];
					  if (p > trans) {
						trans = p;
						prev = k;
					  }
				}
				if(prev != -1) {
					  newC[l] = trans + addon;
					  newPointers[l] = new BackPointer(
						states[l],
					    (l < getDotStatesIndex()) ? oldPointers[prev] : newPointers[prev],
						newC[l]
					  );
				} else {
					  newC[l] = 0;
					  newPointers[l] = null;
				}
			}
			BackPointer [] bp = newPointers;
			newPointers = oldPointers;
			oldPointers = bp;
		} else {
			for (int i=0;i<oldC.length;i++)
				newC[i] = oldC[i];
		}      
		pos++;
	  }

	  BackPointer best = null;
	  int [] tr = transitions[0];
	  double trans =0.0;
	  int prev = -1;
	  for (int kc = 0; kc < tr.length; kc++) {
			int k = tr[kc];
			double p = newC[k];
			if (p > trans) {
			  trans = p;
			  prev = k;
			}
	  }
	  if(prev != -1) {
			best = new BackPointer(
			  states[0],
			  oldPointers[prev], // Yes the old ones
			  newC[prev]
			);
	  }
	  // find max in last row
	  
	  StringBuffer pred = new StringBuffer(seq.seqString());
      int i = pred.length()-1;
	  while (best.state instanceof MagicalState) { best = best.back;}
      while(i>=0){
      	while (i>=0 && pred.charAt(i)=='-') i--;
		while (i>=0 && best.state instanceof DotState) {best=best.back;}
		if (i>=0 && best.state instanceof SimpleEmissionState) {
			String lab;
			try {lab=(String)best.state.getAnnotation().getProperty("Label");}
			catch (Exception e) { lab = "X"; }
			pred.setCharAt(i,lab.charAt(0));
		}		
		i--;
		best = best.back;
      }

	  unlockModel();
	  return pred.toString();
	}

	protected final void forwardCalcEmissionStates(double [] newCol, double [] oldCol, int [][] transitions, double [][] transScore, double [] emissions)
		throws IllegalSymbolException {
		 // Forward
		for (int j=1;j<getDotStatesIndex(); j++) {
			int [] transToJ = transitions[j];
			for (int i = 0; i< transToJ.length;i++)
				newCol[j] += transScore[j][i]*oldCol[transToJ[i]];
			newCol[j] *= emissions[j];
		}
	}

	protected final void forwardCalcDotStates(double [] oldCol, int [][] transitions, double [][] transScore)
		throws IllegalSymbolException {
		 // Forward
		for (int j=this.getDotStatesIndex();j<oldCol.length; j++) {
			int [] transToJ = transitions[j];
			for (int i = 0; i< transToJ.length;i++)
				oldCol[j] += transScore[j][i]*oldCol[transToJ[i]];
		}
//		if (oldCol[189]>0)
//			System.out.println("oldCol[189] = " + oldCol[189]);
	}


	public String [] doNbest(Sequence[] sequences,double [] sWeights, int N, boolean homologs) throws Exception {
		lockModel();

		AlignmentCursor dpCursor = new AlignmentCursor(getStates(),sequences,1);
    
		State [] states = getStates();

		int [][] transitions = getForwardTransitions();
		double [][] transitionScore = forwardTransitionRawScores(getModel(), states, transitions);
		int stateCount = states.length;
		double [] emissions;
		
		ArrayList hypList = new ArrayList();
		Iterator hypIt;
		Hypothesis startHyp = new Hypothesis(stateCount, labels.length);
		Hypothesis.dp = this;
		forwardCalcDotStates(startHyp.oldCol,transitions,transitionScore);
		hypList.add(startHyp);
		
		// Now step forward
		while (dpCursor.canAdvance()) { // symbol i
			dpCursor.advance();
		  	if (homologs) {
				emissions = getRawProfileEmission(dpCursor.currentResidues(),sWeights);
		  	} else {
		  		emissions = getRawEmission(dpCursor.currentRes());
		  	}
		  	hypIt = hypList.iterator();
//		  	System.err.println(dpCursor.currentIndex());
		  	while (hypIt.hasNext()) {
		  		Hypothesis hyp = (Hypothesis) hypIt.next();
		  		forwardCalcEmissionStates(hyp.newCol,hyp.oldCol,transitions,transitionScore,emissions);
			  	if (constraints !=null) constraints.constrain(dpCursor.currentIndex()-1,hyp.newCol);
		  	}
		  	for(int i=1;i<this.getDotStatesIndex();i++) {
		  		hypIt = hypList.iterator();
				while (hypIt.hasNext()) {
					((Hypothesis) hypIt.next()).setCompareState(i);
				}
				Collections.sort(hypList);
				for(int n=0;n<hypList.size() && n<N;n++) {
					Hypothesis hyp = (Hypothesis) hypList.get(n);
					if (!Double.isInfinite(hyp.cmpScore))
						hyp.alive[this.state2labelIx[i]] = true;
				}
		  	}
			// Ok who survived and who should be branched?
			int hypIx =0;
			while(hypIx<hypList.size()) {
				Hypothesis hyp = (Hypothesis) hypList.get(hypIx);
				int alives = 0;
				for (int lax = 0;lax < hyp.alive.length;lax++) {
					if (hyp.alive[lax])
						alives += 1;				
				}
				if (alives==0) {
					hypList.remove(hypIx);
					continue;
				}
				int labelIx = 0;
				// Spawn off new hypothesis
				while (alives>1) {
					while (! hyp.alive[labelIx])
						labelIx++;
					hypList.add(hypIx++,new Hypothesis(hyp,labelIx++));
					alives--;
				}
				// propagate old
				while (! hyp.alive[labelIx])
					labelIx++;
				hyp.move(labelIx);
				hypIx++;				
			}
			hypIt = hypList.iterator();
			while (hypIt.hasNext()) {
				Hypothesis hyp = (Hypothesis) hypIt.next();
				forwardCalcDotStates(hyp.oldCol,transitions,transitionScore);
			}
		}
		hypIt = hypList.iterator();
		while (hypIt.hasNext()) {
			Hypothesis hyp = (Hypothesis) hypIt.next();
			int [] transTo0 = transitions[0];
			for (int i = 0; i< transTo0.length;i++)
				hyp.newCol[0] += hyp.oldCol[transTo0[i]]; // Trans prob=1
			hyp.finishUp();
		}
		Collections.sort(hypList);
		String [] retVal = new String[N];
		for(int n=0;n<hypList.size() && n<N;n++) {
			StringBuffer outPred = new StringBuffer(sequences[0].seqString());
			String donePred = ((Hypothesis)hypList.get(n)).getLabeling();
			int i=0,j=0;
			while(j<donePred.length()) {
				while (outPred.charAt(i)=='-')
					i++;
				outPred.setCharAt(i++,donePred.charAt(j++));
			}
			retVal[n] = outPred.toString();
		}
		unlockModel();
		return retVal;	
	}
}
