/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.image.model;

import Jama.Matrix;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.openimaj.citation.annotation.Reference;
import org.openimaj.citation.annotation.ReferenceType;
import org.openimaj.data.dataset.GroupedDataset;
import org.openimaj.data.dataset.ListDataset;
import org.openimaj.feature.DoubleFV;
import org.openimaj.feature.FeatureExtractor;
import org.openimaj.image.FImage;
import org.openimaj.image.feature.FImage2DoubleFV;
import org.openimaj.io.ReadWriteableBinary;
import org.openimaj.math.matrix.algorithm.LinearDiscriminantAnalysis;
import org.openimaj.math.matrix.algorithm.pca.ThinSvdPrincipalComponentAnalysis;
import org.openimaj.ml.training.BatchTrainer;
import org.openimaj.util.array.ArrayUtils;
import org.openimaj.util.pair.IndependentPair;

@Reference(type=ReferenceType.Article, author={"Belhumeur, Peter N.", "Hespanha, Jo\\~{a}o P.", "Kriegman, David J."}, title="Eigenfaces vs. Fisherfaces: Recognition Using Class Specific Linear Projection", year="1997", journal="IEEE Trans. Pattern Anal. Mach. Intell.", pages={"711", "", "720"}, url="http://dx.doi.org/10.1109/34.598228", month="July", number="7", publisher="IEEE Computer Society", volume="19", customData={"issn", "0162-8828", "numpages", "10", "doi", "10.1109/34.598228", "acmid", "261512", "address", "Washington, DC, USA", "keywords", "Appearance-based vision, face recognition, illumination invariance, Fisher's linear discriminant."})
public class FisherImages
implements BatchTrainer<IndependentPair<?, FImage>>,
FeatureExtractor<DoubleFV, FImage>,
ReadWriteableBinary {
    private int numComponents;
    private int width;
    private int height;
    private Matrix basis;
    private double[] mean;

    public FisherImages(int numComponents) {
        this.numComponents = numComponents;
    }

    public void readBinary(DataInput in) throws IOException {
        this.width = in.readInt();
        this.height = in.readInt();
        this.numComponents = in.readInt();
    }

    public byte[] binaryHeader() {
        return "FisI".getBytes();
    }

    public void writeBinary(DataOutput out) throws IOException {
        out.writeInt(this.width);
        out.writeInt(this.height);
        out.writeInt(this.numComponents);
    }

    public void train(Map<?, ? extends List<FImage>> data) {
        ArrayList<IndependentPair> list = new ArrayList<IndependentPair>();
        for (Map.Entry<?, List<FImage>> e : data.entrySet()) {
            for (FImage i : e.getValue()) {
                list.add(IndependentPair.pair(e.getKey(), (Object)i));
            }
        }
        this.train(list);
    }

    public <KEY> void train(GroupedDataset<KEY, ? extends ListDataset<FImage>, FImage> data) {
        ArrayList<IndependentPair> list = new ArrayList<IndependentPair>();
        for (Object e : data.getGroups()) {
            for (FImage i : (ListDataset)data.getInstances(e)) {
                list.add(IndependentPair.pair(e, (Object)i));
            }
        }
        this.train(list);
    }

    public void train(List<? extends IndependentPair<?, FImage>> data) {
        this.width = ((FImage)data.get((int)0).secondObject()).width;
        this.height = ((FImage)data.get((int)0).secondObject()).height;
        HashMap mapData = new HashMap();
        ArrayList<double[]> listData = new ArrayList<double[]>();
        for (IndependentPair<?, FImage> item : data) {
            List<Object> fvs = (List)mapData.get(item.firstObject());
            if (fvs == null) {
                fvs = new ArrayList<double[]>();
                mapData.put(item.firstObject(), fvs);
            }
            double[] dArray = (double[])FImage2DoubleFV.INSTANCE.extractFeature((FImage)((FImage)item.getSecondObject())).values;
            fvs.add(dArray);
            listData.add(dArray);
        }
        ThinSvdPrincipalComponentAnalysis pca = new ThinSvdPrincipalComponentAnalysis(this.numComponents);
        pca.learnBasis(listData);
        ArrayList<double[][]> ldaData = new ArrayList<double[][]>(mapData.size());
        for (Map.Entry entry : mapData.entrySet()) {
            List vecs = (List)entry.getValue();
            double[][] classData = new double[vecs.size()][];
            for (int i = 0; i < classData.length; ++i) {
                classData[i] = pca.project((double[])vecs.get(i));
            }
            ldaData.add(classData);
        }
        LinearDiscriminantAnalysis lda = new LinearDiscriminantAnalysis(this.numComponents);
        lda.learnBasis(ldaData);
        this.basis = pca.getBasis().times(lda.getBasis());
        this.mean = pca.getMean();
    }

    private double[] project(double[] vector) {
        Matrix vec = new Matrix(1, vector.length);
        double[][] vecarr = vec.getArray();
        for (int i = 0; i < vector.length; ++i) {
            vecarr[0][i] = vector[i] - this.mean[i];
        }
        return vec.times(this.basis).getColumnPackedCopy();
    }

    public DoubleFV extractFeature(FImage object) {
        return new DoubleFV(this.project((double[])FImage2DoubleFV.INSTANCE.extractFeature((FImage)object).values));
    }

    public double[] getBasisVector(int index) {
        double[] pc = new double[this.basis.getRowDimension()];
        double[][] data = this.basis.getArray();
        for (int r = 0; r < pc.length; ++r) {
            pc[r] = data[r][index];
        }
        return pc;
    }

    public FImage visualise(int num) {
        return new FImage(ArrayUtils.reshapeFloat((double[])this.getBasisVector(num), (int)this.width, (int)this.height));
    }
}

