/*
 * Decompiled with CFR 0.152.
 */
package se.ki.cgb.labeledhmm;

import java.io.PrintStream;
import java.text.DecimalFormat;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.biojava.bio.dp.DP;
import org.biojava.bio.dp.State;
import org.biojava.bio.dp.onehead.SingleDPMatrix;
import org.biojava.bio.symbol.SymbolList;
import se.ki.cgb.labeledhmm.FBMatrix;

public class LabelProbMatrix
extends SingleDPMatrix {
    protected HashMap label2states;
    protected int[] state2labelIx;
    protected Set labels;
    public static final long serialVersionUID = 9080834L;

    public LabelProbMatrix(DP dP, SymbolList symbolList) {
        super(dP, symbolList);
        this.mapLabels(dP.getStates());
        this.scores = new double[symbolList.length()][this.labels.size()];
    }

    public void setValues(FBMatrix fBMatrix) {
        int n = this.symList[0].length();
        for (int i = 0; i < n; ++i) {
            int n2 = 0;
            double[] dArray = fBMatrix.getGamma(i);
            Iterator iterator = this.labels.iterator();
            while (iterator.hasNext()) {
                String string = (String)iterator.next();
                if (dArray == null) {
                    this.scores[i][n2] = 0.0;
                } else {
                    Iterator iterator2 = ((HashSet)this.label2states.get(string)).iterator();
                    double d = 0.0;
                    while (iterator2.hasNext()) {
                        int n3 = (Integer)iterator2.next();
                        d += dArray[n3];
                    }
                    this.scores[i][n2] = d;
                }
                if (this.scores[i][n2] > 1.0) {
                    System.err.println("Error detected at pos=" + i + ", label=" + string + ", Symbol is " + this.symList[0].symbolAt(i + 1).getName() + ", prob=" + this.scores[i][n2]);
                    this.scores[i][n2] = 1.0;
                }
                if (this.scores[i][n2] < 0.0) {
                    System.err.println("Error detected at " + i + ", " + n2 + ", Symbol is " + this.symList[0].symbolAt(i + 1).getName());
                    this.scores[i][n2] = 0.0;
                }
                ++n2;
            }
        }
    }

    public LabelProbMatrix(DP dP, SingleDPMatrix singleDPMatrix, SingleDPMatrix singleDPMatrix2) {
        super(dP, singleDPMatrix.symList()[0]);
        this.initiateMatrix(singleDPMatrix, singleDPMatrix2);
    }

    public void initiateMatrix(SingleDPMatrix singleDPMatrix, SingleDPMatrix singleDPMatrix2) {
        this.mapLabels(singleDPMatrix.states());
        int n = this.symList[0].length();
        double d = singleDPMatrix.getScore();
        this.scores = new double[n][this.labels.size()];
        int n2 = 0;
        int n3 = n - 1;
        while (this.symList[0].symbolAt(n2 + 1).getName().equals("[]")) {
            ++n2;
        }
        while (this.symList[0].symbolAt(n3 + 1).getName().equals("[]")) {
            --n3;
        }
        for (int i = 0; i < n; ++i) {
            int n4 = 0;
            Iterator iterator = this.labels.iterator();
            while (iterator.hasNext()) {
                String string = (String)iterator.next();
                if (i < n2 || i > n3 || this.symList[0].symbolAt(i + 1).getName().equals("[]")) {
                    this.scores[i][n4] = -Math.log(this.labels.size());
                } else if (i >= 1 && i <= n - 1) {
                    Iterator iterator2 = ((HashSet)this.label2states.get(string)).iterator();
                    double d2 = 0.0;
                    while (iterator2.hasNext()) {
                        int n5 = (Integer)iterator2.next();
                        if (!Double.isInfinite(d2 += Math.exp(singleDPMatrix.scores[i][n5] + singleDPMatrix2.scores[i][n5] - d)) || !(d2 > 0.0)) continue;
                        d2 = 0.0;
                    }
                    this.scores[i][n4] = Math.log(d2);
                }
                if (Double.isNaN(this.scores[i][n4])) {
                    System.err.println("NaN detected at " + i + ", " + n4 + ", Symbol is " + this.symList[0].symbolAt(i).getName());
                    this.scores[i][n4] = Double.NEGATIVE_INFINITY;
                }
                ++n4;
            }
        }
    }

    public void mapLabels(State[] stateArray) {
        int n;
        String string = "";
        this.label2states = new HashMap();
        for (n = 0; n < stateArray.length; ++n) {
            try {
                HashSet hashSet;
                string = (String)stateArray[n].getAnnotation().getProperty((Object)"Label");
                if (!this.label2states.containsKey(string)) {
                    hashSet = new HashSet();
                    hashSet.add(new Integer(n));
                    this.label2states.put(string, hashSet);
                    continue;
                }
                hashSet = (HashSet)this.label2states.get(string);
                hashSet.add(new Integer(n));
                continue;
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
        this.labels = this.label2states.keySet();
        this.state2labelIx = new int[stateArray.length];
        n = 0;
        int n2 = 0;
        Iterator iterator = this.labels.iterator();
        while (iterator.hasNext()) {
            string = (String)iterator.next();
            HashSet hashSet = (HashSet)this.label2states.get(string);
            Iterator iterator2 = hashSet.iterator();
            while (iterator2.hasNext()) {
                n2 = (Integer)iterator2.next();
                this.state2labelIx[n2] = n;
            }
            ++n;
        }
    }

    public void printPLP(PrintStream printStream, boolean bl) {
        DecimalFormat decimalFormat = new DecimalFormat("0.00E00");
        String string = "          ";
        printStream.print("#" + string.substring(1));
        Object object = this.labels.iterator();
        while (object.hasNext()) {
            String string2 = (String)object.next();
            printStream.print(string2 + string.substring(1));
        }
        printStream.println();
        int n = 0;
        for (int i = 0; i < this.symList[0].length(); ++i) {
            if (bl && this.symList[0].symbolAt(i + 1).getName().length() < 3) continue;
            object = Integer.toString(++n);
            printStream.print((String)object + string.substring(((String)object).length()));
            for (int j = 0; j < this.labels.size(); ++j) {
                object = decimalFormat.format(this.scores[i][j]);
                printStream.print((String)object + string.substring(((String)object).length()));
            }
            printStream.println();
        }
    }

    public void weight(double d) {
        for (int i = 0; i < this.symList[0].length(); ++i) {
            int n = 0;
            while (n < this.labels.size()) {
                double[] dArray = this.scores[i];
                int n2 = n++;
                dArray[n2] = dArray[n2] * d;
            }
        }
    }

    public void join(LabelProbMatrix labelProbMatrix, double d) {
        for (int i = 0; i < this.symList[0].length(); ++i) {
            for (int j = 0; j < this.labels.size(); ++j) {
                double[] dArray = this.scores[i];
                int n = j;
                dArray[n] = dArray[n] + labelProbMatrix.scores[i][j] * d;
                if (!Double.isNaN(this.scores[i][j])) continue;
                System.err.println("NaN detected at " + i + ", " + j + ", Symbol is " + this.symList[0].symbolAt(i + 1).getName());
                this.scores[i][j] = 0.0;
            }
        }
    }

    public void normalize() {
        for (int i = 0; i < this.symList[0].length(); ++i) {
            int n;
            double d = 0.0;
            for (n = 0; n < this.labels.size(); ++n) {
                d += this.scores[i][n];
            }
            if (!(d > 0.0)) continue;
            n = 0;
            while (n < this.labels.size()) {
                double[] dArray = this.scores[i];
                int n2 = n++;
                dArray[n2] = dArray[n2] / d;
            }
        }
    }

    public HashMap getLabel2states() {
        return this.label2states;
    }

    public Set getLabels() {
        return this.labels;
    }

    public State[] getStates() {
        return this.states;
    }

    public SymbolList[] getSymList() {
        return this.symList;
    }
}

