package com.github.habernal.confusionmatrix;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;

/* loaded from: input_file:com/github/habernal/confusionmatrix/ConfusionMatrix.class */
public class ConfusionMatrix {
    int total = 0;
    int correct = 0;
    private int numberOfDecimalPlaces = 3;
    private TreeSet<String> allGoldLabels = new TreeSet<>();
    private TreeSet<String> allPredictedLabels = new TreeSet<>();
    private List<String> labelSeries = new ArrayList();
    private Map<String, Map<String, Integer>> map = new TreeMap();

    public void setNumberOfDecimalPlaces(int i) throws IllegalArgumentException {
        if (i < 1 || i > 100) {
            throw new IllegalArgumentException("Argument must be in rage 1-100");
        }
        this.numberOfDecimalPlaces = i;
    }

    private String getFormat() {
        return "%." + this.numberOfDecimalPlaces + "f";
    }

    public void increaseValue(String str, String str2) {
        increaseValue(str, str2, 1);
    }

    public List<String> getLabelSeries() {
        return this.labelSeries;
    }

    public void increaseValue(String str, String str2, int i) {
        this.allGoldLabels.add(str);
        this.allPredictedLabels.add(str2);
        for (int i2 = 0; i2 < i; i2++) {
            this.labelSeries.add(str2);
        }
        if (!this.map.containsKey(str)) {
            this.map.put(str, new TreeMap());
        }
        if (!this.map.get(str).containsKey(str2)) {
            this.map.get(str).put(str2, 0);
        }
        this.map.get(str).put(str2, Integer.valueOf(this.map.get(str).get(str2).intValue() + i));
        this.total += i;
        if (str.equals(str2)) {
            this.correct += i;
        }
    }

    public double getAccuracy() {
        return this.correct / this.total;
    }

    public int getTotalSum() {
        return this.total;
    }

    public int getRowSum(String str) {
        int i = 0;
        Iterator<Integer> it = this.map.get(str).values().iterator();
        while (it.hasNext()) {
            i += it.next().intValue();
        }
        return i;
    }

    public int getColSum(String str) {
        int i = 0;
        for (Map<String, Integer> map : this.map.values()) {
            if (map.containsKey(str)) {
                i += map.get(str).intValue();
            }
        }
        return i;
    }

    public Map<String, Double> getPrecisionForLabels() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<String> it = this.allGoldLabels.iterator();
        while (it.hasNext()) {
            String next = it.next();
            linkedHashMap.put(next, Double.valueOf(getPrecisionForLabel(next)));
        }
        return linkedHashMap;
    }

    public double getPrecisionForLabel(String str) {
        double d = 0.0d;
        int i = 0;
        int i2 = 0;
        if (this.map.containsKey(str) && this.map.get(str).containsKey(str)) {
            i = this.map.get(str).get(str).intValue();
            i2 = getColSum(str);
        }
        if (i2 > 0) {
            d = i / i2;
        }
        return d;
    }

    public double getMicroFMeasure() {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (String str : this.map.keySet()) {
            if (this.map.containsKey(str) && this.map.get(str).containsKey(str)) {
                i += this.map.get(str).get(str).intValue();
            }
            i2 += getColSum(str);
            i3 += getRowSum(str);
        }
        double d = i / i2;
        double d2 = i / i3;
        return ((2.0d * d) * d2) / (d + d2);
    }

    public double getMacroFMeasure() {
        double d = 0.0d;
        Iterator<Double> it = getFMeasureForLabels().values().iterator();
        while (it.hasNext()) {
            d += it.next().doubleValue();
        }
        return d / r0.size();
    }

    public double getMacroFMeasure(double d) {
        double d2 = 0.0d;
        Iterator<Double> it = getFMeasureForLabels(d).values().iterator();
        while (it.hasNext()) {
            d2 += it.next().doubleValue();
        }
        return d2 / r0.size();
    }

    public Map<String, Double> getFMeasureForLabels() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Map<String, Double> precisionForLabels = getPrecisionForLabels();
        Map<String, Double> recallForLabels = getRecallForLabels();
        Iterator<String> it = this.allGoldLabels.iterator();
        while (it.hasNext()) {
            String next = it.next();
            double doubleValue = precisionForLabels.get(next).doubleValue();
            double doubleValue2 = recallForLabels.get(next).doubleValue();
            double d = 0.0d;
            if (doubleValue + doubleValue2 > 0.0d) {
                d = ((2.0d * doubleValue) * doubleValue2) / (doubleValue + doubleValue2);
            }
            linkedHashMap.put(next, Double.valueOf(d));
        }
        return linkedHashMap;
    }

    public Map<String, Double> getFMeasureForLabels(double d) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Map<String, Double> precisionForLabels = getPrecisionForLabels();
        Map<String, Double> recallForLabels = getRecallForLabels();
        Iterator<String> it = this.allGoldLabels.iterator();
        while (it.hasNext()) {
            String next = it.next();
            double doubleValue = precisionForLabels.get(next).doubleValue();
            double doubleValue2 = recallForLabels.get(next).doubleValue();
            double d2 = 0.0d;
            if (doubleValue + doubleValue2 > 0.0d) {
                d2 = (1.0d + (d * d)) * ((doubleValue * doubleValue2) / (((d * d) * doubleValue) + doubleValue2));
            }
            linkedHashMap.put(next, Double.valueOf(d2));
        }
        return linkedHashMap;
    }

    public Map<String, Double> getRecallForLabels() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<String> it = this.allGoldLabels.iterator();
        while (it.hasNext()) {
            String next = it.next();
            linkedHashMap.put(next, Double.valueOf(getRecallForLabel(next)));
        }
        return linkedHashMap;
    }

    public double getRecallForLabel(String str) {
        int i = 0;
        double d = 0.0d;
        int i2 = 0;
        if (this.map.containsKey(str) && this.map.get(str).containsKey(str)) {
            i2 = this.map.get(str).get(str).intValue();
            i = getRowSum(str);
        }
        if (i > 0) {
            d = i2 / i;
        }
        return d;
    }

    public double getConfidence95Accuracy() {
        return 1.96d * Math.sqrt((getAccuracy() * (1.0d - getAccuracy())) / this.total);
    }

    public double getConfidence90Accuracy() {
        return 1.645d * Math.sqrt((getAccuracy() * (1.0d - getAccuracy())) / this.total);
    }

    public double getConfidence90AccuracyLow() {
        return getAccuracy() - getConfidence90Accuracy();
    }

    public double getConfidence90AccuracyHigh() {
        return getAccuracy() + getConfidence90Accuracy();
    }

    public double getConfidence95AccuracyLow() {
        return getAccuracy() - getConfidence95Accuracy();
    }

    public double getConfidence95AccuracyHigh() {
        return getAccuracy() + getConfidence95Accuracy();
    }

    public double getConfidence95MacroFM() {
        return 1.96d * Math.sqrt((getMacroFMeasure() * (1.0d - getMacroFMeasure())) / this.total);
    }

    public double getConfidence90MacroFM() {
        return 1.66d * Math.sqrt((getMacroFMeasure() * (1.0d - getMacroFMeasure())) / this.total);
    }

    public double getConfidence95MacroFMLow() {
        return getMacroFMeasure() - getConfidence95MacroFM();
    }

    public double getConfidence95MacroFMHigh() {
        return getMacroFMeasure() + getConfidence95MacroFM();
    }

    public double getCohensKappa() {
        double accuracy = getAccuracy();
        double d = 0.0d;
        Iterator<String> it = this.allGoldLabels.iterator();
        while (it.hasNext()) {
            String next = it.next();
            d += (getRowSum(next) * getColSum(next)) / getTotalSum();
        }
        double totalSum = d / getTotalSum();
        return (accuracy - totalSum) / (1.0d - totalSum);
    }

    private List<List<String>> prepareToString() {
        Iterator<String> it = this.allGoldLabels.iterator();
        while (it.hasNext()) {
            String next = it.next();
            if (!this.map.containsKey(next)) {
                this.map.put(next, new TreeMap());
            }
            Iterator<String> it2 = this.allPredictedLabels.iterator();
            while (it2.hasNext()) {
                String next2 = it2.next();
                if (!this.map.get(next).containsKey(next2)) {
                    this.map.get(next).put(next2, 0);
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        ArrayList<String> arrayList2 = new ArrayList();
        TreeSet treeSet = new TreeSet((SortedSet) this.allPredictedLabels);
        treeSet.removeAll(this.allGoldLabels);
        arrayList2.addAll(this.allGoldLabels);
        arrayList2.addAll(treeSet);
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add("↓gold\\pred→");
        arrayList3.addAll(arrayList2);
        arrayList.add(arrayList3);
        Iterator<String> it3 = this.allGoldLabels.iterator();
        while (it3.hasNext()) {
            String next3 = it3.next();
            ArrayList arrayList4 = new ArrayList();
            arrayList4.add(next3);
            for (String str : arrayList2) {
                int i = 0;
                if (this.map.containsKey(next3) && this.map.get(next3).containsKey(str)) {
                    i = this.map.get(next3).get(str).intValue();
                }
                arrayList4.add(Integer.toString(i));
            }
            arrayList.add(arrayList4);
        }
        return arrayList;
    }

    protected String tableToString(List<List<String>> list) {
        int i = Integer.MIN_VALUE;
        Iterator<List<String>> it = list.iterator();
        while (it.hasNext()) {
            for (String str : it.next()) {
                if (str.length() > i) {
                    i = str.length();
                }
            }
        }
        String str2 = "%" + (i + 1) + "s";
        StringBuilder sb = new StringBuilder();
        Iterator<List<String>> it2 = list.iterator();
        while (it2.hasNext()) {
            Iterator<String> it3 = it2.next().iterator();
            while (it3.hasNext()) {
                sb.append(String.format(str2, it3.next()));
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    public String toString() {
        return tableToString(prepareToString());
    }

    public String toStringLatex() {
        List<List<String>> prepareToString = prepareToString();
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < prepareToString.size(); i++) {
            List<String> list = prepareToString.get(i);
            for (int i2 = 0; i2 < list.size(); i2++) {
                String str = list.get(i2);
                if ((i == 0 || i2 == 0) && !str.isEmpty()) {
                    sb.append("\\textbf{").append(str).append("} ");
                } else {
                    sb.append(str);
                    sb.append(" ");
                }
                if (i2 < list.size() - 1) {
                    sb.append("& ");
                }
            }
            sb.append("\\\\\n");
        }
        return sb.toString();
    }

    public String printNiceResults() {
        return "Macro F-measure: " + String.format(Locale.ENGLISH, getFormat(), Double.valueOf(getMacroFMeasure())) + ", (CI at .95: " + String.format(Locale.ENGLISH, getFormat(), Double.valueOf(getConfidence95MacroFM())) + "), micro F-measure (acc): " + String.format(Locale.ENGLISH, getFormat(), Double.valueOf(getMicroFMeasure()));
    }

    public String printLabelPrecRecFm() {
        Map<String, Double> precisionForLabels = getPrecisionForLabels();
        Map<String, Double> recallForLabels = getRecallForLabels();
        Map<String, Double> fMeasureForLabels = getFMeasureForLabels();
        StringBuilder sb = new StringBuilder("P/R/Fm: ");
        for (Map.Entry<String, Double> entry : precisionForLabels.entrySet()) {
            sb.append(entry.getKey());
            sb.append("=");
            sb.append(String.format(Locale.ENGLISH, getFormat(), entry.getValue()));
            sb.append("/");
            sb.append(String.format(Locale.ENGLISH, getFormat(), recallForLabels.get(entry.getKey())));
            sb.append("/");
            sb.append(String.format(Locale.ENGLISH, getFormat(), fMeasureForLabels.get(entry.getKey())));
            sb.append(" ");
        }
        return sb.toString();
    }

    public double getAvgPrecision() {
        double d = 0.0d;
        Iterator<Double> it = getPrecisionForLabels().values().iterator();
        while (it.hasNext()) {
            d += it.next().doubleValue();
        }
        return d / r0.size();
    }

    public double getAvgRecall() {
        double d = 0.0d;
        Iterator<Double> it = getRecallForLabels().values().iterator();
        while (it.hasNext()) {
            d += it.next().doubleValue();
        }
        return d / r0.size();
    }

    public static ConfusionMatrix createCumulativeMatrix(ConfusionMatrix... confusionMatrixArr) {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        for (ConfusionMatrix confusionMatrix2 : confusionMatrixArr) {
            for (Map.Entry<String, Map<String, Integer>> entry : confusionMatrix2.map.entrySet()) {
                for (Map.Entry<String, Integer> entry2 : entry.getValue().entrySet()) {
                    confusionMatrix.increaseValue(entry.getKey(), entry2.getKey(), entry2.getValue().intValue());
                }
            }
        }
        return confusionMatrix;
    }

    public ConfusionMatrix getSymmetricConfusionMatrix() {
        return createCumulativeMatrix(this, getTransposedMatrix(), getNegativeUnitMatrix());
    }

    public ConfusionMatrix getTransposedMatrix() {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        for (Map.Entry<String, Map<String, Integer>> entry : this.map.entrySet()) {
            for (Map.Entry<String, Integer> entry2 : entry.getValue().entrySet()) {
                confusionMatrix.increaseValue(entry2.getKey(), entry.getKey(), entry2.getValue().intValue());
            }
        }
        return confusionMatrix;
    }

    private ConfusionMatrix getNegativeUnitMatrix() {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        for (Map.Entry<String, Map<String, Integer>> entry : this.map.entrySet()) {
            for (Map.Entry<String, Integer> entry2 : entry.getValue().entrySet()) {
                int intValue = entry2.getValue().intValue();
                if (entry.getKey().equals(entry2.getKey())) {
                    confusionMatrix.increaseValue(entry.getKey(), entry2.getKey(), -intValue);
                }
            }
        }
        return confusionMatrix;
    }

    public static ConfusionMatrix parseFromText(String str) throws IllegalArgumentException {
        try {
            String[] split = str.split("\n");
            String[] split2 = split[0].split("\\s+");
            ArrayList arrayList = new ArrayList();
            for (String str2 : split2) {
                if (!str2.isEmpty()) {
                    arrayList.add(str2);
                }
            }
            ConfusionMatrix confusionMatrix = new ConfusionMatrix();
            for (int i = 1; i < split.length; i++) {
                String[] split3 = split[i].split("\\s+");
                ArrayList arrayList2 = new ArrayList();
                for (String str3 : split3) {
                    if (!str3.isEmpty()) {
                        arrayList2.add(str3);
                    }
                }
                String str4 = (String) arrayList2.get(0);
                for (int i2 = 1; i2 < arrayList2.size(); i2++) {
                    confusionMatrix.increaseValue(str4, (String) arrayList.get(i2 - 1), Integer.valueOf((String) arrayList2.get(i2)).intValue());
                }
            }
            return confusionMatrix;
        } catch (Exception e) {
            throw new IllegalArgumentException("Wrong input format", e);
        }
    }

    public String printClassDistributionGold() {
        StringBuilder sb = new StringBuilder("Gold data distribution\t\t");
        sb.append("Predicted data distribution\n");
        Iterator<String> it = this.allGoldLabels.iterator();
        while (it.hasNext()) {
            String next = it.next();
            int rowSum = getRowSum(next);
            int colSum = getColSum(next);
            sb.append(String.format(Locale.ENGLISH, "%s\t%d\t%.1f", next, Integer.valueOf(rowSum), Double.valueOf((rowSum / getTotalSum()) * 100.0d)));
            sb.append("%\t");
            sb.append(String.format(Locale.ENGLISH, "%d\t%.1f", Integer.valueOf(colSum), Double.valueOf((colSum / getTotalSum()) * 100.0d)));
            sb.append("%\n");
        }
        sb.append(String.format(Locale.ENGLISH, "Sum\t%d%n", Integer.valueOf(getTotalSum())));
        return sb.toString().trim();
    }

    private List<List<String>> prepareToStringProbabilistic() {
        Iterator<String> it = this.allGoldLabels.iterator();
        while (it.hasNext()) {
            String next = it.next();
            if (!this.map.containsKey(next)) {
                this.map.put(next, new TreeMap());
            }
            Iterator<String> it2 = this.allPredictedLabels.iterator();
            while (it2.hasNext()) {
                String next2 = it2.next();
                if (!this.map.get(next).containsKey(next2)) {
                    this.map.get(next).put(next2, 0);
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        ArrayList<String> arrayList2 = new ArrayList();
        TreeSet treeSet = new TreeSet((SortedSet) this.allPredictedLabels);
        treeSet.removeAll(this.allGoldLabels);
        arrayList2.addAll(this.allGoldLabels);
        arrayList2.addAll(treeSet);
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add("↓gold\\pred→");
        arrayList3.addAll(arrayList2);
        arrayList.add(arrayList3);
        Iterator<String> it3 = this.allGoldLabels.iterator();
        while (it3.hasNext()) {
            String next3 = it3.next();
            ArrayList arrayList4 = new ArrayList();
            arrayList4.add(next3);
            double rowSum = getRowSum(next3);
            for (String str : arrayList2) {
                double d = 0.0d;
                if (this.map.containsKey(next3) && this.map.get(next3).containsKey(str)) {
                    d = this.map.get(next3).get(str).intValue() / rowSum;
                }
                arrayList4.add(String.format(getFormat(), Double.valueOf(d)));
            }
            arrayList.add(arrayList4);
        }
        return arrayList;
    }

    public String toStringProbabilistic() {
        return tableToString(getSymmetricConfusionMatrix().prepareToStringProbabilistic());
    }
}
