001/*
002 * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.explanations.lime;
018
019import com.oracle.labs.mlrg.olcut.command.Command;
020import com.oracle.labs.mlrg.olcut.command.CommandGroup;
021import com.oracle.labs.mlrg.olcut.command.CommandInterpreter;
022import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
023import com.oracle.labs.mlrg.olcut.config.Option;
024import com.oracle.labs.mlrg.olcut.config.Options;
025import com.oracle.labs.mlrg.olcut.config.UsageException;
026import org.tribuo.Model;
027import org.tribuo.Prediction;
028import org.tribuo.SparseModel;
029import org.tribuo.SparseTrainer;
030import org.tribuo.VariableInfo;
031import org.tribuo.classification.Label;
032import org.tribuo.classification.LabelFactory;
033import org.tribuo.data.text.TextFeatureExtractor;
034import org.tribuo.data.text.impl.BasicPipeline;
035import org.tribuo.data.text.impl.TextFeatureExtractorImpl;
036import org.tribuo.regression.Regressor;
037import org.tribuo.regression.rtree.CARTJointRegressionTrainer;
038import org.jline.builtins.Completers;
039import org.jline.reader.Completer;
040import org.jline.reader.impl.completer.NullCompleter;
041import org.tribuo.util.tokens.Tokenizer;
042import org.tribuo.util.tokens.universal.UniversalTokenizer;
043
044import java.io.BufferedInputStream;
045import java.io.File;
046import java.io.FileInputStream;
047import java.io.FileNotFoundException;
048import java.io.IOException;
049import java.io.ObjectInputStream;
050import java.util.SplittableRandom;
051import java.util.logging.Level;
052import java.util.logging.Logger;
053
054/**
055 * A CLI for interacting with {@link LIMEText}. Uses a simple tokenisation and text extraction pipeline.
056 */
057public class LIMETextCLI implements CommandGroup {
058    private static final Logger logger = Logger.getLogger(LIMETextCLI.class.getName());
059
060    private final CommandInterpreter shell;
061
062    private Model<Label> model;
063
064    private int numSamples = 100;
065
066    private int numFeatures = 10;
067
068    //private SparseTrainer<Regressor> limeTrainer = new LARSLassoTrainer(numFeatures);
069    private SparseTrainer<Regressor> limeTrainer = new CARTJointRegressionTrainer((int)Math.log(numFeatures),true);
070
071    private Tokenizer tokenizer = new UniversalTokenizer();
072
073    private TextFeatureExtractor<Label> extractor = new TextFeatureExtractorImpl<>(new BasicPipeline(tokenizer,2));
074
075    private LIMEText limeText = null;
076
077    /**
078     * Constructs a LIME CLI.
079     */
080    public LIMETextCLI() {
081        shell = new CommandInterpreter();
082        shell.setPrompt("lime-text sh% ");
083    }
084
085    @Override
086    public String getName() {
087        return "LIME Text CLI";
088    }
089
090    @Override
091    public String getDescription() {
092        return "Commands for experimenting with LIME Text.";
093    }
094
095    /**
096     * Completers for filenames.
097     * @return The filename completers.
098     */
099    public Completer[] fileCompleter() {
100        return new Completer[]{
101                new Completers.FileNameCompleter(),
102                new NullCompleter()
103        };
104    }
105
106    /**
107     * Start the command shell
108     */
109    public void startShell() {
110        shell.add(this);
111        shell.start();
112    }
113
114    /**
115     * Loads a model in from disk.
116     * @param ci The command interpreter.
117     * @param path The path to load the model from.
118     * @param protobuf Load the model from protobuf?
119     * @return A status message.
120     */
121    @Command(usage = "<filename> <load-protobuf> - Load a model from disk.", completers="fileCompleter")
122    public String loadModel(CommandInterpreter ci, File path, boolean protobuf) {
123        String output = "Failed to load model";
124        if (protobuf) {
125            try {
126                Model<?> tmpModel = Model.deserializeFromFile(path.toPath());
127                model = tmpModel.castModel(Label.class);
128                output = "Loaded model from path " + path.getAbsolutePath();
129            } catch (IllegalStateException e) {
130                logger.log(Level.SEVERE, "Failed to deserialize protobuf when reading from file " + path.getAbsolutePath(), e);
131            } catch (IOException e) {
132                logger.log(Level.SEVERE, "IOException when reading from file " + path.getAbsolutePath(), e);
133            }
134        } else {
135            try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) {
136                Model<?> tmpModel = (Model<?>) ois.readObject();
137                model = tmpModel.castModel(Label.class);
138                output = "Loaded model from path " + path.getAbsolutePath();
139            } catch (ClassNotFoundException e) {
140                logger.log(Level.SEVERE, "Failed to load class from stream " + path.getAbsolutePath(), e);
141            } catch (FileNotFoundException e) {
142                logger.log(Level.SEVERE, "Failed to open file " + path.getAbsolutePath(), e);
143            } catch (IOException e) {
144                logger.log(Level.SEVERE, "IOException when reading from file " + path.getAbsolutePath(), e);
145            }
146        }
147
148        limeText = new LIMEText(new SplittableRandom(1),model,limeTrainer,numSamples,extractor,tokenizer);
149
150        return output;
151    }
152
153    /**
154     * Does the model generate probabilities.
155     * @param ci The command interpreter.
156     * @return True if the model generates probabilities.
157     */
158    @Command(usage="Does the model generate probabilities")
159    public String generatesProbabilities(CommandInterpreter ci) {
160        return ""+model.generatesProbabilities();
161    }
162
163    /**
164     * Shows the model description.
165     * @param ci The command interpreter.
166     * @return The model description.
167     */
168    @Command(usage="Shows the model description")
169    public String modelDescription(CommandInterpreter ci) {
170        return model.toString();
171    }
172
173    /**
174     * Shows information on a particular feature.
175     * @param ci The command interpreter.
176     * @param featureName The feature to show.
177     * @return Feature information.
178     */
179    @Command(usage="Shows the information on a particular feature")
180    public String featureInfo(CommandInterpreter ci, String featureName) {
181        VariableInfo f = model.getFeatureIDMap().get(featureName);
182        if (f != null) {
183            return "" + f.toString();
184        } else {
185            return "Feature " + featureName + " not found.";
186        }
187    }
188
189    /**
190     * Shows the top features of the loaded model.
191     * @param ci The command interpeter.
192     * @param numFeatures The number of features to show.
193     * @return The top features of the model.
194     */
195    @Command(usage="<int> - Shows the top N features in the model")
196    public String topFeatures(CommandInterpreter ci, int numFeatures) {
197        return ""+ model.getTopFeatures(numFeatures);
198    }
199
200    /**
201     * Shows the number of features.
202     * @param ci The command interpreter.
203     * @return The number of features in the model.
204     */
205    @Command(usage="Shows the number of features in the model")
206    public String numFeatures(CommandInterpreter ci) {
207        return ""+ model.getFeatureIDMap().size();
208    }
209
210    /**
211     * Shows the number of features that occurred more than minCount times.
212     * @param ci The command interpreter.
213     * @param minCount The minimum feature occurrence.
214     * @return The number of features with more than minCount occurrences.
215     */
216    @Command(usage="<min count> - Shows the number of features that occurred more than min count times.")
217    public String minCount(CommandInterpreter ci, int minCount) {
218        int counter = 0;
219        for (VariableInfo f : model.getFeatureIDMap()) {
220            if (f.getCount() > minCount) {
221                counter++;
222            }
223        }
224        return counter + " features occurred more than " + minCount + " times.";
225    }
226
227    /**
228     * Shows the output statistics.
229     * @param ci The command interpreter.
230     * @return The output statistics.
231     */
232    @Command(usage="Shows the output statistics")
233    public String showLabelStats(CommandInterpreter ci) {
234        return "Label histogram : \n" + model.getOutputIDInfo().toReadableString();
235    }
236
237    /**
238     * Sets the number of samples to use in LIME.
239     * @param ci The command interpreter.
240     * @param newNumSamples The number of samples to use in LIME.
241     * @return A status message.
242     */
243    @Command(usage="Sets the number of samples to use in LIME")
244    public String setNumSamples(CommandInterpreter ci, int newNumSamples) {
245        numSamples = newNumSamples;
246        return "Set number of samples to " + numSamples;
247    }
248
249    /**
250     * Explains a text classification.
251     * @param ci The command interpreter.
252     * @param tokens A space separated token stream.
253     * @return An explanation.
254     */
255    @Command(usage="Explain a text classification")
256    public String explain(CommandInterpreter ci, String[] tokens) {
257        String text = String.join(" ",tokens);
258
259        LIMEExplanation explanation = limeText.explain(text);
260
261        SparseModel<Regressor> model = explanation.getModel();
262
263        ci.out.println("Active features of the predicted class = " + model.getActiveFeatures().get(explanation.getPrediction().getOutput().getLabel()));
264
265        return "Explanation = " + explanation.toString();
266    }
267
268    /**
269     * Sets the number of features LIME should use in an explanation.
270     * @param ci The command interpreter.
271     * @param newNumFeatures The number of features.
272     * @return A status message.
273     */
274    @Command(usage="Sets the number of features LIME should use in an explanation")
275    public String setNumFeatures(CommandInterpreter ci, int newNumFeatures) {
276        numFeatures = newNumFeatures;
277        //limeTrainer = new LARSLassoTrainer(numFeatures);
278        limeTrainer = new CARTJointRegressionTrainer((int)Math.log(numFeatures),true);
279        limeText = new LIMEText(new SplittableRandom(1),model,limeTrainer,numSamples,extractor, tokenizer);
280        return "Set the number of features in LIME to " + numFeatures;
281    }
282
283    /**
284     * Makes a prediction using the loaded model.
285     * @param ci The command interpreter.
286     * @param tokens A space separated token stream.
287     * @return The prediction.
288     */
289    @Command(usage="Make a prediction")
290    public String predict(CommandInterpreter ci, String[] tokens) {
291        String text = String.join(" ",tokens);
292
293        Prediction<Label> prediction = model.predict(extractor.extract(LabelFactory.UNKNOWN_LABEL,text));
294
295        return "Prediction = " + prediction.toString();
296    }
297
298    /**
299     * Command line options.
300     */
301    public static class LIMETextCLIOptions implements Options {
302        /**
303         * Model file to load. Optional.
304         */
305        @Option(charName = 'f', longName = "filename", usage = "Model file to load. Optional.")
306        public String modelFilename;
307
308        /**
309         * Load the model from a protobuf. Optional.
310         */
311        @Option(charName = 'p', longName = "protobuf-model", usage = "Load the model from a protobuf. Optional")
312        public boolean protobufFormat;
313    }
314
315    /**
316     * Runs a LIMETextCLI.
317     * @param args The CLI arguments.
318     */
319    public static void main(String[] args) {
320        LIMETextCLI.LIMETextCLIOptions options = new LIMETextCLI.LIMETextCLIOptions();
321        try {
322            ConfigurationManager cm = new ConfigurationManager(args, options, false);
323            LIMETextCLI driver = new LIMETextCLI();
324            if (options.modelFilename != null) {
325                logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename), options.protobufFormat));
326            }
327            driver.startShell();
328        } catch (UsageException e) {
329            System.out.println("Usage: " + e.getUsage());
330        }
331    }
332}