package de.dfki.lt.mary.unitselection.viterbi;

import de.dfki.lt.mary.modules.synthesis.SynthesisException;
import de.dfki.lt.mary.unitselection.DiphoneTarget;
import de.dfki.lt.mary.unitselection.DiphoneUnit;
import de.dfki.lt.mary.unitselection.HalfPhoneTarget;
import de.dfki.lt.mary.unitselection.JoinCostFunction;
import de.dfki.lt.mary.unitselection.SelectedUnit;
import de.dfki.lt.mary.unitselection.Target;
import de.dfki.lt.mary.unitselection.TargetCostFunction;
import de.dfki.lt.mary.unitselection.Unit;
import de.dfki.lt.mary.unitselection.UnitDatabase;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.text.DecimalFormat;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

/* loaded from: input_file:de/dfki/lt/mary/unitselection/viterbi/Viterbi.class */
public class Viterbi {
    protected final float wTargetCosts;
    protected final float wJoinCosts;
    protected ViterbiPoint firstPoint;
    protected ViterbiPoint lastPoint;
    protected LinkedHashMap f;
    private UnitDatabase database;
    protected TargetCostFunction targetCostFunction;
    protected JoinCostFunction joinCostFunction;
    private static Map<UnitDatabase, DebugStats> debugStats;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected int searchStrategy = 25;
    protected Logger logger = Logger.getLogger("Viterbi");
    protected double cumulJoinCosts = 0.0d;
    protected int nJoinCosts = 0;
    protected double cumulTargetCosts = 0.0d;
    protected int nTargetCosts = 0;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/dfki/lt/mary/unitselection/viterbi/Viterbi$DebugStats.class */
    public class DebugStats {
        int n;
        double avgLength;
        double avgCostBestPath;
        double avgTargetCost;
        double avgJoinCost;

        private DebugStats() {
        }
    }

    public Viterbi(List<Target> list, UnitDatabase unitDatabase, float f) {
        this.firstPoint = null;
        this.lastPoint = null;
        this.f = null;
        this.database = unitDatabase;
        this.targetCostFunction = unitDatabase.getTargetCostFunction();
        this.joinCostFunction = unitDatabase.getJoinCostFunction();
        this.wTargetCosts = f;
        this.wJoinCosts = 1.0f - f;
        ViterbiPoint viterbiPoint = null;
        this.f = new LinkedHashMap();
        Iterator<Target> it = list.iterator();
        while (it.hasNext()) {
            ViterbiPoint viterbiPoint2 = new ViterbiPoint(it.next());
            if (viterbiPoint != null) {
                viterbiPoint.setNext(viterbiPoint2);
            } else {
                this.firstPoint = viterbiPoint2;
                this.firstPoint.getPaths().add(new ViterbiPath());
            }
            viterbiPoint = viterbiPoint2;
        }
        this.lastPoint = new ViterbiPoint(null);
        viterbiPoint.setNext(this.lastPoint);
        if (this.searchStrategy == 0) {
            throw new IllegalStateException("General beam search not implemented");
        }
    }

    public void setFeature(String str, Object obj) {
        this.f.put(str, obj);
    }

    public Object getFeature(String str) {
        return this.f.get(str);
    }

    public void apply() throws SynthesisException {
        ViterbiPoint viterbiPoint = this.firstPoint;
        while (true) {
            ViterbiPoint viterbiPoint2 = viterbiPoint;
            if (viterbiPoint2.getNext() == null) {
                return;
            }
            Target target = viterbiPoint2.getTarget();
            ViterbiCandidate[] candidates = this.database.getCandidates(target);
            if (candidates.length == 0) {
                if (!(target instanceof DiphoneTarget)) {
                    throw new SynthesisException("Cannot find any units for target " + target);
                }
                this.logger.debug("No diphone '" + target.getName() + "' -- will build from halfphones");
                DiphoneTarget diphoneTarget = (DiphoneTarget) target;
                HalfPhoneTarget left = diphoneTarget.getLeft();
                HalfPhoneTarget right = diphoneTarget.getRight();
                viterbiPoint2.setTarget(left);
                ViterbiPoint viterbiPoint3 = new ViterbiPoint(right);
                viterbiPoint3.setNext(viterbiPoint2.getNext());
                viterbiPoint2.setNext(viterbiPoint3);
                candidates = this.database.getCandidates(left);
                if (candidates.length == 0) {
                    throw new SynthesisException("Cannot even find any halfphone unit for target " + left);
                }
            }
            if (!$assertionsDisabled && candidates.length <= 0) {
                throw new AssertionError();
            }
            viterbiPoint2.setCandidates(candidates);
            if (!$assertionsDisabled && this.searchStrategy == 0) {
                throw new AssertionError();
            }
            SortedSet paths = viterbiPoint2.getPaths();
            int size = paths.size();
            if (this.searchStrategy != -1 && this.searchStrategy < size) {
                size = this.searchStrategy;
            }
            Iterator it = paths.iterator();
            for (int i = 0; i < size; i++) {
                ViterbiPath viterbiPath = (ViterbiPath) it.next();
                if (!$assertionsDisabled && viterbiPath == null) {
                    throw new AssertionError();
                }
                ViterbiCandidate[] candidates2 = viterbiPoint2.getCandidates();
                if (!$assertionsDisabled && candidates2 == null) {
                    throw new AssertionError();
                }
                for (ViterbiCandidate viterbiCandidate : candidates2) {
                    addPath(viterbiPoint2.getNext(), getPath(viterbiPath, viterbiCandidate));
                }
            }
            viterbiPoint = viterbiPoint2.getNext();
        }
    }

    void addPath(ViterbiPoint viterbiPoint, ViterbiPath viterbiPath) {
        ViterbiCandidate candidate = viterbiPath.getCandidate();
        if (!$assertionsDisabled && candidate == null) {
            throw new AssertionError();
        }
        ViterbiPath bestPath = candidate.getBestPath();
        SortedSet paths = viterbiPoint.getPaths();
        if (bestPath == null) {
            paths.add(viterbiPath);
            candidate.setBestPath(viterbiPath);
        } else if (viterbiPath.getScore() < bestPath.getScore()) {
            paths.remove(bestPath);
            paths.add(viterbiPath);
            candidate.setBestPath(viterbiPath);
        }
    }

    public List<SelectedUnit> getSelectedUnits() {
        LinkedList linkedList = new LinkedList();
        if (this.firstPoint == null || this.firstPoint.getNext() == null) {
            return linkedList;
        }
        ViterbiPath findBestPath = findBestPath();
        if (findBestPath == null) {
            return null;
        }
        ViterbiPath viterbiPath = findBestPath;
        while (true) {
            ViterbiPath viterbiPath2 = viterbiPath;
            if (viterbiPath2 == null) {
                if (this.logger.getEffectiveLevel().equals(Level.DEBUG)) {
                    StringWriter stringWriter = new StringWriter();
                    PrintWriter printWriter = new PrintWriter(stringWriter);
                    int i = -1;
                    int[] iArr = new int[10];
                    int i2 = 0;
                    int size = linkedList.size();
                    StringBuffer stringBuffer = new StringBuffer();
                    for (int i3 = 0; i3 < size; i3++) {
                        SelectedUnit selectedUnit = (SelectedUnit) linkedList.get(i3);
                        int index = selectedUnit.getUnit().getIndex();
                        if (i + 1 == index) {
                            i2++;
                        } else {
                            if (iArr.length <= i2) {
                                int[] iArr2 = new int[i2 + 1];
                                System.arraycopy(iArr, 0, iArr2, 0, iArr.length);
                                iArr = iArr2;
                            }
                            int[] iArr3 = iArr;
                            int i4 = i2;
                            iArr3[i4] = iArr3[i4] + 1;
                            printWriter.print(stringBuffer);
                            if (i3 > 0) {
                                if (!$assertionsDisabled && i3 < i2) {
                                    throw new AssertionError();
                                }
                                String filenameAndTime = this.database.getFilenameAndTime(((SelectedUnit) linkedList.get(i3 - i2)).getUnit());
                                for (int length = stringBuffer.length(); length < 80; length++) {
                                    printWriter.print(" ");
                                }
                                printWriter.print(filenameAndTime);
                            }
                            printWriter.println();
                            i2 = 1;
                            stringBuffer = new StringBuffer();
                        }
                        stringBuffer.append(this.database.getTargetCostFunction().getFeature(selectedUnit.getUnit(), "mary_phoneme") + "(" + selectedUnit.getUnit().getIndex() + ")");
                        i = index;
                    }
                    if (iArr.length <= i2) {
                        int[] iArr4 = new int[i2 + 1];
                        System.arraycopy(iArr, 0, iArr4, 0, iArr.length);
                        iArr = iArr4;
                    }
                    int[] iArr5 = iArr;
                    int i5 = i2;
                    iArr5[i5] = iArr5[i5] + 1;
                    printWriter.print(stringBuffer);
                    String filenameAndTime2 = this.database.getFilenameAndTime(((SelectedUnit) linkedList.get(size - i2)).getUnit());
                    for (int length2 = stringBuffer.length(); length2 < 80; length2++) {
                        printWriter.print(" ");
                    }
                    printWriter.print(filenameAndTime2);
                    printWriter.println();
                    this.logger.debug("Selected units:\n" + stringWriter.toString());
                    int i6 = 0;
                    int i7 = 0;
                    for (int i8 = 1; i8 < iArr.length; i8++) {
                        i6 += iArr[i8] * i8;
                        i7 += iArr[i8];
                    }
                    float f = i6 / i7;
                    DecimalFormat decimalFormat = new DecimalFormat("0.000");
                    this.logger.debug("Avg. consecutive length: " + decimalFormat.format(f) + " units");
                    double score = findBestPath.getScore() / (linkedList.size() - 1);
                    double d = this.cumulTargetCosts / this.nTargetCosts;
                    double d2 = this.cumulJoinCosts / this.nJoinCosts;
                    this.logger.debug("Avg. cost: best path " + decimalFormat.format(score) + ", avg. target " + decimalFormat.format(d) + ", join " + decimalFormat.format(d2) + " (n=" + this.nTargetCosts + ")");
                    DebugStats debugStats2 = debugStats.get(this.database);
                    if (debugStats2 == null) {
                        debugStats2 = new DebugStats();
                        debugStats.put(this.database, debugStats2);
                    }
                    debugStats2.n++;
                    debugStats2.avgLength += (f - debugStats2.avgLength) / debugStats2.n;
                    debugStats2.avgCostBestPath += (score - debugStats2.avgCostBestPath) / debugStats2.n;
                    debugStats2.avgTargetCost += (d - debugStats2.avgTargetCost) / debugStats2.n;
                    debugStats2.avgJoinCost += (d2 - debugStats2.avgJoinCost) / debugStats2.n;
                    this.logger.debug("Total average of " + debugStats2.n + " utterances for this voice:");
                    this.logger.debug("Avg. length: " + decimalFormat.format(debugStats2.avgLength) + ", avg. cost best path: " + decimalFormat.format(debugStats2.avgCostBestPath) + ", avg. target cost: " + decimalFormat.format(debugStats2.avgTargetCost) + ", avg. join cost: " + decimalFormat.format(debugStats2.avgJoinCost));
                }
                return linkedList;
            }
            if (viterbiPath2.getCandidate() != null) {
                Unit unit = viterbiPath2.getCandidate().getUnit();
                Target target = viterbiPath2.getCandidate().getTarget();
                if (!(unit instanceof DiphoneUnit)) {
                    linkedList.addFirst(new SelectedUnit(unit, target));
                } else {
                    if (!$assertionsDisabled && !(target instanceof DiphoneTarget)) {
                        throw new AssertionError();
                    }
                    DiphoneUnit diphoneUnit = (DiphoneUnit) unit;
                    DiphoneTarget diphoneTarget = (DiphoneTarget) target;
                    linkedList.addFirst(new SelectedUnit(diphoneUnit.getRight(), diphoneTarget.getRight()));
                    linkedList.addFirst(new SelectedUnit(diphoneUnit.getLeft(), diphoneTarget.getLeft()));
                }
            }
            viterbiPath = viterbiPath2.getPrevious();
        }
    }

    private ViterbiPath getPath(ViterbiPath viterbiPath, ViterbiCandidate viterbiCandidate) {
        double d;
        ViterbiPath viterbiPath2 = new ViterbiPath();
        Unit unit = viterbiCandidate.getUnit();
        viterbiPath2.setCandidate(viterbiCandidate);
        viterbiPath2.setPrevious(viterbiPath);
        double targetCost = viterbiCandidate.getTargetCost(this.targetCostFunction);
        if (viterbiPath == null || viterbiPath.getCandidate() == null) {
            d = 0.0d;
        } else {
            d = this.joinCostFunction.cost(viterbiPath.getCandidate().getUnit(), unit);
        }
        double d2 = targetCost * this.wTargetCosts;
        double d3 = d * this.wJoinCosts;
        double d4 = d3 + d2;
        if (d3 < Double.POSITIVE_INFINITY) {
            this.cumulJoinCosts += d3;
        }
        this.nJoinCosts++;
        this.cumulTargetCosts += d2;
        this.nTargetCosts++;
        if (viterbiPath == null) {
            viterbiPath2.setScore(d4);
        } else {
            viterbiPath2.setScore(d4 + viterbiPath.getScore());
        }
        return viterbiPath2;
    }

    private ViterbiPath findBestPath() {
        if (!$assertionsDisabled && this.searchStrategy == 0) {
            throw new AssertionError();
        }
        ViterbiPath viterbiPath = (ViterbiPath) this.lastPoint.getPaths().first();
        ViterbiPath viterbiPath2 = viterbiPath;
        viterbiPath.getScore();
        int i = 0;
        while (viterbiPath2 != null) {
            i++;
            ViterbiPath previous = viterbiPath2.getPrevious();
            if (previous != null) {
                previous.setNext(viterbiPath2);
            }
            viterbiPath2 = previous;
        }
        return viterbiPath;
    }

    static {
        $assertionsDisabled = !Viterbi.class.desiredAssertionStatus();
        debugStats = new HashMap();
    }
}
