/*
 * Decompiled with CFR 0.152.
 */
package tech.tablesaw.api.ml.clustering;

import smile.clustering.XMeans;
import tech.tablesaw.api.CategoryColumn;
import tech.tablesaw.api.FloatColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.util.DoubleArrays;

public class Xmeans {
    private final XMeans model;
    private final NumericColumn[] inputColumns;

    public Xmeans(int maxK, NumericColumn ... columns) {
        double[][] data = DoubleArrays.to2dArray(columns);
        this.model = new XMeans(data, maxK);
        this.inputColumns = columns;
    }

    public int predict(double[] x) {
        return this.model.predict(x);
    }

    public double[][] centroids() {
        return this.model.centroids();
    }

    public double distortion() {
        return this.model.distortion();
    }

    public int getClusterCount() {
        return this.model.getNumClusters();
    }

    public int[] getClusterLabels() {
        return this.model.getClusterLabel();
    }

    public int[] getClusterSizes() {
        return this.model.getClusterSize();
    }

    public Table labeledCentroids() {
        Table table = Table.create("Centroids");
        CategoryColumn labelColumn = new CategoryColumn("Cluster");
        table.addColumn(labelColumn);
        for (int i = 0; i < this.inputColumns.length; ++i) {
            FloatColumn centroid = new FloatColumn(this.inputColumns[i].name());
            table.addColumn(centroid);
        }
        double[][] centroids = this.model.centroids();
        for (int i = 0; i < centroids.length; ++i) {
            labelColumn.appendCell(String.valueOf(i));
            double[] values = centroids[i];
            for (int k = 0; k < values.length; ++k) {
                table.floatColumn(k + 1).append((float)values[k]);
            }
        }
        return table;
    }
}

