/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.standardscaler;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.window.EventTimeSessionWindows;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
import org.apache.flink.ml.common.window.Windows;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerParams;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.WithParams;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

public class OnlineStandardScaler
implements Estimator<OnlineStandardScaler, OnlineStandardScalerModel>,
OnlineStandardScalerParams<OnlineStandardScaler> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public OnlineStandardScaler() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, (WithParams)this);
    }

    public OnlineStandardScalerModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        Windows windows = this.getWindows();
        boolean isEventTimeBasedTraining = false;
        if (windows instanceof EventTimeTumblingWindows || windows instanceof EventTimeSessionWindows) {
            isEventTimeBasedTraining = true;
        }
        SingleOutputStreamOperator modelData = DataStreamUtils.windowAllAndProcess((DataStream)tEnv.toDataStream(inputs[0]), (Windows)windows, new ComputeModelDataFunction(this.getInputCol(), isEventTimeBasedTraining));
        OnlineStandardScalerModel model = new OnlineStandardScalerModel().setModelData(tEnv.fromDataStream((DataStream)modelData));
        ParamUtils.updateExistingParams((WithParams)model, this.paramMap);
        return model;
    }

    private static StandardScalerModelData buildModelData(long numElements, DenseVector sum, DenseVector squaredSum, long modelVersion, long currentTimeStamp) {
        BLAS.scal((double)(1.0 / (double)numElements), (Vector)sum);
        double[] mean = sum.values;
        double[] std = squaredSum.values;
        if (numElements > 1L) {
            for (int i = 0; i < mean.length; ++i) {
                std[i] = Math.sqrt((squaredSum.values[i] - (double)numElements * mean[i] * mean[i]) / (double)(numElements - 1L));
            }
        } else {
            Arrays.fill(std, 0.0);
        }
        return new StandardScalerModelData(Vectors.dense((double[])mean), Vectors.dense((double[])std), modelVersion, currentTimeStamp);
    }

    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata((Stage)this, (String)path);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public static OnlineStandardScaler load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (OnlineStandardScaler)ReadWriteUtils.loadStageParam((String)path);
    }

    private static class ComputeModelDataFunction<W extends Window>
    extends ProcessAllWindowFunction<Row, StandardScalerModelData, W> {
        private final String inputCol;
        private final boolean isEventTimeBasedTraining;

        public ComputeModelDataFunction(String inputCol, boolean isEventTimeBasedTraining) {
            this.inputCol = inputCol;
            this.isEventTimeBasedTraining = isEventTimeBasedTraining;
        }

        public void process(ProcessAllWindowFunction.Context context, Iterable<Row> iterable, Collector<StandardScalerModelData> collector) throws Exception {
            ListState sumState = context.globalState().getListState(new ListStateDescriptor("sumState", (TypeInformation)DenseVectorTypeInfo.INSTANCE));
            ListState squaredSumState = context.globalState().getListState(new ListStateDescriptor("squaredSumState", (TypeInformation)DenseVectorTypeInfo.INSTANCE));
            ListState numElementsState = context.globalState().getListState(new ListStateDescriptor("numElementsState", Types.LONG));
            ListState modelVersionState = context.globalState().getListState(new ListStateDescriptor("modelVersionState", Types.LONG));
            DenseVector sum = OperatorStateUtils.getUniqueElement((ListState)sumState, (String)"sumState").orElse(null);
            DenseVector squaredSum = OperatorStateUtils.getUniqueElement((ListState)squaredSumState, (String)"squaredSumState").orElse(null);
            long numElements = OperatorStateUtils.getUniqueElement((ListState)numElementsState, (String)"numElementsState").orElse(0L);
            long modelVersion = OperatorStateUtils.getUniqueElement((ListState)modelVersionState, (String)"modelVersionState").orElse(0L);
            long numElementsBefore = numElements;
            for (Row element : iterable) {
                Vector inputVec = ((Vector)Objects.requireNonNull(element.getField(this.inputCol))).clone();
                if (numElements == 0L) {
                    sum = new DenseVector(inputVec.size());
                    squaredSum = new DenseVector(inputVec.size());
                }
                BLAS.axpy((double)1.0, (Vector)inputVec, (DenseVector)sum);
                BLAS.hDot((Vector)inputVec, (Vector)inputVec);
                BLAS.axpy((double)1.0, (Vector)inputVec, (DenseVector)squaredSum);
                ++numElements;
            }
            if (numElements - numElementsBefore > 0L) {
                long currentEventTime = this.isEventTimeBasedTraining ? context.window().maxTimestamp() : Long.MAX_VALUE;
                collector.collect((Object)OnlineStandardScaler.buildModelData(numElements, sum.clone(), squaredSum.clone(), modelVersion, currentEventTime));
                sumState.update(Collections.singletonList(sum));
                squaredSumState.update(Collections.singletonList(squaredSum));
                numElementsState.update(Collections.singletonList(numElements));
                modelVersionState.update(Collections.singletonList(++modelVersion));
            }
        }
    }
}

