/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.optimizer.examples;

import java.util.Arrays;
import java.util.Collection;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.operators.GenericDataSourceBase;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.optimizer.plan.NamedChannel;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.plan.SinkPlanNode;
import org.apache.flink.optimizer.util.CompilerTestBase;
import org.apache.flink.optimizer.util.OperatorResolver;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.operators.util.LocalStrategy;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

public class KMeansSingleStepTest
extends CompilerTestBase {
    private static final String DATAPOINTS = "Data Points";
    private static final String CENTERS = "Centers";
    private static final String MAPPER_NAME = "Find Nearest Centers";
    private static final String REDUCER_NAME = "Recompute Center Positions";
    private static final String SINK = "New Center Positions";
    private final FieldList set0 = new FieldList(0);

    @Test
    public void testCompileKMeansSingleStepWithStats() throws Exception {
        Plan p = KMeansSingleStepTest.getKMeansPlan();
        p.setExecutionConfig(new ExecutionConfig());
        OperatorResolver cr = KMeansSingleStepTest.getContractResolver((Plan)p);
        GenericDataSourceBase pointsSource = (GenericDataSourceBase)cr.getNode(DATAPOINTS);
        GenericDataSourceBase centersSource = (GenericDataSourceBase)cr.getNode(CENTERS);
        this.setSourceStatistics(pointsSource, 0x1900000000L, 32.0f);
        this.setSourceStatistics(centersSource, 0x100000L, 32.0f);
        OptimizedPlan plan = this.compileWithStats(p);
        this.checkPlan(plan);
    }

    @Test
    public void testCompileKMeansSingleStepWithOutStats() throws Exception {
        Plan p = KMeansSingleStepTest.getKMeansPlan();
        p.setExecutionConfig(new ExecutionConfig());
        OptimizedPlan plan = this.compileNoStats(p);
        this.checkPlan(plan);
    }

    private void checkPlan(OptimizedPlan plan) {
        CompilerTestBase.OptimizerPlanNodeResolver or = KMeansSingleStepTest.getOptimizerPlanNodeResolver((OptimizedPlan)plan);
        SinkPlanNode sink = (SinkPlanNode)or.getNode(SINK);
        SingleInputPlanNode reducer = (SingleInputPlanNode)or.getNode(REDUCER_NAME);
        SingleInputPlanNode combiner = (SingleInputPlanNode)reducer.getPredecessor();
        SingleInputPlanNode mapper = (SingleInputPlanNode)or.getNode(MAPPER_NAME);
        Assert.assertEquals((long)1L, (long)mapper.getBroadcastInputs().size());
        Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)mapper.getInput().getShipStrategy());
        Assert.assertEquals((Object)ShipStrategyType.BROADCAST, (Object)((NamedChannel)mapper.getBroadcastInputs().get(0)).getShipStrategy());
        Assert.assertEquals((Object)LocalStrategy.NONE, (Object)mapper.getInput().getLocalStrategy());
        Assert.assertEquals((Object)LocalStrategy.NONE, (Object)((NamedChannel)mapper.getBroadcastInputs().get(0)).getLocalStrategy());
        Assert.assertEquals((Object)DriverStrategy.MAP, (Object)mapper.getDriverStrategy());
        Assert.assertNull((Object)mapper.getInput().getLocalStrategyKeys());
        Assert.assertNull((Object)mapper.getInput().getLocalStrategySortOrder());
        Assert.assertNull((Object)((NamedChannel)mapper.getBroadcastInputs().get(0)).getLocalStrategyKeys());
        Assert.assertNull((Object)((NamedChannel)mapper.getBroadcastInputs().get(0)).getLocalStrategySortOrder());
        Assert.assertNotNull((Object)combiner);
        Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)combiner.getInput().getShipStrategy());
        Assert.assertEquals((Object)LocalStrategy.NONE, (Object)combiner.getInput().getLocalStrategy());
        Assert.assertEquals((Object)DriverStrategy.SORTED_GROUP_COMBINE, (Object)combiner.getDriverStrategy());
        Assert.assertNull((Object)combiner.getInput().getLocalStrategyKeys());
        Assert.assertNull((Object)combiner.getInput().getLocalStrategySortOrder());
        Assert.assertEquals((Object)this.set0, (Object)combiner.getKeys(0));
        Assert.assertEquals((Object)this.set0, (Object)combiner.getKeys(1));
        Assert.assertEquals((Object)ShipStrategyType.PARTITION_HASH, (Object)reducer.getInput().getShipStrategy());
        Assert.assertEquals((Object)LocalStrategy.COMBININGSORT, (Object)reducer.getInput().getLocalStrategy());
        Assert.assertEquals((Object)DriverStrategy.SORTED_GROUP_REDUCE, (Object)reducer.getDriverStrategy());
        Assert.assertEquals((Object)this.set0, (Object)reducer.getKeys(0));
        Assert.assertEquals((Object)this.set0, (Object)reducer.getInput().getLocalStrategyKeys());
        Assert.assertTrue((boolean)Arrays.equals(reducer.getInput().getLocalStrategySortOrder(), reducer.getSortOrders(0)));
        Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)sink.getInput().getShipStrategy());
        Assert.assertEquals((Object)LocalStrategy.NONE, (Object)sink.getInput().getLocalStrategy());
    }

    public static Plan getKMeansPlan() throws Exception {
        return KMeansSingleStepTest.kmeans(new String[]{IN_FILE, IN_FILE, OUT_FILE, "20"});
    }

    public static Plan kmeans(String[] args) throws Exception {
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        MapOperator points = ((DataSource)env.readCsvFile(args[0]).fieldDelimiter(" ").includeFields(new boolean[]{true, true}).types(Double.class, Double.class).name(DATAPOINTS)).map((MapFunction)new MapFunction<Tuple2<Double, Double>, Point>(){

            public Point map(Tuple2<Double, Double> value) throws Exception {
                return new Point((Double)value.f0, (Double)value.f1);
            }
        });
        MapOperator centroids = ((DataSource)env.readCsvFile(args[1]).fieldDelimiter(" ").includeFields(new boolean[]{true, true, true}).types(Integer.class, Double.class, Double.class).name(CENTERS)).map((MapFunction)new MapFunction<Tuple3<Integer, Double, Double>, Centroid>(){

            public Centroid map(Tuple3<Integer, Double, Double> value) throws Exception {
                return new Centroid((Integer)value.f0, (Double)value.f1, (Double)value.f2);
            }
        });
        SingleInputUdfOperator newCentroids = ((MapOperator)points.map((MapFunction)new SelectNearestCenter()).name(MAPPER_NAME)).withBroadcastSet((DataSet)centroids, "centroids");
        Operator recomputeClusterCenter = newCentroids.groupBy(new int[]{0}).reduceGroup((GroupReduceFunction)new RecomputeClusterCenter()).name(REDUCER_NAME);
        recomputeClusterCenter.project(new int[]{0, 1}).writeAsCsv(args[2], "\n", " ").name(SINK);
        return env.createProgramPlan("KMeans Example");
    }

    private static final class RecomputeClusterCenter
    implements GroupReduceFunction<Tuple3<Integer, Point, Integer>, Tuple3<Integer, Point, Integer>>,
    GroupCombineFunction<Tuple3<Integer, Point, Integer>, Tuple3<Integer, Point, Integer>> {
        private RecomputeClusterCenter() {
        }

        public void reduce(Iterable<Tuple3<Integer, Point, Integer>> values, Collector<Tuple3<Integer, Point, Integer>> out) throws Exception {
            int id = -1;
            double x = 0.0;
            double y = 0.0;
            int count = 0;
            for (Tuple3<Integer, Point, Integer> value : values) {
                id = (Integer)value.f0;
                x += ((Double)((Point)((Object)value.f1)).f0).doubleValue();
                y += ((Double)((Point)((Object)value.f1)).f1).doubleValue();
                count += ((Integer)value.f2).intValue();
            }
            out.collect((Object)new Tuple3((Object)id, (Object)new Point(x, y), (Object)count));
        }

        public void combine(Iterable<Tuple3<Integer, Point, Integer>> values, Collector<Tuple3<Integer, Point, Integer>> out) throws Exception {
            this.reduce(values, out);
        }
    }

    private static final class SelectNearestCenter
    extends RichMapFunction<Point, Tuple3<Integer, Point, Integer>> {
        private Collection<Centroid> centroids;

        private SelectNearestCenter() {
        }

        public void open(OpenContext openContext) throws Exception {
            this.centroids = this.getRuntimeContext().getBroadcastVariable("centroids");
        }

        public Tuple3<Integer, Point, Integer> map(Point p) throws Exception {
            double minDistance = Double.MAX_VALUE;
            int closestCentroidId = -1;
            for (Centroid centroid : this.centroids) {
                double distance = p.euclideanDistance(centroid);
                if (!(distance < minDistance)) continue;
                minDistance = distance;
                closestCentroidId = (Integer)centroid.f0;
            }
            return new Tuple3((Object)closestCentroidId, (Object)p, (Object)1);
        }
    }

    public static class Centroid
    extends Tuple2<Integer, Point> {
        public Centroid(int id, double x, double y) {
            this.f0 = id;
            this.f1 = new Point(x, y);
        }

        public Centroid(int id, Point p) {
            this.f0 = id;
            this.f1 = p;
        }
    }

    public static class Point
    extends Tuple2<Double, Double> {
        public Point(double x, double y) {
            this.f0 = x;
            this.f1 = y;
        }

        public Point add(Point other) {
            this.f0 = (Double)this.f0 + (Double)other.f0;
            this.f1 = (Double)this.f1 + (Double)other.f1;
            return this;
        }

        public Point div(long val) {
            this.f0 = (Double)this.f0 / (double)val;
            this.f1 = (Double)this.f1 / (double)val;
            return this;
        }

        public double euclideanDistance(Point other) {
            return Math.sqrt(((Double)this.f0 - (Double)other.f0) * ((Double)this.f0 - (Double)other.f0) + ((Double)this.f1 - (Double)other.f1) * ((Double)this.f1 - (Double)other.f1));
        }

        public double euclideanDistance(Centroid other) {
            return Math.sqrt(((Double)this.f0 - (Double)((Point)((Object)other.f1)).f0) * ((Double)this.f0 - (Double)((Point)((Object)other.f1)).f0) + ((Double)this.f1 - (Double)((Point)((Object)other.f1)).f1) * ((Double)this.f1 - (Double)((Point)((Object)other.f1)).f1));
        }
    }
}

