/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.maxabsscaler;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel;
import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModelData;
import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerParams;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Preconditions;

public class MaxAbsScaler
implements Estimator<MaxAbsScaler, MaxAbsScalerModel>,
MaxAbsScalerParams<MaxAbsScaler> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public MaxAbsScaler() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public MaxAbsScalerModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        String inputCol = this.getInputCol();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator inputData = tEnv.toDataStream(inputs[0]).map((MapFunction & Serializable)value -> (Vector)value.getField(inputCol), (TypeInformation)VectorTypeInfo.INSTANCE);
        SingleOutputStreamOperator maxAbsValues = inputData.transform("reduceInEachPartition", (TypeInformation)VectorTypeInfo.INSTANCE, (OneInputStreamOperator)new MaxAbsReduceFunctionOperator()).transform("reduceInFinalPartition", (TypeInformation)VectorTypeInfo.INSTANCE, (OneInputStreamOperator)new MaxAbsReduceFunctionOperator()).setParallelism(1);
        SingleOutputStreamOperator modelData = maxAbsValues.map((MapFunction & Serializable)vector -> new MaxAbsScalerModelData((DenseVector)vector));
        MaxAbsScalerModel model = new MaxAbsScalerModel().setModelData(tEnv.fromDataStream((DataStream)modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.getParamMap());
        return model;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public static MaxAbsScaler load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (MaxAbsScaler)ReadWriteUtils.loadStageParam((String)path);
    }

    private static class MaxAbsReduceFunctionOperator
    extends AbstractStreamOperator<Vector>
    implements OneInputStreamOperator<Vector, Vector>,
    BoundedOneInput {
        private ListState<DenseVector> maxAbsState;
        private DenseVector maxAbsVector;

        private MaxAbsReduceFunctionOperator() {
        }

        public void endInput() {
            if (this.maxAbsVector != null) {
                this.output.collect((Object)new StreamRecord((Object)this.maxAbsVector));
            }
        }

        public void processElement(StreamRecord<Vector> streamRecord) {
            block3: {
                Vector currentValue;
                block2: {
                    currentValue = (Vector)streamRecord.getValue();
                    this.maxAbsVector = this.maxAbsVector == null ? new DenseVector(currentValue.size()) : this.maxAbsVector;
                    Preconditions.checkArgument((currentValue.size() == this.maxAbsVector.size() ? 1 : 0) != 0, (Object)"The training data should all have same dimensions.");
                    if (!(currentValue instanceof DenseVector)) break block2;
                    double[] values = ((DenseVector)currentValue).values;
                    for (int i = 0; i < currentValue.size(); ++i) {
                        this.maxAbsVector.values[i] = Math.max(this.maxAbsVector.values[i], Math.abs(values[i]));
                    }
                    break block3;
                }
                if (!(currentValue instanceof SparseVector)) break block3;
                int[] indices = ((SparseVector)currentValue).indices;
                double[] values = ((SparseVector)currentValue).values;
                for (int i = 0; i < indices.length; ++i) {
                    this.maxAbsVector.values[indices[i]] = Math.max(this.maxAbsVector.values[indices[i]], Math.abs(values[i]));
                }
            }
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.maxAbsState = context.getOperatorStateStore().getListState(new ListStateDescriptor("maxAbsState", (TypeInformation)DenseVectorTypeInfo.INSTANCE));
            OperatorStateUtils.getUniqueElement(this.maxAbsState, (String)"maxAbsState").ifPresent(x -> {
                this.maxAbsVector = x;
            });
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.maxAbsState.clear();
            if (this.maxAbsVector != null) {
                this.maxAbsState.add((Object)this.maxAbsVector);
            }
        }
    }
}

