/*
 * Decompiled with CFR 0.152.
 */
package de.biozentrum.bioinformatik.cama.hmm;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.biojava.bio.BioException;
import org.biojava.bio.dp.DPMatrix;
import org.biojava.bio.dp.EmissionState;
import org.biojava.bio.dp.IllegalTransitionException;
import org.biojava.bio.dp.MarkovModel;
import org.biojava.bio.dp.ProfileHMM;
import org.biojava.bio.dp.ScoreType;
import org.biojava.bio.dp.State;
import org.biojava.bio.dp.onehead.SingleDP;
import org.biojava.bio.dp.onehead.SingleDPMatrix;
import org.biojava.bio.symbol.AlphabetIndex;
import org.biojava.bio.symbol.FiniteAlphabet;
import org.biojava.bio.symbol.IllegalAlphabetException;
import org.biojava.bio.symbol.IllegalSymbolException;
import org.biojava.bio.symbol.Symbol;
import org.biojava.bio.symbol.SymbolList;

public class ExtendedSingleDP
extends SingleDP {
    public ExtendedSingleDP(MarkovModel model) throws IllegalSymbolException, IllegalTransitionException, BioException {
        super(model);
    }

    public DPMatrix posteriorMatrix(SymbolList[] symbolList, ScoreType scoreType) throws IllegalSymbolException, IllegalAlphabetException, IllegalArgumentException {
        return this.posteriorMatrix(symbolList, this.forwardMatrix(symbolList, scoreType), this.backwardMatrix(symbolList, scoreType), scoreType);
    }

    public DPMatrix posteriorMatrix(SymbolList[] symbolList, DPMatrix forwardMatrix, DPMatrix backwardMatrix, ScoreType scoreType) throws IllegalSymbolException, IllegalAlphabetException, IllegalArgumentException {
        SingleDPMatrix matrix = new SingleDPMatrix(this, symbolList[0]);
        double score = forwardMatrix.getScore();
        int i = 1;
        while (i <= symbolList[0].length()) {
            double[] emmissions = this.getEmission(symbolList[0].symbolAt(i), scoreType);
            int s = 1;
            while (s < this.getDotStatesIndex()) {
                double forwardValue = forwardMatrix.getCell(new int[]{s, i});
                double backwardValue = backwardMatrix.getCell(new int[]{s, i});
                matrix.scores[i][s] = forwardValue + backwardValue - score;
                ++s;
            }
            ++i;
        }
        return matrix;
    }

    public DoubleMatrix2D calculateFisherScores(SymbolList[] symbolList, DPMatrix forwardMatrix, DPMatrix backwardMatrix, ScoreType score) {
        if (symbolList.length != 1) {
            throw new IllegalArgumentException("seq must be 1 long, not " + symbolList.length);
        }
        List<State> stList = Arrays.asList(this.getStates());
        SymbolList sequence = symbolList[0];
        ProfileHMM profileHMM = (ProfileHMM)this.getModel();
        double alpha = forwardMatrix.getScore();
        FiniteAlphabet alph = (FiniteAlphabet)profileHMM.getMatch(1).getDistribution().getAlphabet();
        Iterator<Symbol> alphabetIterator = alph.iterator();
        DenseDoubleMatrix2D matrix = new DenseDoubleMatrix2D(new double[profileHMM.columns()][alph.size()]);
        int alphPos = 0;
        while (alphabetIterator.hasNext()) {
            Symbol currentSymbol = alphabetIterator.next();
            int j = 1;
            while (j <= profileHMM.columns()) {
                double lhs = Double.NEGATIVE_INFINITY;
                double rhs = Double.NEGATIVE_INFINITY;
                EmissionState matchState = profileHMM.getMatch(j);
                int stateIndex = stList.indexOf(matchState);
                int l = 1;
                while (l <= sequence.length()) {
                    double Alpha = forwardMatrix.getCell(new int[]{stateIndex, l});
                    double Beta2 = backwardMatrix.getCell(new int[]{stateIndex, l});
                    if (currentSymbol.equals(sequence.symbolAt(l))) {
                        lhs = Math.log(Math.exp(lhs) + Math.exp(Alpha + Beta2));
                    }
                    rhs = Math.log(Math.exp(rhs) + Math.exp(Alpha + Beta2));
                    ++l;
                }
                double emissionProb = Double.NEGATIVE_INFINITY;
                try {
                    emissionProb = this.getEmission(currentSymbol, score)[stateIndex];
                }
                catch (IllegalSymbolException e) {
                    e.printStackTrace();
                }
                double p1 = -1.0 * alpha + lhs;
                double p2 = -1.0 * alpha + rhs;
                ((DoubleMatrix2D)matrix).setQuick(j - 1, alphPos, Math.exp(p1) - Math.exp(p2));
                ++j;
            }
            ++alphPos;
        }
        return matrix;
    }

    public DoubleMatrix1D matchFisherVector(SymbolList[] symbolListArray, AlphabetIndex alphabetIndex, DPMatrix forwardMatrix, DPMatrix backwardMatrix, ScoreType scoreType) {
        ProfileHMM profileHMM = (ProfileHMM)this.getModel();
        List<State> states = Arrays.asList(this.getStates());
        DenseDoubleMatrix1D fisherVector = new DenseDoubleMatrix1D(new double[profileHMM.columns() * alphabetIndex.getAlphabet().size()]);
        double seqScore = forwardMatrix.getScore();
        int i = 0;
        while (i < alphabetIndex.getAlphabet().size()) {
            int j = 1;
            while (j <= profileHMM.columns()) {
                EmissionState matchState = profileHMM.getMatch(j);
                double lhs = Math.log(0.0);
                double rhs = Math.log(0.0);
                int l = 1;
                while (l <= symbolListArray[0].length()) {
                    int stateIndex = states.indexOf(matchState);
                    double Alpha = forwardMatrix.getCell(new int[]{stateIndex, l});
                    double Beta2 = backwardMatrix.getCell(new int[]{stateIndex, l});
                    if (alphabetIndex.symbolForIndex(i).equals(symbolListArray[0].symbolAt(l))) {
                        lhs = Math.log(Math.exp(lhs) + Math.exp(Alpha + Beta2));
                    }
                    rhs = Math.log(Math.exp(rhs) + Math.exp(Alpha + Beta2));
                    ++l;
                }
                double emmisionProb = Math.log(0.0);
                try {
                    emmisionProb = Math.log(matchState.getDistribution().getWeight(alphabetIndex.symbolForIndex(i)));
                }
                catch (IllegalSymbolException e) {
                    e.printStackTrace();
                }
                double p1 = -1.0 * seqScore - emmisionProb + lhs;
                double p2 = -1.0 * seqScore + rhs;
                double f = Math.exp(p1) - Math.exp(p2);
                ((DoubleMatrix1D)fisherVector).setQuick((j - 1) * alphabetIndex.getAlphabet().size() + i, f);
                ++j;
            }
            ++i;
        }
        return fisherVector;
    }

    public DoubleMatrix1D matchFisherVector2(SymbolList[] symbolListArray, AlphabetIndex alphabetIndex, DPMatrix forwardMatrix, DPMatrix backwardMatrix, ScoreType scoreType) {
        ProfileHMM profileHMM = (ProfileHMM)this.getModel();
        List<State> states = Arrays.asList(this.getStates());
        int alphabetSize = alphabetIndex.getAlphabet().size();
        DenseDoubleMatrix1D score = new DenseDoubleMatrix1D(new double[profileHMM.columns() * alphabetSize]);
        double seqScore = Math.exp(forwardMatrix.getScore());
        int l = 1;
        while (l <= symbolListArray[0].length()) {
            int j = 1;
            while (j <= profileHMM.columns()) {
                EmissionState matchState = profileHMM.getMatch(j);
                int stateIndex = states.indexOf(matchState);
                int i = -1;
                try {
                    i = alphabetIndex.indexForSymbol(symbolListArray[0].symbolAt(l));
                }
                catch (IllegalSymbolException e) {
                    e.printStackTrace();
                }
                double f = ((DoubleMatrix1D)score).getQuick((j - 1) * alphabetSize + i);
                double alpha = Math.exp(forwardMatrix.getCell(new int[]{stateIndex, l}));
                double beta = Math.exp(backwardMatrix.getCell(new int[]{stateIndex, l}));
                double emissionProb = 0.0;
                try {
                    emissionProb = matchState.getDistribution().getWeight(alphabetIndex.symbolForIndex(i));
                }
                catch (IllegalSymbolException e) {
                    e.printStackTrace();
                }
                ((DoubleMatrix1D)score).setQuick((j - 1) * alphabetSize + i, f += alpha * beta / emissionProb / seqScore);
                int sigma = 0;
                while (sigma < alphabetSize) {
                    double f2 = ((DoubleMatrix1D)score).getQuick((j - 1) * alphabetSize + sigma);
                    ((DoubleMatrix1D)score).setQuick((j - 1) * alphabetSize + sigma, f2 -= alpha * beta / seqScore);
                    ++sigma;
                }
                ++j;
            }
            ++l;
        }
        return score;
    }
}

