001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.explanations.lime;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.CategoricalInfo;
021import org.tribuo.Example;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.Model;
024import org.tribuo.MutableDataset;
025import org.tribuo.OutputFactory;
026import org.tribuo.Prediction;
027import org.tribuo.RealInfo;
028import org.tribuo.SparseModel;
029import org.tribuo.SparseTrainer;
030import org.tribuo.VariableIDInfo;
031import org.tribuo.VariableInfo;
032import org.tribuo.WeightedExamples;
033import org.tribuo.classification.Label;
034import org.tribuo.classification.LabelFactory;
035import org.tribuo.classification.explanations.TabularExplainer;
036import org.tribuo.impl.ArrayExample;
037import org.tribuo.interop.ExternalModel;
038import org.tribuo.math.la.SparseVector;
039import org.tribuo.math.la.VectorTuple;
040import org.tribuo.provenance.SimpleDataSourceProvenance;
041import org.tribuo.regression.RegressionFactory;
042import org.tribuo.regression.Regressor;
043import org.tribuo.regression.evaluation.RegressionEvaluation;
044import org.tribuo.regression.evaluation.RegressionEvaluator;
045import org.tribuo.util.Util;
046
047import java.time.OffsetDateTime;
048import java.util.ArrayList;
049import java.util.Iterator;
050import java.util.List;
051import java.util.Map;
052import java.util.Random;
053import java.util.SplittableRandom;
054import java.util.logging.Logger;
055
056/**
057 * LIMEBase merges the lime_base.py and lime_tabular.py implementations, and deals with simple
058 * matrices of numerical or categorical data. If you want a mixture of text, numerical
059 * and categorical data try {@link LIMEColumnar}. For plain text data use {@link LIMEText}.
060 * <p>
061 * See:
062 * <pre>
063 * Ribeiro MT, Singh S, Guestrin C.
064 * "Why should I trust you?: Explaining the predictions of any classifier"
065 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining 2016.
066 * </pre>
067 */
068public class LIMEBase implements TabularExplainer<Regressor> {
069    private static final Logger logger = Logger.getLogger(LIMEBase.class.getName());
070
071    /**
072     * Width of the noise gaussian.
073     */
074    public static final double WIDTH_CONSTANT = 0.75;
075    /**
076     * Delta to consider two distances equal.
077     */
078    public static final double DISTANCE_DELTA = 1e-12;
079
080    protected static final OutputFactory<Regressor> regressionFactory = new RegressionFactory();
081    protected static final RegressionEvaluator evaluator = new RegressionEvaluator(true);
082
083    protected final SplittableRandom rng;
084
085    protected final Model<Label> innerModel;
086
087    protected final SparseTrainer<Regressor> explanationTrainer;
088
089    protected final int numSamples;
090
091    protected final long numTrainingExamples;
092
093    protected final double kernelWidth;
094
095    private final ImmutableFeatureMap fMap;
096
097    /**
098     * Constructs a LIME explainer for a model which uses tabular data (i.e., no special treatment for text features).
099     * @param rng The rng to use for sampling.
100     * @param innerModel The model to explain.
101     * @param explanationTrainer The sparse trainer used to explain predictions.
102     * @param numSamples The number of samples to generate for an explanation.
103     */
104    public LIMEBase(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, int numSamples) {
105        if (!(explanationTrainer instanceof WeightedExamples)) {
106            throw new IllegalArgumentException("SparseTrainer must implement WeightedExamples, found " + explanationTrainer.toString());
107        }
108        if (!innerModel.generatesProbabilities()) {
109            throw new IllegalArgumentException("LIME requires the model generate probabilities.");
110        }
111        if (innerModel instanceof ExternalModel) {
112            throw new IllegalArgumentException("LIME requires the model to have been trained in Tribuo. Found " + innerModel.getClass() + " which is an external model.");
113        }
114        this.rng = rng;
115        this.innerModel = innerModel;
116        this.explanationTrainer = explanationTrainer;
117        this.numSamples = numSamples;
118        this.numTrainingExamples = innerModel.getOutputIDInfo().getTotalObservations();
119        this.kernelWidth = Math.pow(innerModel.getFeatureIDMap().size() * WIDTH_CONSTANT, 2);
120        this.fMap = innerModel.getFeatureIDMap();
121    }
122
123    @Override
124    public LIMEExplanation explain(Example<Label> example) {
125        return explainWithSamples(example).getA();
126    }
127
128    protected Pair<LIMEExplanation,List<Example<Regressor>>> explainWithSamples(Example<Label> example) {
129        // Predict using the full model, and generate a new example containing that prediction.
130        Prediction<Label> prediction = innerModel.predict(example);
131        Example<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction),example,1.0f);
132
133        // Sample a dataset.
134        List<Example<Regressor>> sample = sampleData(example);
135
136        // Generate a sparse model on the sampled data.
137        SparseModel<Regressor> model = trainExplainer(labelledExample,sample);
138
139        // Test the sparse model against the predictions of the real model.
140        List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
141        predictions.add(model.predict(labelledExample));
142        RegressionEvaluation evaluation = evaluator.evaluate(model,predictions,new SimpleDataSourceProvenance("LIMEColumnar sampled data",regressionFactory));
143
144        return new Pair<>(new LIMEExplanation(model,prediction,evaluation),sample);
145    }
146
147    /**
148     * Sample a dataset based on the input example.
149     * <p>
150     * The sampled dataset uses the feature dimensions from the {@link Model}.
151     * <p>
152     * The outputs are the probability values of each class from the underlying Model,
153     * rather than ground truth outputs. The distance is measured using the
154     * {@link LIMEBase#measureDistance} function, transformed through a kernel and used
155     * as the sampled Example's weight.
156     * @param example The example to sample from.
157     * @return A sampled dataset.
158     */
159    private List<Example<Regressor>> sampleData(Example<Label> example) {
160        List<Example<Regressor>> output = new ArrayList<>();
161
162        SparseVector exampleVector = SparseVector.createSparseVector(example,fMap,false);
163
164        Random innerRNG = new Random(rng.nextLong());
165        for (int i = 0; i < numSamples; i++) {
166            // Sample a new Example.
167            Example<Label> sample = samplePoint(innerRNG,fMap,numTrainingExamples,exampleVector);
168
169            //logger.fine("Itr " + i + " sampled " + sample.toString());
170
171            // Label it using the full model.
172            Prediction<Label> samplePrediction = innerModel.predict(sample);
173
174            // Measure the distance between this point and the input, to be used as a weight.
175            double distance = measureDistance(fMap,numTrainingExamples,exampleVector, SparseVector.createSparseVector(sample,fMap,false));
176
177            // Transform distance through the kernel function.
178            distance = kernelDist(distance,kernelWidth);
179
180            // Generate the new sample with the appropriate label and weight.
181            Example<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction),sample,(float)distance);
182            output.add(labelledSample);
183        }
184
185        return output;
186    }
187
188    /**
189     * Samples a single example from the supplied feature map and input vector.
190     * @param rng The rng to use.
191     * @param fMap The feature map describing the domain of the features.
192     * @param numTrainingExamples The number of training examples the fMap has seen.
193     * @param input The input sparse vector to use.
194     * @return An Example sampled from the supplied feature map and input vector.
195     */
196    public static Example<Label> samplePoint(Random rng, ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input) {
197        ArrayList<String> names = new ArrayList<>();
198        ArrayList<Double> values = new ArrayList<>();
199
200        for (VariableInfo info : fMap) {
201            int id = ((VariableIDInfo)info).getID();
202            double inputValue = input.get(id);
203
204            if (info instanceof CategoricalInfo) {
205                // This one is tricksy as categorical info essentially implicitly includes a zero.
206                CategoricalInfo catInfo = (CategoricalInfo) info;
207                double sample = catInfo.frequencyBasedSample(rng,numTrainingExamples);
208                // If we didn't sample zero.
209                if (Math.abs(sample) > 1e-10) {
210                    names.add(info.getName());
211                    values.add(sample);
212                }
213            } else if (info instanceof RealInfo) {
214                RealInfo realInfo = (RealInfo) info;
215                // As realInfo is sparse we sample from the mixture distribution,
216                // either 0 or N(inputValue,variance).
217                // This assumes realInfo never observed a zero, which is enforced from v2.1
218                // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
219                // If it's not zero do we want to?
220                int count = realInfo.getCount();
221                double threshold = count / ((double)numTrainingExamples);
222                if (rng.nextDouble() < threshold) {
223                    double variance = realInfo.getVariance();
224                    double sample = (rng.nextGaussian() * Math.sqrt(variance)) + inputValue;
225                    names.add(info.getName());
226                    values.add(sample);
227                }
228            } else {
229                throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
230            }
231        }
232
233        return new ArrayExample<>(LabelFactory.UNKNOWN_LABEL,names.toArray(new String[0]),Util.toPrimitiveDouble(values));
234    }
235
236    /**
237     * Trains the explanation model using the supplied sampled data and the input example.
238     * <p>
239     * The labels are usually the predicted probabilities from the real model.
240     * @param target The input example to explain.
241     * @param samples The sampled data around the input.
242     * @return An explanation model.
243     */
244    protected SparseModel<Regressor> trainExplainer(Example<Regressor> target, List<Example<Regressor>> samples) {
245        MutableDataset<Regressor> explanationDataset = new MutableDataset<>(new SimpleDataSourceProvenance("explanationDataset", OffsetDateTime.now(), regressionFactory), regressionFactory);
246        explanationDataset.add(target);
247        explanationDataset.addAll(samples);
248
249        SparseModel<Regressor> explainer = explanationTrainer.train(explanationDataset);
250
251        return explainer;
252    }
253
254    /**
255     * Calculates an RBF kernel of a specific width.
256     * @param input The input value.
257     * @param width The width of the kernel.
258     * @return sqrt ( exp ( - input*input / width))
259     */
260    public static double kernelDist(double input, double width) {
261        return Math.sqrt(Math.exp(-(input*input) / width));
262    }
263
264    /**
265     * Measures the distance between an input point and a sampled point.
266     * <p>
267     * This distance function takes into account categorical and real values. It uses
268     * the hamming distance for categoricals and the euclidean distance for real values.
269     * @param fMap The feature map used to determine if a feature is categorical or real.
270     * @param numTrainingExamples The number of training examples the fMap has seen.
271     * @param input The input point.
272     * @param sample The sampled point.
273     * @return The distance between the two points.
274     */
275    public static double measureDistance(ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input, SparseVector sample) {
276        double score = 0.0;
277
278        Iterator<VectorTuple> itr = input.iterator();
279        Iterator<VectorTuple> otherItr = sample.iterator();
280        VectorTuple tuple;
281        VectorTuple otherTuple;
282        while (itr.hasNext() && otherItr.hasNext()) {
283            tuple = itr.next();
284            otherTuple = otherItr.next();
285            //after this loop, either itr is out or tuple.index >= otherTuple.index
286            while (itr.hasNext() && (tuple.index < otherTuple.index)) {
287                score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value);
288                tuple = itr.next();
289            }
290            //after this loop, either otherItr is out or tuple.index <= otherTuple.index
291            while (otherItr.hasNext() && (tuple.index > otherTuple.index)) {
292                score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value);
293                otherTuple = otherItr.next();
294            }
295            if (tuple.index == otherTuple.index) {
296                //the indices line up, do the calculation.
297                score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value,otherTuple.value);
298            } else {
299                // Now consume both the values as they'll be gone next iteration.
300                // Consume the value in tuple.
301                score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value);
302                // Consume the value in otherTuple.
303                score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value);
304            }
305        }
306        while (itr.hasNext()) {
307            tuple = itr.next();
308            score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value);
309        }
310        while (otherItr.hasNext()) {
311            otherTuple = otherItr.next();
312            score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value);
313        }
314
315        return Math.sqrt(score);
316    }
317
318    /**
319     * Calculates the distance between two values for a single feature.
320     * <p>
321     * Assumes the other value is zero as the example is sparse.
322     * @param fMap The feature map which knows if a feature is categorical or real.
323     * @param numTrainingExamples The number of training examples this feature map observed.
324     * @param index The id number for this feature.
325     * @param value One feature value.
326     * @return The distance from zero to the supplied value.
327     */
328    private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double value) {
329        VariableInfo info = fMap.get(index);
330        if (info instanceof CategoricalInfo) {
331            return 1.0;
332        } else if (info instanceof RealInfo) {
333            RealInfo rInfo = (RealInfo) info;
334            // Fudge the distance calculation so it doesn't overpower the categoricals.
335            double curScore = value * value;
336            double range;
337            // This further fudge is because the RealInfo may have observed a zero if it's sparse, but it might not.
338            if (numTrainingExamples != info.getCount()) {
339                range = Math.max(rInfo.getMax(),0.0) - Math.min(rInfo.getMin(),0.0);
340            } else {
341                range = rInfo.getMax() - rInfo.getMin();
342            }
343            return curScore / (range*range);
344        } else {
345            throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
346        }
347    }
348
349    /**
350     * Calculates the distance between two values for a single feature.
351     *
352     * @param fMap The feature map which knows if a feature is categorical or real.
353     * @param numTrainingExamples The number of training examples this feature map observed.
354     * @param index The id number for this feature.
355     * @param firstValue The first feature value.
356     * @param secondValue The second feature value.
357     * @return The distance between the two values.
358     */
359    private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double firstValue, double secondValue) {
360        VariableInfo info = fMap.get(index);
361        if (info instanceof CategoricalInfo) {
362            if (Math.abs(firstValue - secondValue) > DISTANCE_DELTA) {
363                return 1.0;
364            } else {
365                // else the values are the same so the hamming distance is zero.
366                return 0.0;
367            }
368        } else if (info instanceof RealInfo) {
369            RealInfo rInfo = (RealInfo) info;
370            // Fudge the distance calculation so it doesn't overpower the categoricals.
371            double tmp = firstValue - secondValue;
372            double range;
373            // This further fudge is because the RealInfo may have observed a zero if it's sparse, but it might not.
374            if (numTrainingExamples != info.getCount()) {
375                range = Math.max(rInfo.getMax(),0.0) - Math.min(rInfo.getMin(),0.0);
376            } else {
377                range = rInfo.getMax() - rInfo.getMin();
378            }
379            return (tmp*tmp) / (range*range);
380        } else {
381            throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
382        }
383    }
384
385    /**
386     * Transforms a {@link Prediction} for a multiclass problem into a {@link Regressor}
387     * output which represents the probability for each class.
388     * <p>
389     * Used as the target for LIME Models.
390     * @param prediction A multiclass prediction object. Must contain probabilities.
391     * @return The n dimensional probability output.
392     */
393    public static Regressor transformOutput(Prediction<Label> prediction) {
394        Map<String,Label> outputs = prediction.getOutputScores();
395
396        String[] names = new String[outputs.size()];
397        double[] values = new double[outputs.size()];
398
399        int i = 0;
400        for (Map.Entry<String,Label> e : outputs.entrySet()) {
401            names[i] = e.getKey();
402            values[i] = e.getValue().getScore();
403            i++;
404        }
405
406        return new Regressor(names,values);
407    }
408
409}