001/* 002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.explanations.lime; 018 019import com.oracle.labs.mlrg.olcut.command.Command; 020import com.oracle.labs.mlrg.olcut.command.CommandGroup; 021import com.oracle.labs.mlrg.olcut.command.CommandInterpreter; 022import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 023import com.oracle.labs.mlrg.olcut.config.Option; 024import com.oracle.labs.mlrg.olcut.config.Options; 025import com.oracle.labs.mlrg.olcut.config.UsageException; 026import org.tribuo.Model; 027import org.tribuo.Prediction; 028import org.tribuo.SparseModel; 029import org.tribuo.SparseTrainer; 030import org.tribuo.VariableInfo; 031import org.tribuo.classification.Label; 032import org.tribuo.classification.LabelFactory; 033import org.tribuo.data.text.TextFeatureExtractor; 034import org.tribuo.data.text.impl.BasicPipeline; 035import org.tribuo.data.text.impl.TextFeatureExtractorImpl; 036import org.tribuo.regression.Regressor; 037import org.tribuo.regression.rtree.CARTJointRegressionTrainer; 038import org.jline.builtins.Completers; 039import org.jline.reader.Completer; 040import org.jline.reader.impl.completer.NullCompleter; 041import org.tribuo.util.tokens.Tokenizer; 042import org.tribuo.util.tokens.universal.UniversalTokenizer; 043 044import java.io.BufferedInputStream; 045import java.io.File; 046import java.io.FileInputStream; 047import java.io.FileNotFoundException; 048import java.io.IOException; 049import java.io.ObjectInputStream; 050import java.util.SplittableRandom; 051import java.util.logging.Level; 052import java.util.logging.Logger; 053 054/** 055 * A CLI for interacting with {@link LIMEText}. Uses a simple tokenisation and text extraction pipeline. 056 */ 057public class LIMETextCLI implements CommandGroup { 058 private static final Logger logger = Logger.getLogger(LIMETextCLI.class.getName()); 059 060 private final CommandInterpreter shell; 061 062 private Model<Label> model; 063 064 private int numSamples = 100; 065 066 private int numFeatures = 10; 067 068 //private SparseTrainer<Regressor> limeTrainer = new LARSLassoTrainer(numFeatures); 069 private SparseTrainer<Regressor> limeTrainer = new CARTJointRegressionTrainer((int)Math.log(numFeatures),true); 070 071 private Tokenizer tokenizer = new UniversalTokenizer(); 072 073 private TextFeatureExtractor<Label> extractor = new TextFeatureExtractorImpl<>(new BasicPipeline(tokenizer,2)); 074 075 private LIMEText limeText = null; 076 077 /** 078 * Constructs a LIME CLI. 079 */ 080 public LIMETextCLI() { 081 shell = new CommandInterpreter(); 082 shell.setPrompt("lime-text sh% "); 083 } 084 085 @Override 086 public String getName() { 087 return "LIME Text CLI"; 088 } 089 090 @Override 091 public String getDescription() { 092 return "Commands for experimenting with LIME Text."; 093 } 094 095 /** 096 * Completers for filenames. 097 * @return The filename completers. 098 */ 099 public Completer[] fileCompleter() { 100 return new Completer[]{ 101 new Completers.FileNameCompleter(), 102 new NullCompleter() 103 }; 104 } 105 106 /** 107 * Start the command shell 108 */ 109 public void startShell() { 110 shell.add(this); 111 shell.start(); 112 } 113 114 /** 115 * Loads a model in from disk. 116 * @param ci The command interpreter. 117 * @param path The path to load the model from. 118 * @param protobuf Load the model from protobuf? 119 * @return A status message. 120 */ 121 @Command(usage = "<filename> <load-protobuf> - Load a model from disk.", completers="fileCompleter") 122 public String loadModel(CommandInterpreter ci, File path, boolean protobuf) { 123 String output = "Failed to load model"; 124 if (protobuf) { 125 try { 126 Model<?> tmpModel = Model.deserializeFromFile(path.toPath()); 127 model = tmpModel.castModel(Label.class); 128 output = "Loaded model from path " + path.getAbsolutePath(); 129 } catch (IllegalStateException e) { 130 logger.log(Level.SEVERE, "Failed to deserialize protobuf when reading from file " + path.getAbsolutePath(), e); 131 } catch (IOException e) { 132 logger.log(Level.SEVERE, "IOException when reading from file " + path.getAbsolutePath(), e); 133 } 134 } else { 135 try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) { 136 Model<?> tmpModel = (Model<?>) ois.readObject(); 137 model = tmpModel.castModel(Label.class); 138 output = "Loaded model from path " + path.getAbsolutePath(); 139 } catch (ClassNotFoundException e) { 140 logger.log(Level.SEVERE, "Failed to load class from stream " + path.getAbsolutePath(), e); 141 } catch (FileNotFoundException e) { 142 logger.log(Level.SEVERE, "Failed to open file " + path.getAbsolutePath(), e); 143 } catch (IOException e) { 144 logger.log(Level.SEVERE, "IOException when reading from file " + path.getAbsolutePath(), e); 145 } 146 } 147 148 limeText = new LIMEText(new SplittableRandom(1),model,limeTrainer,numSamples,extractor,tokenizer); 149 150 return output; 151 } 152 153 /** 154 * Does the model generate probabilities. 155 * @param ci The command interpreter. 156 * @return True if the model generates probabilities. 157 */ 158 @Command(usage="Does the model generate probabilities") 159 public String generatesProbabilities(CommandInterpreter ci) { 160 return ""+model.generatesProbabilities(); 161 } 162 163 /** 164 * Shows the model description. 165 * @param ci The command interpreter. 166 * @return The model description. 167 */ 168 @Command(usage="Shows the model description") 169 public String modelDescription(CommandInterpreter ci) { 170 return model.toString(); 171 } 172 173 /** 174 * Shows information on a particular feature. 175 * @param ci The command interpreter. 176 * @param featureName The feature to show. 177 * @return Feature information. 178 */ 179 @Command(usage="Shows the information on a particular feature") 180 public String featureInfo(CommandInterpreter ci, String featureName) { 181 VariableInfo f = model.getFeatureIDMap().get(featureName); 182 if (f != null) { 183 return "" + f.toString(); 184 } else { 185 return "Feature " + featureName + " not found."; 186 } 187 } 188 189 /** 190 * Shows the top features of the loaded model. 191 * @param ci The command interpeter. 192 * @param numFeatures The number of features to show. 193 * @return The top features of the model. 194 */ 195 @Command(usage="<int> - Shows the top N features in the model") 196 public String topFeatures(CommandInterpreter ci, int numFeatures) { 197 return ""+ model.getTopFeatures(numFeatures); 198 } 199 200 /** 201 * Shows the number of features. 202 * @param ci The command interpreter. 203 * @return The number of features in the model. 204 */ 205 @Command(usage="Shows the number of features in the model") 206 public String numFeatures(CommandInterpreter ci) { 207 return ""+ model.getFeatureIDMap().size(); 208 } 209 210 /** 211 * Shows the number of features that occurred more than minCount times. 212 * @param ci The command interpreter. 213 * @param minCount The minimum feature occurrence. 214 * @return The number of features with more than minCount occurrences. 215 */ 216 @Command(usage="<min count> - Shows the number of features that occurred more than min count times.") 217 public String minCount(CommandInterpreter ci, int minCount) { 218 int counter = 0; 219 for (VariableInfo f : model.getFeatureIDMap()) { 220 if (f.getCount() > minCount) { 221 counter++; 222 } 223 } 224 return counter + " features occurred more than " + minCount + " times."; 225 } 226 227 /** 228 * Shows the output statistics. 229 * @param ci The command interpreter. 230 * @return The output statistics. 231 */ 232 @Command(usage="Shows the output statistics") 233 public String showLabelStats(CommandInterpreter ci) { 234 return "Label histogram : \n" + model.getOutputIDInfo().toReadableString(); 235 } 236 237 /** 238 * Sets the number of samples to use in LIME. 239 * @param ci The command interpreter. 240 * @param newNumSamples The number of samples to use in LIME. 241 * @return A status message. 242 */ 243 @Command(usage="Sets the number of samples to use in LIME") 244 public String setNumSamples(CommandInterpreter ci, int newNumSamples) { 245 numSamples = newNumSamples; 246 return "Set number of samples to " + numSamples; 247 } 248 249 /** 250 * Explains a text classification. 251 * @param ci The command interpreter. 252 * @param tokens A space separated token stream. 253 * @return An explanation. 254 */ 255 @Command(usage="Explain a text classification") 256 public String explain(CommandInterpreter ci, String[] tokens) { 257 String text = String.join(" ",tokens); 258 259 LIMEExplanation explanation = limeText.explain(text); 260 261 SparseModel<Regressor> model = explanation.getModel(); 262 263 ci.out.println("Active features of the predicted class = " + model.getActiveFeatures().get(explanation.getPrediction().getOutput().getLabel())); 264 265 return "Explanation = " + explanation.toString(); 266 } 267 268 /** 269 * Sets the number of features LIME should use in an explanation. 270 * @param ci The command interpreter. 271 * @param newNumFeatures The number of features. 272 * @return A status message. 273 */ 274 @Command(usage="Sets the number of features LIME should use in an explanation") 275 public String setNumFeatures(CommandInterpreter ci, int newNumFeatures) { 276 numFeatures = newNumFeatures; 277 //limeTrainer = new LARSLassoTrainer(numFeatures); 278 limeTrainer = new CARTJointRegressionTrainer((int)Math.log(numFeatures),true); 279 limeText = new LIMEText(new SplittableRandom(1),model,limeTrainer,numSamples,extractor, tokenizer); 280 return "Set the number of features in LIME to " + numFeatures; 281 } 282 283 /** 284 * Makes a prediction using the loaded model. 285 * @param ci The command interpreter. 286 * @param tokens A space separated token stream. 287 * @return The prediction. 288 */ 289 @Command(usage="Make a prediction") 290 public String predict(CommandInterpreter ci, String[] tokens) { 291 String text = String.join(" ",tokens); 292 293 Prediction<Label> prediction = model.predict(extractor.extract(LabelFactory.UNKNOWN_LABEL,text)); 294 295 return "Prediction = " + prediction.toString(); 296 } 297 298 /** 299 * Command line options. 300 */ 301 public static class LIMETextCLIOptions implements Options { 302 /** 303 * Model file to load. Optional. 304 */ 305 @Option(charName = 'f', longName = "filename", usage = "Model file to load. Optional.") 306 public String modelFilename; 307 308 /** 309 * Load the model from a protobuf. Optional. 310 */ 311 @Option(charName = 'p', longName = "protobuf-model", usage = "Load the model from a protobuf. Optional") 312 public boolean protobufFormat; 313 } 314 315 /** 316 * Runs a LIMETextCLI. 317 * @param args The CLI arguments. 318 */ 319 public static void main(String[] args) { 320 LIMETextCLI.LIMETextCLIOptions options = new LIMETextCLI.LIMETextCLIOptions(); 321 try { 322 ConfigurationManager cm = new ConfigurationManager(args, options, false); 323 LIMETextCLI driver = new LIMETextCLI(); 324 if (options.modelFilename != null) { 325 logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename), options.protobufFormat)); 326 } 327 driver.startShell(); 328 } catch (UsageException e) { 329 System.out.println("Usage: " + e.getUsage()); 330 } 331 } 332}