/*
 * Decompiled with CFR 0.152.
 */
package de.biozentrum.bioinformatik.cama.hmm;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix2D;
import de.biozentrum.bioinformatik.cama.BioJavaHelper;
import de.biozentrum.bioinformatik.cama.hmm.ExtendedSingleDP;
import java.util.Iterator;
import org.biojava.bio.BioException;
import org.biojava.bio.dist.Distribution;
import org.biojava.bio.dist.DistributionFactory;
import org.biojava.bio.dist.DistributionTools;
import org.biojava.bio.dp.DP;
import org.biojava.bio.dp.EmissionState;
import org.biojava.bio.dp.ProfileHMM;
import org.biojava.bio.dp.ScoreType;
import org.biojava.bio.dp.State;
import org.biojava.bio.dp.StatePath;
import org.biojava.bio.seq.io.SymbolTokenization;
import org.biojava.bio.symbol.Alignment;
import org.biojava.bio.symbol.Alphabet;
import org.biojava.bio.symbol.AlphabetIndex;
import org.biojava.bio.symbol.FiniteAlphabet;
import org.biojava.bio.symbol.IllegalAlphabetException;
import org.biojava.bio.symbol.IllegalSymbolException;
import org.biojava.bio.symbol.Symbol;
import org.biojava.bio.symbol.SymbolList;

public class ProfileHMMTools {
    private static final DoubleFactory2D F2D = DoubleFactory2D.dense;
    public static final int StateTypeMatch = 0;
    public static final int StateTypeInsert = 1;
    public static final int StateTypeDelete = 2;

    public static State getState(ProfileHMM profileHMM, int column, int stateType) {
        switch (stateType) {
            case 2: {
                return profileHMM.getDelete(column);
            }
            case 1: {
                return profileHMM.getInsert(column);
            }
            case 0: {
                return profileHMM.getMatch(column);
            }
        }
        return null;
    }

    public static DoubleMatrix2D getEmissionMatrix(ProfileHMM profileHMM, int stateType, AlphabetIndex index) {
        int columns = profileHMM.columns() + 1;
        int rows = BioJavaHelper.getAllSymbols(index.getAlphabet()).size();
        DoubleMatrix2D matrix = F2D.make(rows, columns);
        if (stateType != 2) {
            int modelPos = 0;
            while (modelPos < columns) {
                if (modelPos != 0 || stateType != 0) {
                    EmissionState state = (EmissionState)ProfileHMMTools.getState(profileHMM, modelPos, stateType);
                    int alphIndex = 0;
                    while (alphIndex < rows) {
                        try {
                            double emm = 0.0;
                            int count = 0;
                            Iterator<Symbol> iterator = ((FiniteAlphabet)index.symbolForIndex(alphIndex).getMatches()).iterator();
                            while (iterator.hasNext()) {
                                emm += state.getDistribution().getWeight(iterator.next());
                                ++count;
                            }
                            matrix.setQuick(alphIndex, modelPos, emm /= (double)count);
                        }
                        catch (IllegalSymbolException e) {
                            e.printStackTrace();
                        }
                        catch (IndexOutOfBoundsException e) {
                            e.printStackTrace();
                        }
                        ++alphIndex;
                    }
                }
                ++modelPos;
            }
        }
        return matrix;
    }

    public static DoubleMatrix2D getTransitionMatrix(ProfileHMM profileHMM, int stateType) {
        int columns = profileHMM.columns() + 1;
        DoubleMatrix2D matrix = F2D.make(3, columns);
        int modelPos = 0;
        while (modelPos < columns) {
            if (modelPos != 0 || stateType != 2) {
                State state = ProfileHMMTools.getState(profileHMM, modelPos, stateType);
                try {
                    Distribution d = profileHMM.getWeights(state);
                    if (modelPos < columns - 1) {
                        matrix.setQuick(2, modelPos, d.getWeight(profileHMM.getDelete(modelPos + 1)));
                    }
                    matrix.setQuick(1, modelPos, d.getWeight(profileHMM.getInsert(modelPos)));
                    matrix.setQuick(0, modelPos, d.getWeight(profileHMM.getMatch(modelPos + 1)));
                }
                catch (Exception e) {
                    System.out.println(e);
                }
            }
            ++modelPos;
        }
        return matrix;
    }

    public static void printProfileHMM(ProfileHMM p) throws Exception {
        int i = 0;
        while (i <= p.columns()) {
            int j = 0;
            while (j < 1) {
                if (i != 0 || j != 1) {
                    State m = null;
                    switch (j) {
                        case 0: {
                            m = p.getMatch(i);
                            break;
                        }
                        case 1: {
                            m = p.getDelete(i);
                            break;
                        }
                        case 2: {
                            m = p.getInsert(i);
                        }
                    }
                    Distribution d = p.getWeights(m);
                    FiniteAlphabet alpha = (FiniteAlphabet)d.getAlphabet();
                    if (j == 0 || j == 2) {
                        State state = m;
                        FiniteAlphabet al = (FiniteAlphabet)state.getDistribution().getAlphabet();
                        System.out.println(al.getClass());
                        Iterator<Symbol> it = al.iterator();
                        while (it.hasNext()) {
                            Symbol sym = it.next();
                            try {
                                System.out.println(String.valueOf(al.getTokenization("token").tokenizeSymbol(sym)) + ": " + state.getDistribution().getWeight(sym));
                            }
                            catch (Exception exception) {
                                // empty catch block
                            }
                        }
                    }
                    Iterator<Symbol> itt = alpha.iterator();
                    while (itt.hasNext()) {
                        Symbol s = itt.next();
                        double w = d.getWeight(s);
                        System.out.println(String.valueOf(m.getName()) + " to " + s.getName() + " Gewicht: " + w);
                    }
                    System.out.println();
                }
                ++j;
            }
            ++i;
        }
    }

    public static ProfileHMM createProfile(Alignment alignment, double matchThreshold, int pseudoCounts) throws Exception {
        boolean[] columnMask = ProfileHMMTools.profileColumnMask(alignment, matchThreshold);
        ProfileHMM profileHMM = new ProfileHMM(((SymbolList)alignment.symbolListIterator().next()).getAlphabet(), ProfileHMMTools.countProfileColumns(columnMask), DistributionFactory.DEFAULT, DistributionFactory.DEFAULT, null);
        ProfileHMMTools.configureProfile(profileHMM, alignment, columnMask, pseudoCounts);
        return profileHMM;
    }

    public static void alignSequence(DP baseDP, SymbolList s, ScoreType scoreType) throws Exception {
        SymbolList[] sl = new SymbolList[]{s};
        StatePath sp = baseDP.viterbi(sl, scoreType);
        double fwScore = baseDP.forward(sl, scoreType);
        double bwScore = baseDP.backward(sl, scoreType);
        String seqString = "";
        String stateString = "";
        Alphabet alphabet = s.getAlphabet();
        SymbolTokenization token = alphabet.getTokenization("token");
        int i = 1;
        while (i <= sp.length()) {
            seqString = String.valueOf(seqString) + token.tokenizeSymbol(sp.symbolAt(StatePath.SEQUENCE, i));
            stateString = String.valueOf(stateString) + sp.symbolAt(StatePath.STATES, i).getName().charAt(0);
            ++i;
        }
        System.out.println("Alignment-Sequence [" + seqString + "] [" + stateString + "] alias [" + "name" + "]: viterbi[" + sp.getScore() + "] fw[" + fwScore + "] bw[" + bwScore + "]");
        System.out.println(((ExtendedSingleDP)baseDP).posteriorMatrix(sl, scoreType));
    }

    public static ExtendedSingleDP createDP(ProfileHMM profileHMM) throws IllegalArgumentException, BioException {
        return new ExtendedSingleDP(profileHMM);
    }

    public static int countProfileColumns(Alignment alignment, double threshold) {
        return ProfileHMMTools.countProfileColumns(ProfileHMMTools.profileColumnMask(alignment, threshold));
    }

    public static void configureProfile(ProfileHMM p, Alignment al, boolean[] profileColumnMask, int pseudoCounts) throws Exception {
        ProfileHMMTools.initializeTransitionWeight(p, al, profileColumnMask, pseudoCounts);
        ProfileHMMTools.initializeEmmitions(p, al, profileColumnMask);
    }

    public static int countProfileColumns(boolean[] profileColumnMask) {
        int count = 0;
        boolean[] blArray = profileColumnMask;
        int n = profileColumnMask.length;
        int n2 = 0;
        while (n2 < n) {
            boolean b = blArray[n2];
            if (b) {
                ++count;
            }
            ++n2;
        }
        return count;
    }

    public static boolean[] profileColumnMask(Alignment alignment, double threshold) {
        Symbol gapSymbol = ((SymbolList)alignment.symbolListIterator().next()).getAlphabet().getGapSymbol();
        boolean[] mask = new boolean[alignment.length()];
        Distribution[] dis = null;
        try {
            dis = DistributionTools.distOverAlignment(alignment, true, 1.0);
            int i = 0;
            while (i < dis.length) {
                double gapWeight = dis[i].getWeight(gapSymbol);
                mask[i] = gapWeight < threshold;
                ++i;
            }
        }
        catch (IllegalAlphabetException e) {
            e.printStackTrace();
        }
        catch (IllegalSymbolException e) {
            e.printStackTrace();
        }
        return mask;
    }

    public static void initializeEmmitions(ProfileHMM p, Alignment alignment, boolean[] profileColumnMask) throws Exception {
        int pCol = 0;
        SymbolList list = (SymbolList)alignment.symbolListIterator().next();
        int alphabetSize = ((FiniteAlphabet)list.getAlphabet()).size();
        Distribution[] d = DistributionTools.distOverAlignment(alignment, false, alphabetSize);
        int[] cols = new int[alignment.length()];
        int i = 0;
        while (i < cols.length) {
            cols[i] = i + 1;
            ++i;
        }
        p.getInsert(0).setDistribution(DistributionTools.average(d));
        i = 0;
        while (i < d.length) {
            if (profileColumnMask[i]) {
                p.getMatch(++pCol).setDistribution(d[i]);
                p.getInsert(pCol).setDistribution(d[i]);
            }
            ++i;
        }
    }

    private static int getAlignmentColumnFromProfileColumn(ProfileHMM p, Alignment a, int pCol, boolean[] profileColumnMask) throws Exception {
        if (pCol <= 0 || pCol > p.columns()) {
            return -1;
        }
        int i = 0;
        while (i < a.length()) {
            if (profileColumnMask[i]) {
                --pCol;
            }
            if (pCol == 0) {
                return i;
            }
            ++i;
        }
        return -1;
    }

    private static boolean isSymbol(SymbolList s, int col1) {
        return s.symbolAt(col1 + 1) != s.getAlphabet().getGapSymbol();
    }

    private static boolean SymbolBetween(SymbolList s, int col1, int col2) throws Exception {
        if (col2 - col1 <= 1) {
            return false;
        }
        int i = col1 + 1;
        while (i < col2) {
            if (ProfileHMMTools.isSymbol(s, i)) {
                return true;
            }
            ++i;
        }
        return false;
    }

    private static int countSymbolsBetween(SymbolList s, int col1, int col2) {
        if (col2 - col1 <= 1) {
            return 0;
        }
        int cnt = 0;
        int i = col1 + 1;
        while (i < col2) {
            if (ProfileHMMTools.isSymbol(s, i)) {
                ++cnt;
            }
            ++i;
        }
        return cnt;
    }

    public static void initializeTransitionWeight(ProfileHMM p, Alignment a, boolean[] profileColumnMask, int pseudoCnt) throws Exception {
        boolean insertBlock = false;
        int pCol = 0;
        while (pCol <= p.columns()) {
            Distribution d_di;
            Distribution i_di;
            Distribution m_di;
            int aCol2;
            int aCol1 = ProfileHMMTools.getAlignmentColumnFromProfileColumn(p, a, pCol, profileColumnMask);
            if (aCol1 == (aCol2 = ProfileHMMTools.getAlignmentColumnFromProfileColumn(p, a, pCol + 1, profileColumnMask))) {
                throw new Exception("Invalid column numbers!");
            }
            insertBlock = aCol2 - aCol1 > 1;
            int M2M = pseudoCnt;
            int M2I = pseudoCnt;
            int M2D = pseudoCnt;
            int I2M = pseudoCnt;
            int I2I = pseudoCnt;
            int I2D = pseudoCnt;
            int D2M = pseudoCnt;
            int D2I = pseudoCnt;
            int D2D = pseudoCnt;
            Iterator it = a.symbolListIterator();
            while (it.hasNext()) {
                SymbolList s = (SymbolList)it.next();
                if (aCol1 == -1) {
                    if (ProfileHMMTools.isSymbol(s, aCol2) && !ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                        ++M2M;
                    }
                    if (!ProfileHMMTools.isSymbol(s, aCol2) && !ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                        ++M2D;
                    }
                    if (ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                        ++M2I;
                    }
                    if (ProfileHMMTools.SymbolBetween(s, aCol1, aCol2) && ProfileHMMTools.isSymbol(s, aCol2)) {
                        ++I2M;
                    }
                    if (ProfileHMMTools.SymbolBetween(s, aCol1, aCol2) && !ProfileHMMTools.isSymbol(s, aCol2)) {
                        ++I2D;
                    }
                    if (ProfileHMMTools.countSymbolsBetween(s, aCol1, aCol2) > 1) {
                        I2I += ProfileHMMTools.countSymbolsBetween(s, aCol1, aCol2) - 1;
                    }
                    D2M = 0;
                    D2D = 0;
                    D2I = 0;
                    continue;
                }
                if (aCol2 == -1) {
                    if (ProfileHMMTools.isSymbol(s, aCol1) && !ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                        ++M2M;
                    }
                    if (!ProfileHMMTools.isSymbol(s, aCol1) && !ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                        ++D2M;
                    }
                    if (ProfileHMMTools.countSymbolsBetween(s, aCol1, aCol2) <= 1) continue;
                    I2M += ProfileHMMTools.countSymbolsBetween(s, aCol1, aCol2) - 1;
                    continue;
                }
                if (ProfileHMMTools.isSymbol(s, aCol1) && ProfileHMMTools.isSymbol(s, aCol2) && !ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                    ++M2M;
                }
                if (ProfileHMMTools.isSymbol(s, aCol1) && !ProfileHMMTools.isSymbol(s, aCol2) && !ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                    ++M2D;
                }
                if (ProfileHMMTools.isSymbol(s, aCol1) && ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                    ++M2I;
                }
                if (!ProfileHMMTools.isSymbol(s, aCol1) && ProfileHMMTools.isSymbol(s, aCol2) && !ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                    ++D2M;
                }
                if (!(ProfileHMMTools.isSymbol(s, aCol1) || ProfileHMMTools.isSymbol(s, aCol2) || ProfileHMMTools.SymbolBetween(s, aCol1, aCol2))) {
                    ++D2D;
                }
                if (!ProfileHMMTools.isSymbol(s, aCol1) && ProfileHMMTools.SymbolBetween(s, aCol1, aCol2)) {
                    ++D2I;
                }
                if (ProfileHMMTools.SymbolBetween(s, aCol1, aCol2) && ProfileHMMTools.isSymbol(s, aCol2)) {
                    ++I2M;
                }
                if (ProfileHMMTools.SymbolBetween(s, aCol1, aCol2) && !ProfileHMMTools.isSymbol(s, aCol2)) {
                    ++I2D;
                }
                if (ProfileHMMTools.countSymbolsBetween(s, aCol1, aCol2) <= 1) continue;
                I2I += ProfileHMMTools.countSymbolsBetween(s, aCol1, aCol2) - 1;
            }
            int matchSum = M2M + M2D + M2I;
            int insertSum = I2M + I2D + I2I;
            int deleteSum = D2M + D2D + D2I;
            double M2M_Prob = (double)M2M / (double)matchSum;
            double M2D_Prob = (double)M2D / (double)matchSum;
            double M2I_Prob = (double)M2I / (double)matchSum;
            double D2M_Prob = (double)D2M / (double)deleteSum;
            double D2D_Prob = (double)D2D / (double)deleteSum;
            double D2I_Prob = (double)D2I / (double)deleteSum;
            double I2M_Prob = (double)I2M / (double)insertSum;
            double I2D_Prob = (double)I2D / (double)insertSum;
            double I2I_Prob = (double)I2I / (double)insertSum;
            if (pCol == 0) {
                m_di = p.getWeights(p.getMatch(pCol));
                i_di = p.getWeights(p.getInsert(pCol));
                m_di.setWeight(p.getMatch(pCol + 1), M2M_Prob);
                m_di.setWeight(p.getDelete(pCol + 1), M2D_Prob);
                m_di.setWeight(p.getInsert(pCol), M2I_Prob);
                i_di.setWeight(p.getMatch(pCol + 1), I2M_Prob);
                i_di.setWeight(p.getDelete(pCol + 1), I2D_Prob);
                i_di.setWeight(p.getInsert(pCol), I2I_Prob);
            } else if (pCol >= p.columns()) {
                matchSum = M2M + M2I;
                insertSum = I2M + I2I;
                deleteSum = D2M + D2I;
                M2M_Prob = (double)M2M / (double)matchSum;
                M2I_Prob = (double)M2I / (double)matchSum;
                D2M_Prob = (double)D2M / (double)deleteSum;
                D2I_Prob = (double)D2I / (double)deleteSum;
                I2M_Prob = (double)I2M / (double)insertSum;
                I2I_Prob = (double)I2I / (double)insertSum;
                m_di = p.getWeights(p.getMatch(pCol));
                i_di = p.getWeights(p.getInsert(pCol));
                d_di = p.getWeights(p.getDelete(pCol));
                m_di.setWeight(p.getMatch(pCol + 1), M2M_Prob);
                m_di.setWeight(p.getInsert(pCol), M2I_Prob);
                d_di.setWeight(p.getMatch(pCol + 1), D2M_Prob);
                d_di.setWeight(p.getInsert(pCol), D2I_Prob);
                i_di.setWeight(p.getMatch(pCol + 1), I2M_Prob);
                i_di.setWeight(p.getInsert(pCol), I2I_Prob);
            } else {
                m_di = p.getWeights(p.getMatch(pCol));
                i_di = p.getWeights(p.getInsert(pCol));
                d_di = p.getWeights(p.getDelete(pCol));
                m_di.setWeight(p.getMatch(pCol + 1), M2M_Prob);
                m_di.setWeight(p.getDelete(pCol + 1), M2D_Prob);
                m_di.setWeight(p.getInsert(pCol), M2I_Prob);
                i_di.setWeight(p.getMatch(pCol + 1), I2M_Prob);
                i_di.setWeight(p.getDelete(pCol + 1), I2D_Prob);
                i_di.setWeight(p.getInsert(pCol), I2I_Prob);
                d_di.setWeight(p.getMatch(pCol + 1), D2M_Prob);
                d_di.setWeight(p.getDelete(pCol + 1), D2D_Prob);
                d_di.setWeight(p.getInsert(pCol), D2I_Prob);
            }
            ++pCol;
        }
    }
}

