/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.iterative.aggregators;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.aggregators.LongSumAggregator;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.AggregateOperator;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.LongValue;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class AggregatorConvergenceITCase
extends MultipleProgramsTestBase {
    final List<Tuple2<Long, Long>> verticesInput = Arrays.asList(new Tuple2((Object)1L, (Object)1L), new Tuple2((Object)2L, (Object)2L), new Tuple2((Object)3L, (Object)3L), new Tuple2((Object)4L, (Object)4L), new Tuple2((Object)5L, (Object)5L), new Tuple2((Object)6L, (Object)6L), new Tuple2((Object)7L, (Object)7L), new Tuple2((Object)8L, (Object)8L), new Tuple2((Object)9L, (Object)9L));
    final List<Tuple2<Long, Long>> edgesInput = Arrays.asList(new Tuple2((Object)1L, (Object)2L), new Tuple2((Object)1L, (Object)3L), new Tuple2((Object)2L, (Object)3L), new Tuple2((Object)2L, (Object)4L), new Tuple2((Object)2L, (Object)1L), new Tuple2((Object)3L, (Object)1L), new Tuple2((Object)3L, (Object)2L), new Tuple2((Object)4L, (Object)2L), new Tuple2((Object)4L, (Object)6L), new Tuple2((Object)5L, (Object)6L), new Tuple2((Object)6L, (Object)4L), new Tuple2((Object)6L, (Object)5L), new Tuple2((Object)7L, (Object)8L), new Tuple2((Object)7L, (Object)9L), new Tuple2((Object)8L, (Object)7L), new Tuple2((Object)8L, (Object)9L), new Tuple2((Object)9L, (Object)7L), new Tuple2((Object)9L, (Object)8L));
    final List<Tuple2<Long, Long>> expectedResult = Arrays.asList(new Tuple2((Object)1L, (Object)1L), new Tuple2((Object)2L, (Object)1L), new Tuple2((Object)3L, (Object)1L), new Tuple2((Object)4L, (Object)1L), new Tuple2((Object)5L, (Object)2L), new Tuple2((Object)6L, (Object)1L), new Tuple2((Object)7L, (Object)7L), new Tuple2((Object)8L, (Object)7L), new Tuple2((Object)9L, (Object)7L));

    public AggregatorConvergenceITCase(MultipleProgramsTestBase.TestExecutionMode mode) {
        super(mode);
    }

    @Test
    public void testConnectedComponentsWithParametrizableConvergence() throws Exception {
        String updatedElements = "updated.elements.aggr";
        long convergenceThreshold = 3L;
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        DataSource initialSolutionSet = env.fromCollection(this.verticesInput);
        DataSource edges = env.fromCollection(this.edgesInput);
        IterativeDataSet iteration = initialSolutionSet.iterate(10);
        iteration.registerAggregationConvergenceCriterion("updated.elements.aggr", (Aggregator)new LongSumAggregator(), (ConvergenceCriterion)new UpdatedElementsConvergenceCriterion(3L));
        AggregateOperator verticesWithNewComponents = iteration.join((DataSet)edges).where(new int[]{0}).equalTo(new int[]{0}).with((JoinFunction)new NeighborWithComponentIDJoin()).groupBy(new int[]{0}).min(1);
        FlatMapOperator updatedComponentId = verticesWithNewComponents.join((DataSet)iteration).where(new int[]{0}).equalTo(new int[]{0}).flatMap((FlatMapFunction)new MinimumIdFilter("updated.elements.aggr"));
        List result = iteration.closeWith((DataSet)updatedComponentId).collect();
        Collections.sort(result, new TestBaseUtils.TupleComparator());
        Assert.assertEquals(this.expectedResult, (Object)result);
    }

    @Test
    public void testDeltaConnectedComponentsWithParametrizableConvergence() throws Exception {
        String updatedElements = "updated.elements.aggr";
        long convergenceThreshold = 3L;
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        DataSource initialSolutionSet = env.fromCollection(this.verticesInput);
        DataSource edges = env.fromCollection(this.edgesInput);
        DeltaIteration iteration = initialSolutionSet.iterateDelta((DataSet)initialSolutionSet, 10, new int[]{0});
        iteration.registerAggregationConvergenceCriterion("updated.elements.aggr", (Aggregator)new LongSumAggregator(), (ConvergenceCriterion)new UpdatedElementsConvergenceCriterion(3L));
        AggregateOperator verticesWithNewComponents = iteration.getWorkset().join((DataSet)edges).where(new int[]{0}).equalTo(new int[]{0}).with((JoinFunction)new NeighborWithComponentIDJoin()).groupBy(new int[]{0}).min(1);
        FlatMapOperator updatedComponentId = verticesWithNewComponents.join((DataSet)iteration.getSolutionSet()).where(new int[]{0}).equalTo(new int[]{0}).flatMap((FlatMapFunction)new MinimumIdFilter("updated.elements.aggr"));
        List result = iteration.closeWith((DataSet)updatedComponentId, (DataSet)updatedComponentId).collect();
        Collections.sort(result, new TestBaseUtils.TupleComparator());
        Assert.assertEquals(this.expectedResult, (Object)result);
    }

    @Test
    public void testParameterizableAggregator() throws Exception {
        int maxIterations = 5;
        String aggregatorName = "elements.in.component.aggregator";
        long componentId = 1L;
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        DataSource initialSolutionSet = env.fromCollection(this.verticesInput);
        DataSource edges = env.fromCollection(this.edgesInput);
        IterativeDataSet iteration = initialSolutionSet.iterate(5);
        iteration.registerAggregator("elements.in.component.aggregator", (Aggregator)new LongSumAggregatorWithParameter(1L));
        AggregateOperator verticesWithNewComponents = iteration.join((DataSet)edges).where(new int[]{0}).equalTo(new int[]{0}).with((JoinFunction)new NeighborWithComponentIDJoin()).groupBy(new int[]{0}).min(1);
        FlatMapOperator updatedComponentId = verticesWithNewComponents.join((DataSet)iteration).where(new int[]{0}).equalTo(new int[]{0}).flatMap((FlatMapFunction)new MinimumIdFilterCounting("elements.in.component.aggregator"));
        List result = iteration.closeWith((DataSet)updatedComponentId).collect();
        Collections.sort(result, new TestBaseUtils.TupleComparator());
        List<Tuple2> expectedResult = Arrays.asList(new Tuple2((Object)1L, (Object)1L), new Tuple2((Object)2L, (Object)1L), new Tuple2((Object)3L, (Object)1L), new Tuple2((Object)4L, (Object)1L), new Tuple2((Object)5L, (Object)1L), new Tuple2((Object)6L, (Object)1L), new Tuple2((Object)7L, (Object)7L), new Tuple2((Object)8L, (Object)7L), new Tuple2((Object)9L, (Object)7L));
        Assert.assertEquals(expectedResult, (Object)result);
        long[] aggrValues = MinimumIdFilterCounting.aggr_value;
        Assert.assertEquals((long)3L, (long)aggrValues[0]);
        Assert.assertEquals((long)4L, (long)aggrValues[1]);
        Assert.assertEquals((long)5L, (long)aggrValues[2]);
        Assert.assertEquals((long)6L, (long)aggrValues[3]);
    }

    private static final class LongSumAggregatorWithParameter
    extends LongSumAggregator {
        private long componentId;

        public LongSumAggregatorWithParameter(long compId) {
            this.componentId = compId;
        }

        public long getComponentId() {
            return this.componentId;
        }
    }

    private static class UpdatedElementsConvergenceCriterion
    implements ConvergenceCriterion<LongValue> {
        private final long threshold;

        public UpdatedElementsConvergenceCriterion(long uThreshold) {
            this.threshold = uThreshold;
        }

        public boolean isConverged(int iteration, LongValue value) {
            return value.getValue() < this.threshold;
        }
    }

    private static final class MinimumIdFilterCounting
    extends RichFlatMapFunction<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>, Tuple2<Long, Long>> {
        private static final long[] aggr_value = new long[5];
        private final String aggName;
        private LongSumAggregatorWithParameter aggr;

        public MinimumIdFilterCounting(String aggName) {
            this.aggName = aggName;
        }

        public void open(OpenContext openContext) {
            int superstep = this.getIterationRuntimeContext().getSuperstepNumber();
            this.aggr = (LongSumAggregatorWithParameter)this.getIterationRuntimeContext().getIterationAggregator(this.aggName);
            if (superstep > 1 && this.getIterationRuntimeContext().getTaskInfo().getIndexOfThisSubtask() == 0) {
                LongValue val = (LongValue)this.getIterationRuntimeContext().getPreviousIterationAggregate(this.aggName);
                MinimumIdFilterCounting.aggr_value[superstep - 2] = val.getValue();
            }
        }

        public void flatMap(Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> vertexWithNewAndOldId, Collector<Tuple2<Long, Long>> out) {
            if ((Long)((Tuple2)vertexWithNewAndOldId.f0).f1 < (Long)((Tuple2)vertexWithNewAndOldId.f1).f1) {
                out.collect(vertexWithNewAndOldId.f0);
                if (((Long)((Tuple2)vertexWithNewAndOldId.f0).f1).longValue() == this.aggr.getComponentId()) {
                    this.aggr.aggregate(1L);
                }
            } else {
                out.collect(vertexWithNewAndOldId.f1);
                if (((Long)((Tuple2)vertexWithNewAndOldId.f1).f1).longValue() == this.aggr.getComponentId()) {
                    this.aggr.aggregate(1L);
                }
            }
        }
    }

    private static class MinimumIdFilter
    extends RichFlatMapFunction<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>, Tuple2<Long, Long>> {
        private final String aggName;
        private LongSumAggregator aggr;

        public MinimumIdFilter(String aggName) {
            this.aggName = aggName;
        }

        public void open(OpenContext openContext) {
            this.aggr = (LongSumAggregator)this.getIterationRuntimeContext().getIterationAggregator(this.aggName);
        }

        public void flatMap(Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> vertexWithNewAndOldId, Collector<Tuple2<Long, Long>> out) {
            if ((Long)((Tuple2)vertexWithNewAndOldId.f0).f1 < (Long)((Tuple2)vertexWithNewAndOldId.f1).f1) {
                out.collect(vertexWithNewAndOldId.f0);
                this.aggr.aggregate(1L);
            } else {
                out.collect(vertexWithNewAndOldId.f1);
            }
        }
    }

    private static final class NeighborWithComponentIDJoin
    extends RichJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
        private static final long serialVersionUID = 1L;

        private NeighborWithComponentIDJoin() {
        }

        public Tuple2<Long, Long> join(Tuple2<Long, Long> vertexWithCompId, Tuple2<Long, Long> edge) {
            vertexWithCompId.f0 = edge.f1;
            return vertexWithCompId;
        }
    }
}

