/*
 * Decompiled with CFR 0.152.
 */
package de.biozentrum.bioinformatik.cama.hmm;

import org.biojava.bio.dist.Distribution;
import org.biojava.bio.dist.DistributionFactory;
import org.biojava.bio.dist.DistributionTools;
import org.biojava.bio.dp.DP;
import org.biojava.bio.dp.DPFactory;
import org.biojava.bio.dp.ProfileHMM;
import org.biojava.bio.dp.ScoreType;
import org.biojava.bio.dp.StatePath;
import org.biojava.bio.seq.DNATools;
import org.biojava.bio.seq.Sequence;
import org.biojava.bio.seq.SequenceIterator;
import org.biojava.bio.seq.db.SequenceDB;
import org.biojava.bio.seq.io.SymbolTokenization;
import org.biojava.bio.symbol.Alignment;
import org.biojava.bio.symbol.Alphabet;
import org.biojava.bio.symbol.SymbolList;

public class SeqEngine {
    private Alignment baseAlignment = null;
    private Alphabet baseAlphabet = null;
    private SequenceDB baseDB = null;
    private ProfileHMM baseProfile = null;
    private DP baseDP = null;
    private int matchColumns = 0;
    private double matchThreshold = 0.5;

    public SeqEngine(Alignment alignment, SequenceDB db, double matchThreshold) throws Exception {
        this.matchThreshold = matchThreshold;
        this.baseAlignment = alignment;
        this.baseDB = db;
        this.baseAlphabet = this.baseAlignment.getAlphabet();
        this.matchColumns = this.getMatchColumns(this.baseAlignment, matchThreshold);
        this.baseProfile = this.createProfile(this.baseAlphabet, this.matchColumns);
        this.configureProfile(this.baseProfile, this.baseDB, this.baseAlignment);
        this.baseDP = DPFactory.DEFAULT.createDP(this.baseProfile);
    }

    public void alignBaseSequences() throws Exception {
        SequenceIterator it = this.baseDB.sequenceIterator();
        while (it.hasNext()) {
            Sequence s = it.nextSequence();
            this.alignSequence(s);
        }
    }

    public void alignSequence(SymbolList s) throws Exception {
        SymbolList[] sl = new SymbolList[]{s};
        StatePath sp = this.baseDP.viterbi(sl, ScoreType.ODDS);
        double fwScore = this.baseDP.forward(sl, ScoreType.ODDS);
        double bwScore = this.baseDP.backward(sl, ScoreType.ODDS);
        String seqString = "";
        String stateString = "";
        Alphabet alphabet = s.getAlphabet();
        SymbolTokenization token = alphabet.getTokenization("token");
        int i = 1;
        while (i <= sp.length()) {
            seqString = String.valueOf(seqString) + token.tokenizeSymbol(sp.symbolAt(StatePath.SEQUENCE, i));
            stateString = String.valueOf(stateString) + sp.symbolAt(StatePath.STATES, i).getName().charAt(0);
            ++i;
        }
        System.out.println("Alignment-Sequence [" + seqString + "] [" + stateString + "] alias [" + "name" + "]: viterbi[" + sp.getScore() + "] fw[" + fwScore + "] bw[" + bwScore + "]");
    }

    public void alignTestSequence() throws Exception {
        String test = "ag-ag-a";
        Sequence testSeq = DNATools.createDNASequence(test, "testSeq");
        SymbolList[] sl = new SymbolList[]{testSeq};
        StatePath sp = this.baseDP.viterbi(sl, ScoreType.PROBABILITY);
        String seqString = "";
        String stateString = "";
        SymbolTokenization token = DNATools.getDNA().getTokenization("token");
        int i = 1;
        while (i <= sp.length()) {
            seqString = String.valueOf(seqString) + token.tokenizeSymbol(sp.symbolAt(StatePath.SEQUENCE, i));
            stateString = String.valueOf(stateString) + sp.symbolAt(StatePath.STATES, i).getName().charAt(0);
            ++i;
        }
    }

    private void configureProfile(ProfileHMM p, SequenceDB db, Alignment al) throws Exception {
        this.calcTransitionWeight(p, this.baseAlignment, this.baseDB, 1);
        this.setEmissionWeight(p, this.baseAlignment);
    }

    private ProfileHMM createProfile(Alphabet alpha, int cols) throws Exception {
        ProfileHMM profile = new ProfileHMM(alpha, cols, DistributionFactory.DEFAULT, DistributionFactory.DEFAULT, null);
        return profile;
    }

    private int getMatchColumns(Alignment alignment, double threshold) throws Exception {
        int matchCols = 0;
        double colWeight = 0.0;
        Distribution[] dis = DistributionTools.distOverAlignment(alignment, true, 1.0);
        int i = 0;
        while (i < dis.length) {
            colWeight = dis[i].getWeight(DNATools.n());
            if (colWeight >= threshold) {
                ++matchCols;
            }
            ++i;
        }
        return matchCols;
    }

    private boolean isMatchColumn(Alignment alignment, int alignmentCol, double threshold) throws Exception {
        Distribution[] dis = DistributionTools.distOverAlignment(alignment, true, 1.0);
        double colWeight = dis[alignmentCol].getWeight(DNATools.n());
        return colWeight >= threshold;
    }

    private void setEmissionWeight(ProfileHMM p, Alignment alignment) throws Exception {
        int pCol = 0;
        Distribution[] d = DistributionTools.distOverAlignment(alignment, true, 1.0);
        int i = 0;
        while (i < d.length) {
            if (this.isMatchColumn(alignment, i, this.matchThreshold)) {
                p.getMatch(++pCol).setDistribution(d[i]);
                p.getInsert(pCol).setDistribution(d[i]);
            }
            ++i;
        }
    }

    private int getAColFromPCol(ProfileHMM p, Alignment a, int pCol, double threshold) throws Exception {
        Distribution[] dis = DistributionTools.distOverAlignment(a, true, 0.0);
        if (pCol <= 0 || pCol > p.columns()) {
            return -1;
        }
        int i = 0;
        while (i < dis.length) {
            if (this.isMatchColumn(a, i, threshold)) {
                --pCol;
            }
            if (pCol == 0) {
                return i;
            }
            ++i;
        }
        return -1;
    }

    private void calcTransitionWeight(ProfileHMM p, Alignment a, SequenceDB db, int pseudoCnt) throws Exception {
        boolean insertBlock = false;
        int pCol = 0;
        while (pCol <= p.columns()) {
            Distribution d_di;
            Distribution i_di;
            Distribution m_di;
            int aCol2;
            int aCol1 = this.getAColFromPCol(p, a, pCol, this.matchThreshold);
            if (aCol1 == (aCol2 = this.getAColFromPCol(p, a, pCol + 1, this.matchThreshold))) {
                throw new Exception("Invalid column numbers!");
            }
            insertBlock = aCol2 - aCol1 > 1;
            int M2M = pseudoCnt;
            int M2I = pseudoCnt;
            int M2D = pseudoCnt;
            int I2M = pseudoCnt;
            int I2I = pseudoCnt;
            int I2D = pseudoCnt;
            int D2M = pseudoCnt;
            int D2I = pseudoCnt;
            int D2D = pseudoCnt;
            SequenceIterator it = db.sequenceIterator();
            while (it.hasNext()) {
                Sequence s = it.nextSequence();
                if (aCol1 == -1) {
                    if (this.isSymbol(s, aCol2) && !this.SymbolBetween(s, aCol1, aCol2)) {
                        ++M2M;
                    }
                    if (!this.isSymbol(s, aCol2) && !this.SymbolBetween(s, aCol1, aCol2)) {
                        ++M2D;
                    }
                    if (this.SymbolBetween(s, aCol1, aCol2)) {
                        ++M2I;
                    }
                    if (this.SymbolBetween(s, aCol1, aCol2) && this.isSymbol(s, aCol2)) {
                        ++I2M;
                    }
                    if (this.SymbolBetween(s, aCol1, aCol2) && !this.isSymbol(s, aCol2)) {
                        ++I2D;
                    }
                    if (this.countSymbolsBetween(s, aCol1, aCol2) > 1) {
                        I2I += this.countSymbolsBetween(s, aCol1, aCol2) - 1;
                    }
                    D2M = 0;
                    D2D = 0;
                    D2I = 0;
                    continue;
                }
                if (aCol2 == -1) {
                    if (this.isSymbol(s, aCol1) && !this.SymbolBetween(s, aCol1, aCol2)) {
                        ++M2M;
                    }
                    if (!this.isSymbol(s, aCol1) && !this.SymbolBetween(s, aCol1, aCol2)) {
                        ++D2M;
                    }
                    if (this.countSymbolsBetween(s, aCol1, aCol2) <= 1) continue;
                    I2M += this.countSymbolsBetween(s, aCol1, aCol2) - 1;
                    continue;
                }
                if (this.isSymbol(s, aCol1) && this.isSymbol(s, aCol2) && !this.SymbolBetween(s, aCol1, aCol2)) {
                    ++M2M;
                }
                if (this.isSymbol(s, aCol1) && !this.isSymbol(s, aCol2) && !this.SymbolBetween(s, aCol1, aCol2)) {
                    ++M2D;
                }
                if (this.isSymbol(s, aCol1) && this.SymbolBetween(s, aCol1, aCol2)) {
                    ++M2I;
                }
                if (!this.isSymbol(s, aCol1) && this.isSymbol(s, aCol2) && !this.SymbolBetween(s, aCol1, aCol2)) {
                    ++D2M;
                }
                if (!(this.isSymbol(s, aCol1) || this.isSymbol(s, aCol2) || this.SymbolBetween(s, aCol1, aCol2))) {
                    ++D2D;
                }
                if (!this.isSymbol(s, aCol1) && this.SymbolBetween(s, aCol1, aCol2)) {
                    ++D2I;
                }
                if (this.SymbolBetween(s, aCol1, aCol2) && this.isSymbol(s, aCol2)) {
                    ++I2M;
                }
                if (this.SymbolBetween(s, aCol1, aCol2) && !this.isSymbol(s, aCol2)) {
                    ++I2D;
                }
                if (this.countSymbolsBetween(s, aCol1, aCol2) <= 1) continue;
                I2I += this.countSymbolsBetween(s, aCol1, aCol2) - 1;
            }
            int matchSum = M2M + M2D + M2I;
            int insertSum = I2M + I2D + I2I;
            int deleteSum = D2M + D2D + D2I;
            double M2M_Prob = (double)M2M / (double)matchSum;
            double M2D_Prob = (double)M2D / (double)matchSum;
            double M2I_Prob = (double)M2I / (double)matchSum;
            double D2M_Prob = (double)D2M / (double)deleteSum;
            double D2D_Prob = (double)D2D / (double)deleteSum;
            double D2I_Prob = (double)D2I / (double)deleteSum;
            double I2M_Prob = (double)I2M / (double)insertSum;
            double I2D_Prob = (double)I2D / (double)insertSum;
            double I2I_Prob = (double)I2I / (double)insertSum;
            if (pCol == 0) {
                m_di = p.getWeights(p.getMatch(pCol));
                i_di = p.getWeights(p.getInsert(pCol));
                m_di.setWeight(p.getMatch(pCol + 1), M2M_Prob);
                m_di.setWeight(p.getDelete(pCol + 1), M2D_Prob);
                m_di.setWeight(p.getInsert(pCol), M2I_Prob);
                i_di.setWeight(p.getMatch(pCol + 1), I2M_Prob);
                i_di.setWeight(p.getDelete(pCol + 1), I2D_Prob);
                i_di.setWeight(p.getInsert(pCol), I2I_Prob);
            } else if (pCol >= p.columns()) {
                matchSum = M2M + M2I;
                insertSum = I2M + I2I;
                deleteSum = D2M + D2I;
                M2M_Prob = (double)M2M / (double)matchSum;
                M2I_Prob = (double)M2I / (double)matchSum;
                D2M_Prob = (double)D2M / (double)deleteSum;
                D2I_Prob = (double)D2I / (double)deleteSum;
                I2M_Prob = (double)I2M / (double)insertSum;
                I2I_Prob = (double)I2I / (double)insertSum;
                m_di = p.getWeights(p.getMatch(pCol));
                i_di = p.getWeights(p.getInsert(pCol));
                d_di = p.getWeights(p.getDelete(pCol));
                m_di.setWeight(p.getMatch(pCol + 1), M2M_Prob);
                m_di.setWeight(p.getInsert(pCol), M2I_Prob);
                d_di.setWeight(p.getMatch(pCol + 1), D2M_Prob);
                d_di.setWeight(p.getInsert(pCol), D2I_Prob);
                i_di.setWeight(p.getMatch(pCol + 1), I2M_Prob);
                i_di.setWeight(p.getInsert(pCol), I2I_Prob);
            } else {
                m_di = p.getWeights(p.getMatch(pCol));
                i_di = p.getWeights(p.getInsert(pCol));
                d_di = p.getWeights(p.getDelete(pCol));
                m_di.setWeight(p.getMatch(pCol + 1), M2M_Prob);
                m_di.setWeight(p.getDelete(pCol + 1), M2D_Prob);
                m_di.setWeight(p.getInsert(pCol), M2I_Prob);
                i_di.setWeight(p.getMatch(pCol + 1), I2M_Prob);
                i_di.setWeight(p.getDelete(pCol + 1), I2D_Prob);
                i_di.setWeight(p.getInsert(pCol), I2I_Prob);
                d_di.setWeight(p.getMatch(pCol + 1), D2M_Prob);
                d_di.setWeight(p.getDelete(pCol + 1), D2D_Prob);
                d_di.setWeight(p.getInsert(pCol), D2I_Prob);
            }
            ++pCol;
        }
    }

    private int countSymbolsBetween(Sequence s, int col1, int col2) {
        if (col2 - col1 <= 1) {
            return 0;
        }
        int cnt = 0;
        int i = col1 + 1;
        while (i < col2) {
            if (this.isSymbol(s, i)) {
                ++cnt;
            }
            ++i;
        }
        return cnt;
    }

    private boolean SymbolBetween(Sequence s, int col1, int col2) {
        if (col2 - col1 <= 1) {
            return false;
        }
        int i = col1 + 1;
        if (i < col2) {
            return this.isSymbol(s, i);
        }
        return false;
    }

    private boolean isSymbol(Sequence s, int col1) {
        System.out.println("isSymbol:" + (s.symbolAt(col1 + 1) != s.getAlphabet().getGapSymbol()));
        return s.symbolAt(col1 + 1) != s.getAlphabet().getGapSymbol();
    }

    public ProfileHMM getProfile() throws Exception {
        return this.baseProfile;
    }
}

