001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.explanations.lime; 018 019import org.tribuo.Model; 020import org.tribuo.Prediction; 021import org.tribuo.SparseModel; 022import org.tribuo.classification.Label; 023import org.tribuo.classification.explanations.Explanation; 024import org.tribuo.regression.Regressor; 025import org.tribuo.regression.evaluation.RegressionEvaluation; 026 027import java.util.List; 028import java.util.Map; 029 030/** 031 * An {@link Explanation} using LIME. 032 * <p> 033 * Wraps a {@link SparseModel} {@link Regressor} which is trained to predict the probabilities 034 * generated by the true {@link Model}. 035 * <p> 036 * See: 037 * <pre> 038 * Ribeiro MT, Singh S, Guestrin C. 039 * "Why should I trust you?: Explaining the predictions of any classifier" 040 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining 2016. 041 * </pre> 042 */ 043public class LIMEExplanation implements Explanation<Regressor> { 044 private static final long serialVersionUID = 1L; 045 046 private final SparseModel<Regressor> model; 047 048 private final Prediction<Label> prediction; 049 050 private final RegressionEvaluation evaluation; 051 052 /** 053 * Constructs a LIME explanation. 054 * @param model The explanation model. 055 * @param prediction The prediction being explained. 056 * @param evaluation The evaluation of the explanation model on the sampled data. 057 */ 058 public LIMEExplanation(SparseModel<Regressor> model, Prediction<Label> prediction, RegressionEvaluation evaluation) { 059 this.model = model; 060 this.prediction = prediction; 061 this.evaluation = evaluation; 062 } 063 064 @Override 065 public List<String> getActiveFeatures() { 066 Map<String,List<String>> features = model.getActiveFeatures(); 067 if (features.containsKey(Model.ALL_OUTPUTS)) { 068 return features.get(Model.ALL_OUTPUTS); 069 } else { 070 return features.get(prediction.getOutput().getLabel()); 071 } 072 } 073 074 @Override 075 public SparseModel<Regressor> getModel() { 076 return model; 077 } 078 079 @Override 080 public Prediction<Label> getPrediction() { 081 return prediction; 082 } 083 084 /** 085 * Gets the evaluator which scores how close the sparse model's 086 * predictions are to the complex model's predictions. 087 * @return The evaluation. 088 */ 089 public RegressionEvaluation getEvaluation() { 090 return evaluation; 091 } 092 093 /** 094 * Get the RMSE of a specific dimension of the explanation model. 095 * @param name The dimension to look at. 096 * @return The RMSE of the explanation model. 097 */ 098 public double getRMSE(String name) { 099 return evaluation.rmse().get(new Regressor.DimensionTuple(name,Double.NaN)); 100 } 101 102 @Override 103 public String toString() { 104 return "LIMEExplanation(linearRMSE="+evaluation.rmse()+",modelPrediction="+prediction+",activeFeatures="+getActiveFeatures().toString()+")"; 105 } 106}