package marytts.tools.voiceimport;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.LineNumberReader;
import java.util.ArrayList;
import java.util.List;
import java.util.SortedMap;
import java.util.TreeMap;
import marytts.cart.StringPredictionTree;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;
import marytts.machinelearning.GmmDiscretizer;
import weka.classifiers.trees.j48.BinC45ModelSelection;
import weka.classifiers.trees.j48.C45PruneableClassifierTree;
import weka.classifiers.trees.j48.TreeConverter;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:marytts/tools/voiceimport/PauseDurationTrainer.class */
public class PauseDurationTrainer extends VoiceImportComponent {
    public final String[] featureNames = {"breakindex", "ph_cplace", "ph_ctype", "next_pos", "next_wordbegin_ctype", "next_wordbegin_cplace", "words_from_phrase_end", "words_from_phrase_start"};
    public final String FVFILES = "PauseDurationTrainer.featureDir";
    public final String LABFILES = "PauseDurationTrainer.lab";
    public final String TRAINEDTREE = "PauseDurationTrainer.tree";
    protected DatabaseLayout db = null;
    private String fvExt = ".pfeats";
    private String labExt = ".lab";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:marytts/tools/voiceimport/PauseDurationTrainer$VectorsAndDefinition.class */
    public class VectorsAndDefinition {
        private List<FeatureVector> fv;
        private FeatureDefinition fd;

        public VectorsAndDefinition(List<FeatureVector> list, FeatureDefinition featureDefinition) {
            this.fv = list;
            this.fd = featureDefinition;
        }

        public List<FeatureVector> getFv() {
            return this.fv;
        }

        public void setFv(List<FeatureVector> list) {
            this.fv = list;
        }

        public FeatureDefinition getFd() {
            return this.fd;
        }

        public void setFd(FeatureDefinition featureDefinition) {
            this.fd = featureDefinition;
        }
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public void initialiseComp() {
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public SortedMap getDefaultProps(DatabaseLayout databaseLayout) {
        this.db = databaseLayout;
        if (this.props == null) {
            this.props = new TreeMap();
            String property = System.getProperty("PauseDurationTrainer.featureDir");
            if (property == null) {
                StringBuilder sb = new StringBuilder();
                databaseLayout.getClass();
                property = sb.append(databaseLayout.getProp("db.rootDir")).append("pausefeatures").append(System.getProperty("file.separator")).toString();
            }
            this.props.put("PauseDurationTrainer.featureDir", property);
            String property2 = System.getProperty("PauseDurationTrainer.lab");
            if (property2 == null) {
                StringBuilder sb2 = new StringBuilder();
                databaseLayout.getClass();
                property2 = sb2.append(databaseLayout.getProp("db.rootDir")).append("lab").append(System.getProperty("file.separator")).toString();
            }
            this.props.put("PauseDurationTrainer.lab", property2);
            String property3 = System.getProperty("PauseDurationTrainer.tree");
            if (property3 == null) {
                StringBuilder sb3 = new StringBuilder();
                databaseLayout.getClass();
                property3 = sb3.append(databaseLayout.getProp("db.rootDir")).append("durations.tree").toString();
            }
            this.props.put("PauseDurationTrainer.tree", property3);
        }
        return this.props;
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public boolean compute() throws Exception {
        Instances instances = null;
        FeatureDefinition featureDefinition = null;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.bnl.getLength(); i++) {
            VectorsAndDefinition readFeaturesFor = readFeaturesFor(this.bnl.getName(i));
            if (null != readFeaturesFor) {
                List<FeatureVector> fv = readFeaturesFor.getFv();
                featureDefinition = readFeaturesFor.getFd();
                if (instances == null) {
                    instances = initData(featureDefinition);
                }
                BufferedReader bufferedReader = new BufferedReader(new FileReader(getProp("PauseDurationTrainer.lab") + this.bnl.getName(i) + this.labExt));
                ArrayList arrayList2 = new ArrayList();
                ArrayList arrayList3 = new ArrayList();
                int i2 = 0;
                while (true) {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        int featureIndex = featureDefinition.getFeatureIndex(PhoneUnitFeatureComputer.PHONEFEATURE);
                        int featureIndex2 = featureDefinition.getFeatureIndex("breakindex");
                        int i3 = 0;
                        while (((String) arrayList2.get(i3)).equals("_")) {
                            i3++;
                        }
                        for (FeatureVector featureVector : fv) {
                            String featureAsString = featureVector.getFeatureAsString(featureIndex, featureDefinition);
                            if (!featureAsString.equals("_")) {
                                if (!featureAsString.equals(arrayList2.get(i3))) {
                                    throw new IllegalArgumentException("Phone symbol of label file (" + featureAsString + ") and of feature vector (" + ((String) arrayList2.get(i3)) + ") don't correspond. Run CorrectedTranscriptionAligner first.");
                                }
                                int i4 = 0;
                                if (i3 + 1 < arrayList2.size() && ((String) arrayList2.get(i3 + 1)).equals("_")) {
                                    i3++;
                                    i4 = ((Integer) arrayList3.get(i3)).intValue();
                                }
                                if (featureVector.getFeatureAsInt(featureIndex2) > 1) {
                                    arrayList.add(Integer.valueOf(i4));
                                    instances.add(createInstance(instances, featureDefinition, featureVector));
                                }
                                i3++;
                            }
                        }
                    } else if (!readLine.startsWith("#")) {
                        String[] split = readLine.split("\\s+");
                        if (split.length != 3) {
                            throw new IllegalArgumentException("Expected three columns in label file, got " + split.length);
                        }
                        arrayList2.add(split[2]);
                        int parseFloat = (int) (1000.0f * Float.parseFloat(split[0]));
                        arrayList3.add(Integer.valueOf(parseFloat - i2));
                        i2 = parseFloat;
                    }
                }
            }
        }
        StringPredictionTree trainTree = trainTree(enterDurations(instances, arrayList), featureDefinition);
        FileWriter fileWriter = new FileWriter(getProp("PauseDurationTrainer.tree"));
        fileWriter.write(trainTree.toString());
        fileWriter.close();
        return true;
    }

    private StringPredictionTree trainTree(Instances instances, FeatureDefinition featureDefinition) throws Exception {
        System.out.println("training duration tree (" + instances.numInstances() + " instances) ...");
        C45PruneableClassifierTree c45PruneableClassifierTree = new C45PruneableClassifierTree(new BinC45ModelSelection(2, instances), true, 0.25f, true, true);
        c45PruneableClassifierTree.buildClassifier(instances);
        System.out.println("...done");
        return TreeConverter.c45toStringPredictionTree(c45PruneableClassifierTree, featureDefinition, instances);
    }

    private Instances enterDurations(Instances instances, List<Integer> list) {
        GmmDiscretizer trainDiscretizer = GmmDiscretizer.trainDiscretizer(list, 6, true);
        FastVector fastVector = new FastVector();
        for (int i : trainDiscretizer.getPossibleValues()) {
            fastVector.addElement(i + "ms");
        }
        instances.insertAttributeAt(new Attribute("target", fastVector), instances.numAttributes());
        for (int i2 = 0; i2 < list.size(); i2++) {
            instances.instance(i2).setValue(instances.numAttributes() - 1, trainDiscretizer.discretize(list.get(i2).intValue()) + "ms");
        }
        instances.setClassIndex(instances.numAttributes() - 1);
        return instances;
    }

    private Instance createInstance(Instances instances, FeatureDefinition featureDefinition, FeatureVector featureVector) {
        Instance instance = new Instance(instances.numAttributes());
        instance.setDataset(instances);
        for (String str : this.featureNames) {
            instance.setValue(instances.attribute(str), featureVector.getFeatureAsString(featureDefinition.getFeatureIndex(str), featureDefinition));
        }
        return instance;
    }

    private Instances initData(FeatureDefinition featureDefinition) {
        FastVector fastVector = new FastVector();
        for (int i = 0; i < featureDefinition.getNumberOfFeatures(); i++) {
            String featureName = featureDefinition.getFeatureName(i);
            if (!featureName.equals(PhoneUnitFeatureComputer.PHONEFEATURE)) {
                FastVector fastVector2 = new FastVector();
                for (String str : featureDefinition.getPossibleValues(i)) {
                    fastVector2.addElement(str);
                }
                fastVector.addElement(new Attribute(featureName, fastVector2));
            }
        }
        return new Instances("pausedurations", fastVector, 0);
    }

    private VectorsAndDefinition readFeatureTable(LineNumberReader lineNumberReader) throws IOException {
        ArrayList arrayList = new ArrayList();
        FeatureDefinition featureDefinition = new FeatureDefinition(lineNumberReader, false);
        try {
            featureDefinition.getFeatureIndex(PhoneUnitFeatureComputer.PHONEFEATURE);
            featureDefinition.getFeatureIndex("breakindex");
            do {
            } while (!lineNumberReader.readLine().equals(""));
            while (true) {
                String readLine = lineNumberReader.readLine();
                if (readLine == null) {
                    return new VectorsAndDefinition(arrayList, featureDefinition);
                }
                try {
                    arrayList.add(featureDefinition.toFeatureVector(0, readLine));
                } catch (Exception e) {
                    e.printStackTrace();
                    throw new IOException("Unexpected Input in line " + String.valueOf(lineNumberReader.getLineNumber()));
                }
            }
        } catch (IllegalArgumentException e2) {
            throw new IOException("Unexpected FeatureDefinition: Does not contain the features 'phone' and 'breakindex'.");
        }
    }

    private VectorsAndDefinition readFeaturesFor(String str) throws IOException {
        File file = new File(getProp("PauseDurationTrainer.featureDir") + str + this.fvExt);
        if (!file.exists()) {
            return null;
        }
        FileInputStream fileInputStream = new FileInputStream(file);
        System.out.println("processing " + getProp("PauseDurationTrainer.featureDir") + str + this.fvExt);
        return readFeatureTable(new LineNumberReader(new InputStreamReader(fileInputStream)));
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public String getName() {
        return "PauseDurationTrainer";
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public int getProgress() {
        return 0;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public void setupHelp() {
        this.props2Help = new TreeMap();
        this.props2Help.put("PauseDurationTrainer.featureDir", "Directory containing the pause feature files.");
        this.props2Help.put("PauseDurationTrainer.lab", "Directory containing label files from which pause durations are taken.");
        this.props2Help.put("PauseDurationTrainer.tree", "Result of training.");
    }
}
