001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.explanations.lime; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.CategoricalInfo; 021import org.tribuo.Example; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.Model; 024import org.tribuo.MutableDataset; 025import org.tribuo.OutputFactory; 026import org.tribuo.Prediction; 027import org.tribuo.RealInfo; 028import org.tribuo.SparseModel; 029import org.tribuo.SparseTrainer; 030import org.tribuo.VariableIDInfo; 031import org.tribuo.VariableInfo; 032import org.tribuo.WeightedExamples; 033import org.tribuo.classification.Label; 034import org.tribuo.classification.LabelFactory; 035import org.tribuo.classification.explanations.TabularExplainer; 036import org.tribuo.impl.ArrayExample; 037import org.tribuo.interop.ExternalModel; 038import org.tribuo.math.la.SparseVector; 039import org.tribuo.math.la.VectorTuple; 040import org.tribuo.provenance.SimpleDataSourceProvenance; 041import org.tribuo.regression.RegressionFactory; 042import org.tribuo.regression.Regressor; 043import org.tribuo.regression.evaluation.RegressionEvaluation; 044import org.tribuo.regression.evaluation.RegressionEvaluator; 045import org.tribuo.util.Util; 046 047import java.time.OffsetDateTime; 048import java.util.ArrayList; 049import java.util.Iterator; 050import java.util.List; 051import java.util.Map; 052import java.util.Random; 053import java.util.SplittableRandom; 054import java.util.logging.Logger; 055 056/** 057 * LIMEBase merges the lime_base.py and lime_tabular.py implementations, and deals with simple 058 * matrices of numerical or categorical data. If you want a mixture of text, numerical 059 * and categorical data try {@link LIMEColumnar}. For plain text data use {@link LIMEText}. 060 * <p> 061 * See: 062 * <pre> 063 * Ribeiro MT, Singh S, Guestrin C. 064 * "Why should I trust you?: Explaining the predictions of any classifier" 065 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining 2016. 066 * </pre> 067 */ 068public class LIMEBase implements TabularExplainer<Regressor> { 069 private static final Logger logger = Logger.getLogger(LIMEBase.class.getName()); 070 071 /** 072 * Width of the noise gaussian. 073 */ 074 public static final double WIDTH_CONSTANT = 0.75; 075 /** 076 * Delta to consider two distances equal. 077 */ 078 public static final double DISTANCE_DELTA = 1e-12; 079 080 protected static final OutputFactory<Regressor> regressionFactory = new RegressionFactory(); 081 protected static final RegressionEvaluator evaluator = new RegressionEvaluator(true); 082 083 protected final SplittableRandom rng; 084 085 protected final Model<Label> innerModel; 086 087 protected final SparseTrainer<Regressor> explanationTrainer; 088 089 protected final int numSamples; 090 091 protected final long numTrainingExamples; 092 093 protected final double kernelWidth; 094 095 private final ImmutableFeatureMap fMap; 096 097 /** 098 * Constructs a LIME explainer for a model which uses tabular data (i.e., no special treatment for text features). 099 * @param rng The rng to use for sampling. 100 * @param innerModel The model to explain. 101 * @param explanationTrainer The sparse trainer used to explain predictions. 102 * @param numSamples The number of samples to generate for an explanation. 103 */ 104 public LIMEBase(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, int numSamples) { 105 if (!(explanationTrainer instanceof WeightedExamples)) { 106 throw new IllegalArgumentException("SparseTrainer must implement WeightedExamples, found " + explanationTrainer.toString()); 107 } 108 if (!innerModel.generatesProbabilities()) { 109 throw new IllegalArgumentException("LIME requires the model generate probabilities."); 110 } 111 if (innerModel instanceof ExternalModel) { 112 throw new IllegalArgumentException("LIME requires the model to have been trained in Tribuo. Found " + innerModel.getClass() + " which is an external model."); 113 } 114 this.rng = rng; 115 this.innerModel = innerModel; 116 this.explanationTrainer = explanationTrainer; 117 this.numSamples = numSamples; 118 this.numTrainingExamples = innerModel.getOutputIDInfo().getTotalObservations(); 119 this.kernelWidth = Math.pow(innerModel.getFeatureIDMap().size() * WIDTH_CONSTANT, 2); 120 this.fMap = innerModel.getFeatureIDMap(); 121 } 122 123 @Override 124 public LIMEExplanation explain(Example<Label> example) { 125 return explainWithSamples(example).getA(); 126 } 127 128 protected Pair<LIMEExplanation,List<Example<Regressor>>> explainWithSamples(Example<Label> example) { 129 // Predict using the full model, and generate a new example containing that prediction. 130 Prediction<Label> prediction = innerModel.predict(example); 131 Example<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction),example,1.0f); 132 133 // Sample a dataset. 134 List<Example<Regressor>> sample = sampleData(example); 135 136 // Generate a sparse model on the sampled data. 137 SparseModel<Regressor> model = trainExplainer(labelledExample,sample); 138 139 // Test the sparse model against the predictions of the real model. 140 List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample)); 141 predictions.add(model.predict(labelledExample)); 142 RegressionEvaluation evaluation = evaluator.evaluate(model,predictions,new SimpleDataSourceProvenance("LIMEColumnar sampled data",regressionFactory)); 143 144 return new Pair<>(new LIMEExplanation(model,prediction,evaluation),sample); 145 } 146 147 /** 148 * Sample a dataset based on the input example. 149 * <p> 150 * The sampled dataset uses the feature dimensions from the {@link Model}. 151 * <p> 152 * The outputs are the probability values of each class from the underlying Model, 153 * rather than ground truth outputs. The distance is measured using the 154 * {@link LIMEBase#measureDistance} function, transformed through a kernel and used 155 * as the sampled Example's weight. 156 * @param example The example to sample from. 157 * @return A sampled dataset. 158 */ 159 private List<Example<Regressor>> sampleData(Example<Label> example) { 160 List<Example<Regressor>> output = new ArrayList<>(); 161 162 SparseVector exampleVector = SparseVector.createSparseVector(example,fMap,false); 163 164 Random innerRNG = new Random(rng.nextLong()); 165 for (int i = 0; i < numSamples; i++) { 166 // Sample a new Example. 167 Example<Label> sample = samplePoint(innerRNG,fMap,numTrainingExamples,exampleVector); 168 169 //logger.fine("Itr " + i + " sampled " + sample.toString()); 170 171 // Label it using the full model. 172 Prediction<Label> samplePrediction = innerModel.predict(sample); 173 174 // Measure the distance between this point and the input, to be used as a weight. 175 double distance = measureDistance(fMap,numTrainingExamples,exampleVector, SparseVector.createSparseVector(sample,fMap,false)); 176 177 // Transform distance through the kernel function. 178 distance = kernelDist(distance,kernelWidth); 179 180 // Generate the new sample with the appropriate label and weight. 181 Example<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction),sample,(float)distance); 182 output.add(labelledSample); 183 } 184 185 return output; 186 } 187 188 /** 189 * Samples a single example from the supplied feature map and input vector. 190 * @param rng The rng to use. 191 * @param fMap The feature map describing the domain of the features. 192 * @param numTrainingExamples The number of training examples the fMap has seen. 193 * @param input The input sparse vector to use. 194 * @return An Example sampled from the supplied feature map and input vector. 195 */ 196 public static Example<Label> samplePoint(Random rng, ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input) { 197 ArrayList<String> names = new ArrayList<>(); 198 ArrayList<Double> values = new ArrayList<>(); 199 200 for (VariableInfo info : fMap) { 201 int id = ((VariableIDInfo)info).getID(); 202 double inputValue = input.get(id); 203 204 if (info instanceof CategoricalInfo) { 205 // This one is tricksy as categorical info essentially implicitly includes a zero. 206 CategoricalInfo catInfo = (CategoricalInfo) info; 207 double sample = catInfo.frequencyBasedSample(rng,numTrainingExamples); 208 // If we didn't sample zero. 209 if (Math.abs(sample) > 1e-10) { 210 names.add(info.getName()); 211 values.add(sample); 212 } 213 } else if (info instanceof RealInfo) { 214 RealInfo realInfo = (RealInfo) info; 215 // As realInfo is sparse we sample from the mixture distribution, 216 // either 0 or N(inputValue,variance). 217 // This assumes realInfo never observed a zero, which is enforced from v2.1 218 // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab? 219 // If it's not zero do we want to? 220 int count = realInfo.getCount(); 221 double threshold = count / ((double)numTrainingExamples); 222 if (rng.nextDouble() < threshold) { 223 double variance = realInfo.getVariance(); 224 double sample = (rng.nextGaussian() * Math.sqrt(variance)) + inputValue; 225 names.add(info.getName()); 226 values.add(sample); 227 } 228 } else { 229 throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName()); 230 } 231 } 232 233 return new ArrayExample<>(LabelFactory.UNKNOWN_LABEL,names.toArray(new String[0]),Util.toPrimitiveDouble(values)); 234 } 235 236 /** 237 * Trains the explanation model using the supplied sampled data and the input example. 238 * <p> 239 * The labels are usually the predicted probabilities from the real model. 240 * @param target The input example to explain. 241 * @param samples The sampled data around the input. 242 * @return An explanation model. 243 */ 244 protected SparseModel<Regressor> trainExplainer(Example<Regressor> target, List<Example<Regressor>> samples) { 245 MutableDataset<Regressor> explanationDataset = new MutableDataset<>(new SimpleDataSourceProvenance("explanationDataset", OffsetDateTime.now(), regressionFactory), regressionFactory); 246 explanationDataset.add(target); 247 explanationDataset.addAll(samples); 248 249 SparseModel<Regressor> explainer = explanationTrainer.train(explanationDataset); 250 251 return explainer; 252 } 253 254 /** 255 * Calculates an RBF kernel of a specific width. 256 * @param input The input value. 257 * @param width The width of the kernel. 258 * @return sqrt ( exp ( - input*input / width)) 259 */ 260 public static double kernelDist(double input, double width) { 261 return Math.sqrt(Math.exp(-(input*input) / width)); 262 } 263 264 /** 265 * Measures the distance between an input point and a sampled point. 266 * <p> 267 * This distance function takes into account categorical and real values. It uses 268 * the hamming distance for categoricals and the euclidean distance for real values. 269 * @param fMap The feature map used to determine if a feature is categorical or real. 270 * @param numTrainingExamples The number of training examples the fMap has seen. 271 * @param input The input point. 272 * @param sample The sampled point. 273 * @return The distance between the two points. 274 */ 275 public static double measureDistance(ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input, SparseVector sample) { 276 double score = 0.0; 277 278 Iterator<VectorTuple> itr = input.iterator(); 279 Iterator<VectorTuple> otherItr = sample.iterator(); 280 VectorTuple tuple; 281 VectorTuple otherTuple; 282 while (itr.hasNext() && otherItr.hasNext()) { 283 tuple = itr.next(); 284 otherTuple = otherItr.next(); 285 //after this loop, either itr is out or tuple.index >= otherTuple.index 286 while (itr.hasNext() && (tuple.index < otherTuple.index)) { 287 score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value); 288 tuple = itr.next(); 289 } 290 //after this loop, either otherItr is out or tuple.index <= otherTuple.index 291 while (otherItr.hasNext() && (tuple.index > otherTuple.index)) { 292 score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value); 293 otherTuple = otherItr.next(); 294 } 295 if (tuple.index == otherTuple.index) { 296 //the indices line up, do the calculation. 297 score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value,otherTuple.value); 298 } else { 299 // Now consume both the values as they'll be gone next iteration. 300 // Consume the value in tuple. 301 score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value); 302 // Consume the value in otherTuple. 303 score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value); 304 } 305 } 306 while (itr.hasNext()) { 307 tuple = itr.next(); 308 score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value); 309 } 310 while (otherItr.hasNext()) { 311 otherTuple = otherItr.next(); 312 score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value); 313 } 314 315 return Math.sqrt(score); 316 } 317 318 /** 319 * Calculates the distance between two values for a single feature. 320 * <p> 321 * Assumes the other value is zero as the example is sparse. 322 * @param fMap The feature map which knows if a feature is categorical or real. 323 * @param numTrainingExamples The number of training examples this feature map observed. 324 * @param index The id number for this feature. 325 * @param value One feature value. 326 * @return The distance from zero to the supplied value. 327 */ 328 private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double value) { 329 VariableInfo info = fMap.get(index); 330 if (info instanceof CategoricalInfo) { 331 return 1.0; 332 } else if (info instanceof RealInfo) { 333 RealInfo rInfo = (RealInfo) info; 334 // Fudge the distance calculation so it doesn't overpower the categoricals. 335 double curScore = value * value; 336 double range; 337 // This further fudge is because the RealInfo may have observed a zero if it's sparse, but it might not. 338 if (numTrainingExamples != info.getCount()) { 339 range = Math.max(rInfo.getMax(),0.0) - Math.min(rInfo.getMin(),0.0); 340 } else { 341 range = rInfo.getMax() - rInfo.getMin(); 342 } 343 return curScore / (range*range); 344 } else { 345 throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName()); 346 } 347 } 348 349 /** 350 * Calculates the distance between two values for a single feature. 351 * 352 * @param fMap The feature map which knows if a feature is categorical or real. 353 * @param numTrainingExamples The number of training examples this feature map observed. 354 * @param index The id number for this feature. 355 * @param firstValue The first feature value. 356 * @param secondValue The second feature value. 357 * @return The distance between the two values. 358 */ 359 private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double firstValue, double secondValue) { 360 VariableInfo info = fMap.get(index); 361 if (info instanceof CategoricalInfo) { 362 if (Math.abs(firstValue - secondValue) > DISTANCE_DELTA) { 363 return 1.0; 364 } else { 365 // else the values are the same so the hamming distance is zero. 366 return 0.0; 367 } 368 } else if (info instanceof RealInfo) { 369 RealInfo rInfo = (RealInfo) info; 370 // Fudge the distance calculation so it doesn't overpower the categoricals. 371 double tmp = firstValue - secondValue; 372 double range; 373 // This further fudge is because the RealInfo may have observed a zero if it's sparse, but it might not. 374 if (numTrainingExamples != info.getCount()) { 375 range = Math.max(rInfo.getMax(),0.0) - Math.min(rInfo.getMin(),0.0); 376 } else { 377 range = rInfo.getMax() - rInfo.getMin(); 378 } 379 return (tmp*tmp) / (range*range); 380 } else { 381 throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName()); 382 } 383 } 384 385 /** 386 * Transforms a {@link Prediction} for a multiclass problem into a {@link Regressor} 387 * output which represents the probability for each class. 388 * <p> 389 * Used as the target for LIME Models. 390 * @param prediction A multiclass prediction object. Must contain probabilities. 391 * @return The n dimensional probability output. 392 */ 393 public static Regressor transformOutput(Prediction<Label> prediction) { 394 Map<String,Label> outputs = prediction.getOutputScores(); 395 396 String[] names = new String[outputs.size()]; 397 double[] values = new double[outputs.size()]; 398 399 int i = 0; 400 for (Map.Entry<String,Label> e : outputs.entrySet()) { 401 names[i] = e.getKey(); 402 values[i] = e.getValue().getScore(); 403 i++; 404 } 405 406 return new Regressor(names,values); 407 } 408 409}