/*
 * Decompiled with CFR 0.152.
 */
package org.lenskit.mf.funksvd;

import org.lenskit.data.ratings.PreferenceDomain;
import org.lenskit.mf.funksvd.FunkSVDUpdateRule;

public final class FunkSVDUpdater {
    private final FunkSVDUpdateRule updateRule;
    private double error;
    private double userFeatureValue;
    private double itemFeatureValue;
    private double sse;
    private int n;

    FunkSVDUpdater(FunkSVDUpdateRule rule) {
        this.updateRule = rule;
    }

    public void resetStatistics() {
        this.sse = 0.0;
        this.n = 0;
    }

    public int getUpdateCount() {
        return this.n;
    }

    public double getRMSE() {
        if (this.n <= 0) {
            return Double.NaN;
        }
        return Math.sqrt(this.sse / (double)this.n);
    }

    public void prepare(int feature, double rating, double estimate, double uv, double iv, double trail) {
        double pred = estimate + uv * iv;
        PreferenceDomain dom = this.updateRule.getDomain();
        if (dom != null) {
            pred = dom.clampValue(pred);
        }
        this.error = rating - (pred += trail);
        this.userFeatureValue = uv;
        this.itemFeatureValue = iv;
        ++this.n;
        this.sse += this.error * this.error;
    }

    public double getError() {
        return this.error;
    }

    public double getUserFeatureUpdate() {
        double delta = this.error * this.itemFeatureValue - this.updateRule.getTrainingRegularization() * this.userFeatureValue;
        return delta * this.updateRule.getLearningRate();
    }

    public double getItemFeatureUpdate() {
        double delta = this.error * this.userFeatureValue - this.updateRule.getTrainingRegularization() * this.itemFeatureValue;
        return delta * this.updateRule.getLearningRate();
    }
}

