package marytts.tools.voiceimport;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.SortedMap;
import java.util.TreeMap;
import marytts.cart.DecisionNode;
import marytts.cart.DirectedGraph;
import marytts.cart.DirectedGraphNode;
import marytts.cart.LeafNode;
import marytts.cart.Node;
import marytts.cart.io.DirectedGraphWriter;
import marytts.features.FeatureVector;
import marytts.tools.voiceimport.traintrees.AgglomerativeClusterer;
import marytts.tools.voiceimport.traintrees.DurationDistanceMeasure;
import marytts.unitselection.data.FeatureFileReader;
import marytts.unitselection.data.UnitFileReader;
import marytts.util.math.MathUtils;

/* loaded from: input_file:marytts/tools/voiceimport/DurationTreeTrainer.class */
public class DurationTreeTrainer extends VoiceImportComponent {
    protected DatabaseLayout db = null;
    private final String name = "DurationTreeTrainer";
    public final String DURTREE = "DurationTreeTrainer.durTree";
    public final String FEATUREFILE = "DurationTreeTrainer.featureFile";
    public final String UNITFILE = "DurationTreeTrainer.unitFile";
    public final String MAXDATA = "DurationTreeTrainer.maxData";
    public final String PROPORTIONTESTDATA = "DurationTreeTrainer.propTestData";
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public String getName() {
        return "DurationTreeTrainer";
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public void initialiseComp() {
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public SortedMap<String, String> getDefaultProps(DatabaseLayout databaseLayout) {
        this.db = databaseLayout;
        if (this.props == null) {
            this.props = new TreeMap();
            System.getProperty("file.separator");
            SortedMap<String, String> sortedMap = this.props;
            StringBuilder sb = new StringBuilder();
            DatabaseLayout databaseLayout2 = this.db;
            this.db.getClass();
            StringBuilder append = sb.append(databaseLayout2.getProp("db.fileDir")).append("phoneFeatures");
            DatabaseLayout databaseLayout3 = this.db;
            this.db.getClass();
            sortedMap.put("DurationTreeTrainer.featureFile", append.append(databaseLayout3.getProp("db.maryExtension")).toString());
            SortedMap<String, String> sortedMap2 = this.props;
            StringBuilder sb2 = new StringBuilder();
            DatabaseLayout databaseLayout4 = this.db;
            this.db.getClass();
            StringBuilder append2 = sb2.append(databaseLayout4.getProp("db.fileDir")).append("phoneUnits");
            DatabaseLayout databaseLayout5 = this.db;
            this.db.getClass();
            sortedMap2.put("DurationTreeTrainer.unitFile", append2.append(databaseLayout5.getProp("db.maryExtension")).toString());
            SortedMap<String, String> sortedMap3 = this.props;
            StringBuilder sb3 = new StringBuilder();
            DatabaseLayout databaseLayout6 = this.db;
            this.db.getClass();
            sortedMap3.put("DurationTreeTrainer.durTree", sb3.append(databaseLayout6.getProp("db.fileDir")).append("dur.graph.mry").toString());
            this.props.put("DurationTreeTrainer.maxData", "0");
            this.props.put("DurationTreeTrainer.propTestData", "0.1");
        }
        return this.props;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public void setupHelp() {
        this.props2Help = new TreeMap();
        this.props2Help.put("DurationTreeTrainer.featureFile", "file containing all phone units and their target cost features");
        this.props2Help.put("DurationTreeTrainer.unitFile", "file containing all phone units");
        this.props2Help.put("DurationTreeTrainer.durTree", "file containing the duration tree. Will be created by this module");
        this.props2Help.put("DurationTreeTrainer.maxData", "if >0, gives the maximum number of syllables to use for training the tree");
        this.props2Help.put("DurationTreeTrainer.propTestData", "the proportion of the data to use as test data (choose so that 1/value is an integer)");
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public boolean compute() throws IOException {
        DirectedGraph cluster;
        this.logger.info("Duration tree trainer started.");
        FeatureFileReader featureFileReader = FeatureFileReader.getFeatureFileReader(getProp("DurationTreeTrainer.featureFile"));
        UnitFileReader unitFileReader = new UnitFileReader(getProp("DurationTreeTrainer.unitFile"));
        FeatureVector[] featureVectors = featureFileReader.getFeatureVectors();
        int parseInt = Integer.parseInt(getProp("DurationTreeTrainer.maxData"));
        if (parseInt == 0) {
            parseInt = featureVectors.length;
        }
        FeatureVector[] featureVectorArr = new FeatureVector[Math.min(parseInt, featureVectors.length)];
        System.arraycopy(featureVectors, 0, featureVectorArr, 0, featureVectorArr.length);
        this.logger.debug("Total of " + featureVectors.length + " feature vectors -- will use " + featureVectorArr.length);
        AgglomerativeClusterer agglomerativeClusterer = new AgglomerativeClusterer(featureVectorArr, featureFileReader.getFeatureDefinition(), (List) null, new DurationDistanceMeasure(unitFileReader), Float.parseFloat(getProp("DurationTreeTrainer.propTestData")));
        DirectedGraphWriter directedGraphWriter = new DirectedGraphWriter();
        int i = 0;
        do {
            cluster = agglomerativeClusterer.cluster();
            i++;
            if (cluster != null) {
                directedGraphWriter.saveGraph(cluster, getProp("DurationTreeTrainer.durTree") + ".level" + i);
            }
        } while (agglomerativeClusterer.canClusterMore());
        if (cluster == null) {
            return false;
        }
        Iterator it = cluster.getLeafNodes().iterator();
        while (it.hasNext()) {
            Node node = (LeafNode.FeatureVectorLeafNode) ((LeafNode) it.next());
            FeatureVector[] featureVectors2 = node.getFeatureVectors();
            double[] dArr = new double[featureVectors2.length];
            for (int i2 = 0; i2 < featureVectors2.length; i2++) {
                dArr[i2] = unitFileReader.getUnit(featureVectors2[i2].getUnitIndex()).duration / unitFileReader.getSampleRate();
            }
            double mean = MathUtils.mean(dArr);
            LeafNode.FloatLeafNode floatLeafNode = new LeafNode.FloatLeafNode(new float[]{(float) MathUtils.standardDeviation(dArr, mean), (float) mean});
            DirectedGraphNode mother = node.getMother();
            if (!$assertionsDisabled && mother == null) {
                throw new AssertionError();
            }
            if (mother.isDecisionNode()) {
                ((DecisionNode) mother).replaceDaughter(floatLeafNode, node.getNodeIndex());
            } else {
                if (!$assertionsDisabled && !mother.isDirectedGraphNode()) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && mother.getLeafNode() != node) {
                    throw new AssertionError();
                }
                mother.setLeafNode(floatLeafNode);
            }
        }
        directedGraphWriter.saveGraph(cluster, getProp("DurationTreeTrainer.durTree"));
        return true;
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public int getProgress() {
        return -1;
    }

    public static void main(String[] strArr) throws Exception {
        DurationTreeTrainer durationTreeTrainer = new DurationTreeTrainer();
        new DatabaseLayout(durationTreeTrainer);
        durationTreeTrainer.compute();
    }

    static {
        $assertionsDisabled = !DurationTreeTrainer.class.desiredAssertionStatus();
    }
}
