package marytts.tools.voiceimport;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringReader;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.Vector;
import marytts.cart.CART;
import marytts.cart.LeafNode;
import marytts.cart.io.HTSCARTReader;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureVector;
import marytts.htsengine.HMMData;
import marytts.htsengine.PhoneTranslator;
import marytts.unitselection.data.FeatureFileReader;
import marytts.unitselection.data.UnitFileReader;
import marytts.unitselection.select.JoinCostFeatures;
import marytts.util.math.MathUtils;
import org.apache.log4j.BasicConfigurator;

/* loaded from: input_file:marytts/tools/voiceimport/JoinModeller.class */
public class JoinModeller extends VoiceImportComponent {
    private DatabaseLayout db = null;
    private int percent = 0;
    private HMMData htsData = null;
    private Vector<String> featureList = null;
    private Map<String, String> feat2shortFeat = new HashMap();
    private int numberOfFeatures = 0;
    private float[] fw = null;
    private String[] wfun = null;
    FileWriter statsStream = null;
    FileWriter mmfStream = null;
    FileWriter fullStream = null;
    public final String JOINCOSTFEATURESFILE = "JoinModeller.joinCostFeaturesFile";
    public final String UNITFEATURESFILE = "JoinModeller.unitFeaturesFile";
    public final String UNITFILE = "JoinModeller.unitFile";
    public final String STATSFILE = "JoinModeller.statsFile";
    public final String MMFFILE = "JoinModeller.mmfFile";
    public final String FULLFILE = "JoinModeller.fullFile";
    public final String CXCHEDFILE = "JoinModeller.cxcJoinFile";
    public final String JOINTREEFILE = "JoinModeller.joinTreeFile";
    public final String CNVHEDFILE = "JoinModeller.cnvJoinFile";
    public final String TRNCONFFILE = "JoinModeller.trnFile";
    public final String CNVCONFFILE = "JoinModeller.cnvFile";
    public final String HHEDCOMMAND = "JoinModeller.hhedCommand";
    public final String FEATURELISTFILE = "JoinModeller.featureListFile";
    public final String ALLOPHONESFILE = "JoinModeller.allophonesFile";
    public final String TRICKYPHONESFILE = "JoinModeller.trickyPhonesFile";
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public String getName() {
        return "JoinModeller";
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public SortedMap<String, String> getDefaultProps(DatabaseLayout databaseLayout) {
        this.db = databaseLayout;
        if (this.props == null) {
            this.props = new TreeMap();
            databaseLayout.getClass();
            String prop = databaseLayout.getProp("db.fileDir");
            databaseLayout.getClass();
            String prop2 = databaseLayout.getProp("db.maryExtension");
            this.props.put("JoinModeller.joinCostFeaturesFile", prop + "joinCostFeatures" + prop2);
            this.props.put("JoinModeller.unitFeaturesFile", prop + "halfphoneFeatures" + prop2);
            this.props.put("JoinModeller.unitFile", prop + "halfphoneUnits" + prop2);
            this.props.put("JoinModeller.statsFile", prop + "stats" + prop2);
            this.props.put("JoinModeller.mmfFile", prop + "join_mmf" + prop2);
            this.props.put("JoinModeller.fullFile", prop + "fullList" + prop2);
            this.props.put("JoinModeller.featureListFile", prop + "/mary/featureListFile.txt");
            this.props.put("JoinModeller.allophonesFile", "/project/mary/marcela/openmary/lib/modules/en/us/lexicon/allophones.en_US.xml");
            this.props.put("JoinModeller.trickyPhonesFile", prop + "/mary/trickyPhones.txt");
            this.props.put("JoinModeller.cxcJoinFile", prop + "cxc_join.hed");
            this.props.put("JoinModeller.joinTreeFile", prop + "join_tree.inf");
            this.props.put("JoinModeller.cnvJoinFile", prop + "cnv_join.hed");
            this.props.put("JoinModeller.trnFile", prop + "trn.cnf");
            this.props.put("JoinModeller.cnvFile", prop + "cnv.cnf");
            this.props.put("JoinModeller.hhedCommand", "/project/mary/marcela/sw/HTS_2.0.1/htk/bin/HHEd");
        }
        return this.props;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public void setupHelp() {
        this.props2Help = new TreeMap();
        this.props2Help.put("JoinModeller.joinCostFeaturesFile", "file containing all halfphone units and their join cost features");
        this.props2Help.put("JoinModeller.unitFeaturesFile", "file containing all halfphone units and their target cost features");
        this.props2Help.put("JoinModeller.unitFile", "file containing all halfphone units");
        this.props2Help.put("JoinModeller.statsFile", "output file containing statistics of the models in HTK stats format");
        this.props2Help.put("JoinModeller.mmfFile", "output file containing one state HMM models, HTK format, representing join models (mean and variances are calculated in this class)");
        this.props2Help.put("JoinModeller.fullFile", "output file containing the full list of HMM model names");
        this.props2Help.put("JoinModeller.featureListFile", "feature list for making fullcontext names and questions");
        this.props2Help.put("JoinModeller.trickyPhonesFile", "list of aliases for tricky phones, so HTK-HHEd command can handle them.");
        this.props2Help.put("JoinModeller.cxcJoinFile", "HTK hed file used by HHEd, load stats file, contains questions for decision tree-based context clustering and outputs result in join-tree.inf");
        this.props2Help.put("JoinModeller.cnvJoinFile", "HTK hed file used by HHEd to convert trees and mmf into hts_engine format");
        this.props2Help.put("JoinModeller.allophonesFile", "allophones set (language dependent, an example can be found in ../openmary/lib/modules/language/...)");
        this.props2Help.put("JoinModeller.trnFile", "HTK configuration file for context clustering");
        this.props2Help.put("JoinModeller.cnvFile", "HTK configuration file for converting to hts_engine format");
        this.props2Help.put("JoinModeller.hhedCommand", "HTS-HTK HHEd command, HTS version minimum HTS_2.0.1");
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public boolean compute() throws IOException, Exception {
        double[] mean;
        double[] variance;
        System.out.println("\n---- Training join models\n");
        FeatureFileReader featureFileReader = FeatureFileReader.getFeatureFileReader(getProp("JoinModeller.unitFeaturesFile"));
        JoinCostFeatures joinCostFeatures = new JoinCostFeatures(getProp("JoinModeller.joinCostFeaturesFile"));
        UnitFileReader unitFileReader = new UnitFileReader(getProp("JoinModeller.unitFile"));
        FeatureDefinition featureDefinition = featureFileReader.getFeatureDefinition();
        PhoneTranslator phoneTranslator = HMMVoiceMakeData.checkTrickyPhones(getProp("JoinModeller.allophonesFile"), getProp("JoinModeller.trickyPhonesFile")) ? new PhoneTranslator(getProp("JoinModeller.trickyPhonesFile")) : new PhoneTranslator("");
        this.featureList = new Vector<>();
        readFeatureList(getProp("JoinModeller.featureListFile"), featureDefinition);
        this.statsStream = new FileWriter(getProp("JoinModeller.statsFile"));
        this.mmfStream = new FileWriter(getProp("JoinModeller.mmfFile"));
        this.fullStream = new FileWriter(getProp("JoinModeller.fullFile"));
        int numberOfFeatures = joinCostFeatures.getNumberOfFeatures();
        this.mmfStream.write("~o\n<VECSIZE> " + numberOfFeatures + " <USER><DIAGC>\n~t \"trP_1\"\n<TRANSP> 3\n0 1 0\n0 0 1\n0 0 0\n");
        if (featureFileReader.getNumberOfUnits() != joinCostFeatures.getNumberOfUnits()) {
            throw new IllegalStateException("Number of units in unit and join feature files does not match!");
        }
        if (featureFileReader.getNumberOfUnits() != unitFileReader.getNumberOfUnits()) {
            throw new IllegalStateException("Number of units in unit file and unit feature file does not match!");
        }
        int numberOfUnits = featureFileReader.getNumberOfUnits();
        int featureIndex = featureDefinition.getFeatureIndex(PhoneUnitFeatureComputer.PHONEFEATURE);
        int numberOfValues = featureDefinition.getNumberOfValues(featureIndex);
        int featureIndex2 = featureDefinition.getFeatureIndex("halfphone_lr");
        featureDefinition.getFeatureValueAsByte(featureIndex2, "L");
        featureDefinition.getFeatureValueAsByte(featureIndex2, "R");
        int featureIndex3 = featureDefinition.getFeatureIndex("edge");
        byte featureValueAsByte = featureDefinition.getFeatureValueAsByte(featureIndex3, "0");
        featureDefinition.getFeatureValueAsByte(featureIndex3, "start");
        featureDefinition.getFeatureValueAsByte(featureIndex3, "end");
        HashMap hashMap = new HashMap();
        FeatureVector featureVector = featureFileReader.getFeatureVector(0);
        for (int i = 0; i < numberOfUnits - 1; i++) {
            this.percent = (100 * (i + 1)) / numberOfUnits;
            FeatureVector featureVector2 = featureVector;
            byte byteFeature = featureVector2.getByteFeature(featureIndex3);
            featureVector = featureFileReader.getFeatureVector(i + 1);
            byte byteFeature2 = featureVector.getByteFeature(featureIndex3);
            if (byteFeature == featureValueAsByte && byteFeature2 == featureValueAsByte) {
                int featureAsInt = featureVector2.getFeatureAsInt(featureIndex);
                if (!$assertionsDisabled && (0 > featureAsInt || featureAsInt >= numberOfValues)) {
                    throw new AssertionError();
                }
                String str = featureDefinition.getFeatureValueAsString(featureIndex, featureAsInt) + "_" + featureDefinition.getFeatureValueAsString(featureIndex2, featureVector2.getByteFeature(featureIndex2)) + "-" + featureDefinition.getFeatureValueAsString(featureIndex, featureVector.getFeatureAsInt(featureIndex)) + "_" + featureDefinition.getFeatureValueAsString(featureIndex2, featureVector.getByteFeature(featureIndex2));
                float[] rightJCF = joinCostFeatures.getRightJCF(i);
                float[] leftJCF = joinCostFeatures.getLeftJCF(i + 1);
                double[] dArr = new double[rightJCF.length];
                int length = rightJCF.length;
                for (int i2 = 0; i2 < length; i2++) {
                    if (i2 == length - 1 && (rightJCF[i2] == Float.POSITIVE_INFINITY || leftJCF[i2] == Float.POSITIVE_INFINITY)) {
                        dArr[i2] = 0.0d;
                        System.out.println("WARNING: numUnit=" + i + " myRightFrame[k]=" + rightJCF[i2] + " nextLeftFrame[k]=" + leftJCF[i2]);
                    } else {
                        dArr[i2] = rightJCF[i2] - leftJCF[i2];
                    }
                }
                String features2LongContext = phoneTranslator.features2LongContext(featureDefinition, featureVector2, this.featureList);
                Set set = (Set) hashMap.get(features2LongContext);
                if (set == null) {
                    set = new HashSet();
                    hashMap.put(features2LongContext, set);
                }
                set.add(dArr);
            }
        }
        int i3 = 1;
        for (String str2 : hashMap.keySet()) {
            double[][] dArr2 = (double[][]) ((Set) hashMap.get(str2)).toArray((Object[]) new double[0]);
            int length2 = dArr2.length;
            if (length2 == 1) {
                mean = dArr2[0];
                variance = MathUtils.zeros(mean.length);
            } else {
                mean = MathUtils.mean(dArr2, true);
                variance = MathUtils.variance(dArr2, mean, true);
            }
            if (!$assertionsDisabled && mean.length != numberOfFeatures) {
                throw new AssertionError("expected to have " + numberOfFeatures + " features, got " + mean.length);
            }
            this.fullStream.write(str2 + "\n");
            this.statsStream.write(i3 + " \"" + str2 + "\"    " + length2 + "    " + length2 + "\n");
            this.mmfStream.write("~h \"" + str2 + "\"\n");
            this.mmfStream.write("<BEGINHMM>\n<NUMSTATES> 3\n<STATE> 2\n");
            this.mmfStream.write("<MEAN> " + numberOfFeatures + "\n");
            for (double d : mean) {
                this.mmfStream.write(d + " ");
            }
            this.mmfStream.write("\n<VARIANCE> " + numberOfFeatures + "\n");
            for (double d2 : variance) {
                this.mmfStream.write(d2 + " ");
            }
            this.mmfStream.write("\n~t \"trP_1\"\n<ENDHMM>\n");
            i3++;
        }
        this.fullStream.close();
        this.statsStream.close();
        this.mmfStream.close();
        DatabaseLayout databaseLayout = this.db;
        this.db.getClass();
        String prop = databaseLayout.getProp("db.fileDir");
        System.out.println(hashMap.keySet().size() + " unique feature vectors, " + numberOfUnits + " units");
        System.out.println("Generated files: " + getProp("JoinModeller.statsFile") + " " + getProp("JoinModeller.mmfFile") + " " + getProp("JoinModeller.fullFile"));
        System.out.println("\n---- Creating tree clustering command file for HHEd\n");
        PrintWriter printWriter = new PrintWriter(new File(getProp("JoinModeller.cxcJoinFile")));
        printWriter.println("// load stats file");
        printWriter.println("RO 000 \"" + getProp("JoinModeller.statsFile") + "\"");
        printWriter.println();
        printWriter.println("TR 0");
        printWriter.println();
        printWriter.println("// questions for decision tree-based context clustering");
        Iterator<String> it = this.featureList.iterator();
        while (it.hasNext()) {
            String next = it.next();
            for (String str3 : featureDefinition.getPossibleValues(featureDefinition.getFeatureIndex(next))) {
                if (next.endsWith(PhoneUnitFeatureComputer.PHONEFEATURE)) {
                    str3 = phoneTranslator.replaceTrickyPhones(str3);
                } else if (next.endsWith("sentence_punc") || next.endsWith("punctuation")) {
                    str3 = phoneTranslator.replacePunc(str3);
                }
                printWriter.println("QS \"" + next + "=" + str3 + "\" {*|" + next + "=" + str3 + "|*}");
            }
            printWriter.println();
        }
        printWriter.println("TR 3");
        printWriter.println();
        printWriter.println("// construct decision trees");
        printWriter.println("TB 000 join_ {*.state[2]}");
        printWriter.println();
        printWriter.println("TR 1");
        printWriter.println();
        printWriter.println("// output constructed tree");
        printWriter.println("ST \"" + getProp("JoinModeller.joinTreeFile") + "\"");
        printWriter.close();
        System.out.println("\n---- Tree-based context clustering for joinModeller\n");
        General.launchProc(getProp("JoinModeller.hhedCommand") + " -A -C " + getProp("JoinModeller.trnFile") + " -D -T 2 -p -i -H " + getProp("JoinModeller.mmfFile") + " -m -a 1.0 -w " + getProp("JoinModeller.mmfFile") + ".clustered " + getProp("JoinModeller.cxcJoinFile") + " " + getProp("JoinModeller.fullFile"), "HHEd", prop);
        System.out.println("\n---- Creating conversion-to-hts command file for HHEd\n");
        PrintWriter printWriter2 = new PrintWriter(new File(getProp("JoinModeller.cnvJoinFile")));
        printWriter2.println("TR 3");
        printWriter2.println();
        printWriter2.println("// load trees for joinModeller");
        printWriter2.println("LT \"" + getProp("JoinModeller.joinTreeFile") + "\"");
        printWriter2.println();
        printWriter2.println("// convert loaded trees for hts_engine format");
        printWriter2.println("CT \"" + prop + "\"");
        printWriter2.println();
        printWriter2.println("// convert mmf for hts_engine format");
        printWriter2.println("CM \"" + prop + "\"");
        printWriter2.close();
        System.out.println("\n---- Converting mmfs to the hts_engine file format\n");
        General.launchProc(getProp("JoinModeller.hhedCommand") + " -A -C " + getProp("JoinModeller.cnvFile") + " -D -T 1 -p -i -H " + getProp("JoinModeller.mmfFile") + ".clustered " + getProp("JoinModeller.cnvJoinFile") + " " + getProp("JoinModeller.fullFile"), "HHEd", prop);
        General.launchProc("mv " + prop + "trees.1 " + prop + "tree-joinModeller.inf", "mv", prop);
        General.launchProc("mv " + prop + "pdf.1 " + prop + "joinModeller.pdf", "mv", prop);
        System.out.println("\n---- Created files: tree-joinModeller.inf, joinModeller.pdf");
        return true;
    }

    @Override // marytts.tools.voiceimport.VoiceImportComponent
    public int getProgress() {
        return this.percent;
    }

    private void readFeatureList(String str, FeatureDefinition featureDefinition) throws Exception {
        try {
            Scanner useDelimiter = new Scanner(new BufferedReader(new FileReader(str))).useDelimiter("\n");
            System.out.println("The following are other context features used for training Hmms: ");
            while (useDelimiter.hasNext()) {
                String nextLine = useDelimiter.nextLine();
                if (!featureDefinition.hasFeature(nextLine)) {
                    throw new Exception("Error: feature \"" + nextLine + "\" in feature list file: " + str + " does not exist in FeatureDefinition.");
                }
                this.featureList.add(nextLine);
            }
            if (useDelimiter != null) {
                useDelimiter.close();
            }
        } catch (FileNotFoundException e) {
            System.out.println("readFeatureList:  " + e.getMessage());
        }
        System.out.println("readFeatureList: loaded " + this.featureList.size() + " context features from " + str);
    }

    public static void main(String[] strArr) throws IOException, InterruptedException {
        String str;
        try {
            BasicConfigurator.configure();
            Scanner scanner = new Scanner(new BufferedReader(new FileReader("/project/mary/marcela/unitselection-halfphone.pfeats")));
            String str2 = "";
            while (scanner.hasNext()) {
                str2 = (str2 + scanner.nextLine()) + "\n";
            }
            scanner.close();
            FeatureDefinition featureDefinition = new FeatureDefinition(new BufferedReader(new StringReader(str2)), false);
            str = "/project/mary/marcela/HMM-voices/DFKI_German_Poker/mary_files_old/trickyPhones.txt";
            PhoneTranslator phoneTranslator = new PhoneTranslator(HMMVoiceMakeData.checkTrickyPhones("/project/mary/marcela/openmary/lib/modules/en/us/lexicon/allophones.en_US.xml", str) ? "/project/mary/marcela/HMM-voices/DFKI_German_Poker/mary_files_old/trickyPhones.txt" : "");
            HTSCARTReader hTSCARTReader = new HTSCARTReader();
            try {
                CART[] load = hTSCARTReader.load(1, "/project/mary/marcela/HMM-voices/DFKI_German_Poker/mary_files_old/tree-joinModeller.inf", "/project/mary/marcela/HMM-voices/DFKI_German_Poker/mary_files_old/joinModeller.pdf", featureDefinition, phoneTranslator);
                int vectorSize = hTSCARTReader.getVectorSize();
                double[] dArr = new double[vectorSize];
                double[] dArr2 = new double[vectorSize];
                Scanner scanner2 = new Scanner(new BufferedReader(new FileReader("/project/mary/marcela/unitselection-halfphone.pfeats")));
                while (scanner2.hasNext() && !scanner2.nextLine().trim().equals("")) {
                }
                while (scanner2.hasNext() && !scanner2.nextLine().trim().equals("")) {
                }
                while (scanner2.hasNext()) {
                    LeafNode.PdfLeafNode interpretToNode = load[0].interpretToNode(featureDefinition.toFeatureVector(0, scanner2.nextLine()), 1);
                    if (!$assertionsDisabled && !(interpretToNode instanceof LeafNode.PdfLeafNode)) {
                        throw new AssertionError("The node must be a PdfLeafNode.");
                    }
                    double[] mean = interpretToNode.getMean();
                    double[] variance = interpretToNode.getVariance();
                    System.out.print("mean: ");
                    for (int i = 0; i < vectorSize; i++) {
                        System.out.print(mean[i] + " ");
                    }
                    System.out.print("\nvariance: ");
                    for (int i2 = 0; i2 < vectorSize; i2++) {
                        System.out.print(variance[i2] + " ");
                    }
                    System.out.println();
                }
            } catch (Exception e) {
                IOException iOException = new IOException("Cannot load join model trees from /project/mary/marcela/HMM-voices/DFKI_German_Poker/mary_files_old/tree-joinModeller.inf");
                iOException.initCause(e);
                throw iOException;
            }
        } catch (Exception e2) {
            System.err.println("Exception: " + e2.getMessage());
        }
    }

    static {
        $assertionsDisabled = !JoinModeller.class.desiredAssertionStatus();
    }
}
