/*
 * Decompiled with CFR 0.152.
 */
package de.biozentrum.bioinformatik.cama.hmm;

import cern.colt.function.DoubleFunction;
import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import de.biozentrum.bioinformatik.cama.BioJavaHelper;
import de.biozentrum.bioinformatik.cama.hmm.ProfileHMMTools;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;
import org.biojava.bio.dp.ProfileHMM;
import org.biojava.bio.symbol.AlphabetIndex;
import org.biojava.bio.symbol.FiniteAlphabet;
import org.biojava.bio.symbol.IllegalSymbolException;
import org.biojava.bio.symbol.Symbol;
import org.biojava.bio.symbol.SymbolList;

public class FisherScoring {
    FisherScoringParameters parameters = null;
    LogSpace logSpace = null;
    DoubleFactory2D F2D = DoubleFactory2D.dense;
    DoubleFactory1D F1D = DoubleFactory1D.dense;
    ProfileHMM model = null;
    AlphabetIndex alphabetIndex = null;
    int alphabetSize = 0;
    int columnCount = 0;
    int stateCount = 0;
    double[][][] incomingTransitionsCache = null;
    double[][][] outgoingTransitionsCache = null;
    HashMap<Symbol, Integer> alphabetIndexCache = null;
    HashMap<Symbol, int[]> matchingSymbolsCache = null;
    double[][] emissionCache = null;

    public static FisherScoringParameters createFisherScoringParameters(final LogStrategy logStrategy) {
        return new FisherScoringParameters(){

            public LogStrategy getLogStrategy() {
                return logStrategy;
            }
        };
    }

    public static FisherScoringParameters createDefaultFisherScoringParameters() {
        return FisherScoring.createFisherScoringParameters(new LogStrategyExact());
    }

    public FisherScoring(ProfileHMM model, AlphabetIndex alphabetIndex) {
        this(model, alphabetIndex, FisherScoring.createDefaultFisherScoringParameters());
    }

    public FisherScoring(ProfileHMM model, AlphabetIndex alphabetIndex, FisherScoringParameters parameters) {
        this.model = model;
        this.alphabetIndex = alphabetIndex;
        this.alphabetSize = alphabetIndex.getAlphabet().size();
        this.parameters = parameters;
        this.logSpace = new LogSpace(parameters.getLogStrategy());
        this.columnCount = this.model.columns() + 1;
        this.stateCount = this.columnCount * 3;
        this.initializeMatrices();
    }

    private void initializeMatrices() {
        DoubleMatrix2D[] transitionMatrices = new DoubleMatrix2D[]{ProfileHMMTools.getTransitionMatrix(this.model, 0).assign(this.logSpace.LOG), ProfileHMMTools.getTransitionMatrix(this.model, 1).assign(this.logSpace.LOG), ProfileHMMTools.getTransitionMatrix(this.model, 2).assign(this.logSpace.LOG)};
        DoubleMatrix2D[] emissionMatrices = new DoubleMatrix2D[]{ProfileHMMTools.getEmissionMatrix(this.model, 0, this.alphabetIndex).assign(this.logSpace.LOG), ProfileHMMTools.getEmissionMatrix(this.model, 1, this.alphabetIndex).assign(this.logSpace.LOG)};
        Set<Symbol> allSymbols = BioJavaHelper.getAllSymbols(this.alphabetIndex.getAlphabet());
        this.alphabetIndexCache = new HashMap(this.alphabetSize, 1.0f);
        for (Symbol s : allSymbols) {
            int index = -1;
            try {
                index = this.alphabetIndex.indexForSymbol(s);
            }
            catch (IllegalSymbolException e) {
                e.printStackTrace();
            }
            this.alphabetIndexCache.put(s, index);
        }
        this.matchingSymbolsCache = new HashMap(this.alphabetSize, 1.0f);
        for (Symbol s : allSymbols) {
            FiniteAlphabet matchAlphabet = (FiniteAlphabet)s.getMatches();
            int[] matches = new int[matchAlphabet.size()];
            int i = 0;
            Iterator<Symbol> itmatch = matchAlphabet.iterator();
            while (itmatch.hasNext()) {
                matches[i] = this.alphabetIndexCache.get(itmatch.next());
                ++i;
            }
            this.matchingSymbolsCache.put(s, matches);
        }
        this.outgoingTransitionsCache = new double[this.stateCount][][];
        this.incomingTransitionsCache = new double[this.stateCount][][];
        this.emissionCache = new double[this.stateCount + 3][this.alphabetIndexCache.values().size()];
        int a = 0;
        while (a < this.stateCount) {
            this.outgoingTransitionsCache[a] = this.getOutgoingTransitionTriplet(a, transitionMatrices);
            this.incomingTransitionsCache[a] = this.getIncomingTransitionTriplet(a + 1, transitionMatrices);
            ++a;
        }
        a = 0;
        while (a < this.stateCount + 3) {
            for (int symbolIndex : this.alphabetIndexCache.values()) {
                this.emissionCache[a][symbolIndex] = this.getEmission(a, symbolIndex, emissionMatrices);
            }
            ++a;
        }
    }

    private double[][] getOutgoingTransitionTriplet(int stateIndex, DoubleMatrix2D[] transitionMatrices) {
        double[][] triplet = new double[][]{{this.logSpace.ZERO, this.logSpace.ZERO, this.logSpace.ZERO}, {0.0, 0.0, 0.0}};
        int column = stateIndex / 3;
        int stateType = stateIndex % 3;
        DoubleMatrix2D transMatMatch = transitionMatrices[0];
        DoubleMatrix2D transMatDelete = transitionMatrices[2];
        DoubleMatrix2D transMatInsert = transitionMatrices[1];
        if (column < transMatDelete.columns()) {
            switch (stateType) {
                case 2: {
                    triplet[0][0] = transMatDelete.getQuick(0, column);
                    triplet[0][1] = transMatDelete.getQuick(1, column);
                    triplet[0][2] = transMatDelete.getQuick(2, column);
                    triplet[1][0] = (column + 1) * 3;
                    triplet[1][1] = column * 3 + 1;
                    triplet[1][2] = (column + 1) * 3 + 2;
                    break;
                }
                case 1: {
                    triplet[0][0] = transMatInsert.getQuick(0, column);
                    triplet[0][1] = transMatInsert.getQuick(1, column);
                    triplet[0][2] = transMatInsert.getQuick(2, column);
                    triplet[1][0] = (column + 1) * 3;
                    triplet[1][1] = stateIndex;
                    triplet[1][2] = (column + 1) * 3 + 2;
                    break;
                }
                case 0: {
                    triplet[0][0] = transMatMatch.getQuick(0, column);
                    triplet[0][1] = transMatMatch.getQuick(1, column);
                    triplet[0][2] = transMatMatch.getQuick(2, column);
                    triplet[1][0] = (column + 1) * 3;
                    triplet[1][1] = column * 3 + 1;
                    triplet[1][2] = (column + 1) * 3 + 2;
                }
            }
        } else {
            triplet[0] = new double[3];
            triplet[1] = new double[]{column * 3, column * 3, column * 3};
        }
        return triplet;
    }

    private double[][] getIncomingTransitionTriplet(int stateIndex, DoubleMatrix2D[] transitionMatrices) {
        double[][] triplet = new double[][]{{this.logSpace.ZERO, this.logSpace.ZERO, this.logSpace.ZERO}, {0.0, 0.0, 0.0}};
        int column = stateIndex / 3;
        int stateType = stateIndex % 3;
        DoubleMatrix2D transMatMatch = transitionMatrices[0];
        DoubleMatrix2D transMatDelete = transitionMatrices[2];
        DoubleMatrix2D transMatInsert = transitionMatrices[1];
        switch (stateType) {
            case 2: {
                if (column == 0) break;
                triplet[0][0] = transMatMatch.getQuick(2, column - 1);
                triplet[0][1] = transMatInsert.getQuick(2, column - 1);
                triplet[0][2] = transMatDelete.getQuick(2, column - 1);
                triplet[1][0] = (column - 1) * 3;
                triplet[1][1] = (column - 1) * 3 + 1;
                triplet[1][2] = (column - 1) * 3 + 2;
                break;
            }
            case 1: {
                triplet[0][0] = transMatMatch.getQuick(1, column);
                triplet[0][1] = transMatInsert.getQuick(1, column);
                triplet[0][2] = transMatDelete.getQuick(1, column);
                triplet[1][0] = column * 3;
                triplet[1][1] = column * 3 + 1;
                triplet[1][2] = column * 3 + 2;
                break;
            }
            case 0: {
                triplet[0][0] = transMatMatch.getQuick(0, column - 1);
                triplet[0][1] = transMatInsert.getQuick(0, column - 1);
                triplet[0][2] = transMatDelete.getQuick(0, column - 1);
                triplet[1][0] = (column - 1) * 3;
                triplet[1][1] = (column - 1) * 3 + 1;
                triplet[1][2] = (column - 1) * 3 + 2;
            }
        }
        return triplet;
    }

    private double getEmission(int stateIndex, int symbolIndex, DoubleMatrix2D[] emissionMatrices) {
        double emission = this.logSpace.ZERO;
        int column = stateIndex / 3;
        int stateType = stateIndex % 3;
        DoubleMatrix2D emitMatMatch = emissionMatrices[0];
        DoubleMatrix2D emitMatInsert = emissionMatrices[1];
        if (column < emitMatMatch.columns()) {
            switch (stateType) {
                case 2: {
                    emission = this.logSpace.ZERO;
                    break;
                }
                case 1: {
                    emission = emitMatInsert.getQuick(symbolIndex, column);
                    break;
                }
                case 0: {
                    emission = emitMatMatch.getQuick(symbolIndex, column);
                }
            }
        }
        return emission;
    }

    public int getFisherScoreDimensions() {
        return this.model.columns() * this.alphabetSize;
    }

    public DoubleMatrix1D matchFisherVector(SymbolList[] symbolListArray) {
        DoubleMatrix2D scoreMat = this.F2D.make(1, this.model.columns() * this.alphabetSize);
        this.matchFisherVector(symbolListArray, scoreMat, 0);
        return scoreMat.viewRow(0);
    }

    public void matchFisherVector(SymbolList[] symbolListArray, DoubleMatrix2D scoreMat, int scoreIndex) {
        double[][] triplet;
        SymbolList sequence = symbolListArray[0];
        int sequenceLength = sequence.length();
        DoubleMatrix2D forw = this.F2D.make(this.stateCount, sequenceLength + 1);
        DoubleMatrix2D back = this.F2D.make(this.stateCount + 3, sequenceLength + 1);
        forw.setQuick(0, 0, this.logSpace.toLog(1.0));
        int i = 1;
        while (i <= sequenceLength) {
            forw.setQuick(0, i, this.logSpace.ZERO);
            ++i;
        }
        int a = 1;
        while (a < this.stateCount) {
            double value = this.logSpace.ZERO;
            if (a % 3 == 2) {
                triplet = this.incomingTransitionsCache[a - 1];
                int b = 0;
                while (b < 3) {
                    value = this.logSpace.sum(value, triplet[0][b] + forw.getQuick((int)triplet[1][b], 0));
                    ++b;
                }
            }
            forw.setQuick(a, 0, value);
            ++a;
        }
        i = 1;
        while (i <= sequenceLength) {
            int symbolIndex = this.alphabetIndexCache.get(sequence.symbolAt(i));
            int a2 = 1;
            while (a2 < this.stateCount) {
                int b;
                triplet = this.incomingTransitionsCache[a2 - 1];
                double value = this.logSpace.ZERO;
                if (a2 % 3 != 2) {
                    b = 0;
                    while (b < 3) {
                        value = this.logSpace.sum(value, triplet[0][b] + forw.getQuick((int)triplet[1][b], i - 1));
                        ++b;
                    }
                    value += this.emissionCache[a2][symbolIndex];
                } else {
                    b = 0;
                    while (b < 3) {
                        value = this.logSpace.sum(value, triplet[0][b] + forw.getQuick((int)triplet[1][b], i));
                        ++b;
                    }
                }
                forw.setQuick(a2, i, value);
                ++a2;
            }
            ++i;
        }
        double seqScore = this.logSpace.ZERO;
        double[][] finalTriplet = this.incomingTransitionsCache[this.stateCount - 1];
        int b = 0;
        while (b < 3) {
            seqScore = this.logSpace.sum(seqScore, forw.getQuick((int)finalTriplet[1][b], sequenceLength) + finalTriplet[0][b]);
            ++b;
        }
        back.set(this.stateCount, sequenceLength, this.logSpace.toLog(1.0));
        int a3 = this.stateCount - 1;
        while (a3 >= 0) {
            double[][] triplet2 = this.outgoingTransitionsCache[a3];
            double value = this.logSpace.ZERO;
            int b2 = 0;
            while (b2 < 3) {
                int sid = (int)triplet2[1][b2];
                if (sid % 3 == 2 || sid == this.stateCount) {
                    value = this.logSpace.sum(value, triplet2[0][b2] + back.getQuick(sid, sequenceLength));
                }
                ++b2;
            }
            back.set(a3, sequenceLength, value);
            --a3;
        }
        int i2 = sequenceLength - 1;
        while (i2 >= 0) {
            int symbolIndex = this.alphabetIndexCache.get(sequence.symbolAt(i2 + 1));
            int[] matches = this.matchingSymbolsCache.get(sequence.symbolAt(i2 + 1));
            int a4 = this.stateCount - 1;
            while (a4 >= 0) {
                double emit;
                double value = this.logSpace.ZERO;
                double[][] triplet3 = this.outgoingTransitionsCache[a4];
                int b3 = 0;
                while (b3 < 3) {
                    int sid = (int)triplet3[1][b3];
                    if (sid % 3 != 2) {
                        emit = this.emissionCache[sid][symbolIndex];
                        value = this.logSpace.sum(value, triplet3[0][b3] + emit + back.getQuick(sid, i2 + 1));
                    } else {
                        value = this.logSpace.sum(value, triplet3[0][b3] + back.getQuick(sid, i2));
                    }
                    ++b3;
                }
                back.setQuick(a4, i2, value);
                int columnIndex = a4 / 3;
                int rowIndex = a4 % 3;
                if (rowIndex == 0 && columnIndex > 0) {
                    emit = this.emissionCache[a4][symbolIndex];
                    double backVal = back.getQuick(a4, i2 + 1);
                    double forwVal = forw.getQuick(a4, i2 + 1);
                    int[] nArray = matches;
                    int n = matches.length;
                    int n2 = 0;
                    while (n2 < n) {
                        int fisherIndex = nArray[n2];
                        try {
                            double f = scoreMat.getQuick(scoreIndex, (columnIndex - 1) * this.alphabetSize + fisherIndex);
                            scoreMat.setQuick(scoreIndex, (columnIndex - 1) * this.alphabetSize + fisherIndex, f += this.logSpace.toReal(backVal + forwVal - emit - seqScore));
                        }
                        catch (Exception e) {
                            e.printStackTrace();
                        }
                        ++n2;
                    }
                    int sigma = 0;
                    while (sigma < this.alphabetSize) {
                        double f2 = scoreMat.getQuick(scoreIndex, (columnIndex - 1) * this.alphabetSize + sigma);
                        scoreMat.setQuick(scoreIndex, (columnIndex - 1) * this.alphabetSize + sigma, f2 -= this.logSpace.toReal(backVal + forwVal - seqScore));
                        ++sigma;
                    }
                }
                --a4;
            }
            --i2;
        }
    }

    public static interface FisherScoringParameters {
        public LogStrategy getLogStrategy();
    }

    private class LogSpace {
        public double ZERO = Double.MAX_VALUE;
        public LogFunction LOG = new LogFunction();
        public RealFunction REAL = new RealFunction();
        public LogStrategy logStrategy = null;

        public LogSpace(LogStrategy strategy) {
            this.logStrategy = strategy;
        }

        public double log(double val) {
            return this.logStrategy.log(val);
        }

        public double exp(double val) {
            return this.logStrategy.exp(val);
        }

        public double sum(double p1, double p2) {
            double smaller;
            double larger;
            if (p1 > p2) {
                larger = p1;
                smaller = p2;
            } else {
                larger = p2;
                smaller = p1;
            }
            return smaller - Math.log1p(this.exp(smaller - larger));
        }

        public double toLog(double p1) {
            assert (p1 >= 0.0);
            if (p1 == 0.0) {
                return this.ZERO;
            }
            return -this.log(p1);
        }

        public double toReal(double p1) {
            return this.exp(-p1);
        }

        private class LogFunction
        implements DoubleFunction {
            private LogFunction() {
            }

            public double apply(double p1) {
                return LogSpace.this.toLog(p1);
            }
        }

        private class RealFunction
        implements DoubleFunction {
            private RealFunction() {
            }

            public double apply(double p1) {
                return LogSpace.this.toReal(p1);
            }
        }
    }

    public static interface LogStrategy {
        public double log(double var1);

        public double exp(double var1);
    }

    public static class LogStrategyExact
    implements LogStrategy {
        public double exp(double val) {
            return Math.exp(val);
        }

        public double log(double val) {
            return Math.log(val);
        }
    }

    public static class LogStrategyFast
    implements LogStrategy {
        public double exp(double val) {
            if (val < -700.0) {
                return 0.0;
            }
            long tmp = (long)(1512775.0 * val + 1.072632447E9);
            return Double.longBitsToDouble(tmp << 32);
        }

        public double log(double val) {
            return 6.0 * (val - 1.0) / (val + 1.0 + 4.0 * Math.sqrt(val));
        }
    }
}

