{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(example-lime)=\n", "\n", " Example: Explainability LIME plots for MetaLearners\n", "====================================================\n", "\n", "Motivation\n", "----------\n", "\n", "LIME -- short for *local interpretable model-agnostic explanations* -- is a method developed by [Ribeiro et al. (2016)](https://arxiv.org/abs/1602.04938).\n", "LIME falls under the umbrella term of explainability methods in Machine Learning. On a high level,\n", "it is meant to serve the purpose of providing explanations, intuitions\n", "or examples as to how a model or estimator works.\n", "\n", "The authors argue that\n", "\n", "> If the users do not trust a model or prediction, they will not use it.\n", "\n", "While LIME is typically used in supervised learning scenarios, the key\n", "motivation of better understanding a model's behaviour applies just as\n", "well to CATE estimation. Therefore, we illustrate how it can be used\n", "with the MetaLearner from ``metalearners``.\n", "\n", "Background\n", "----------\n", "\n", "As the name suggests, LIME is model-agnostic and can be used for any\n", "black-box model which receives features or covariates and maps those\n", "to a one-dimension vector of equal number of rows.\n", "\n", "As the name also suggests, the explanations provided by LIME are\n", "local. The authors state the following:\n", "\n", "> [...] for an explanation to be meaningful it must at least be locally faithful, i.e. must correspond to how the model behaves in the vicinity of the instance being predicted.\n", "\n", "Concretely, this means that LIME focuses on one sample -- or its\n", "locality/vicinity/neighborhood -- at a time and tries to imitate the\n", "true model behaviour around that sample with a simpler model.\n", "\n", "In other words, LIME's objective is to choose a substitute model for\n", "our complex model, simulaneously considering two concerns:\n", "\n", "* the interpretability of our new, simple model (let's call this surrogate)\n", "* the approximation error between the surrogate and the original,\n", " complex model\n", "\n", "More formally, the authors define:\n", "\n", "* {math}`f`, the original model -- in our case the MetaLearner\n", "* {math}`G`, the class of possible, interpretable surrogate models\n", "* {math}`\\Omega(g)`, a measure of complexity for {math}`g \\in G`\n", "* {math}`\\pi_x(z)` a proximity measure of an instance {math}`z` with respect to data point {math}`x`\n", "* {math}`\\mathcal{L}(f, g, \\pi_x)` a measure of how unfaithful a {math}`g \\in G` is to {math}`f` in the locality defined by {math}`\\pi_x`\n", "\n", "Given all of these objects as well as a to be explained data point {math}`x`, the authors suggest that the most appropriate surrogate {math}`g`, also referred to as explanation for {math}`x`, {math}`\\xi(x)`, can be expressed as follows:\n", "\n", "```{math}\n", " \\xi(x) = argmin_{g' \\in G} \\mathcal{L}(f, g', \\pi_x) + \\Omega(g')\n", "```\n", "\n", "The authors suggest a mechanisms to optimize this problem, i.e. to\n", "find suitable local explanations.\n", "\n", "Moreover, they suggest a systematic approach to selecting a set of samples, as for\n", "their respective local explanations to be as telling of the overall model\n", "behaviour as possible. Intuitively, the authors suggest to select a pool of explanations which\n", "\n", "* have little redundancy between each other\n", "* showcase the features with highest global importance\n", "\n", "In line with this ambition, they define a notion of 'coverage' which specifies how well a set\n", "of candidate datapoints {math}`V` are explained by features that are relevant for\n", "many observed datapoints. The goal is to find {math}`V` that is not larger than some\n", "pre-specified size such that this coverage is maximal.\n", "\n", "```{math}\n", " c(V, W, \\mathcal{I}) = \\sum_{j=1}^{d} \\mathbb{I}\\{\\exists i \\in V: W_{i,j} > 0\\} \\mathcal{I}_j\n", "````\n", "\n", "where\n", "\n", "* {math}`d` is the number of features\n", "* {math}`V` is the candidate set of explanations to be shown to\n", " humans, within a fixed budget -- this is the variable to be optimized\n", "* {math}`W` is a {math}`n \\times d` local feature importance matrix that represents\n", " the local importance of each feature for each instance, and\n", "* {math}`\\mathcal{I}` is a {math}`d`-dimensional vector of global\n", " feature importances\n", "\n", "Implicitly, the authors suppose that local model {math}`\\xi(x_i)` has a\n", "canonical way of determining feature importances for {math}`W` --\n", "e.g. weights in a linear model --\n", "and that a global model {math}`f` does so, too, for\n", "{math}`\\mathcal{I}`.\n", "\n", "Picking data points to optimize this notion of coverage is reflected\n", "in ``lime``'s ``SubmodularPick`` class, which we use below.\n", "\n", "Installation\n", "------------\n", "\n", "In order to generate LIME plots, we first need to install the [lime package](https://github.com/marcotcr/lime).\n", "We can do so either via conda and conda-forge:\n", "\n", "```console\n", "$ conda install lime -c conda-forge\n", "```\n", "\n", "or via pip and PyPI\n", "\n", "```console\n", "$ pip install lime\n", "```\n", "\n", "Usage\n", "-----\n", "\n", "### Loading the data\n", "\n", "Just like in our {ref}`example on estimating CATEs with a MetaLearner\n", "`, we will first load some experiment data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "import pandas as pd\n", "from pathlib import Path\n", "from git_root import git_root\n", "\n", "df = pd.read_csv(git_root(\"data/learning_mindset.zip\"))\n", "outcome_column = \"achievement_score\"\n", "treatment_column = \"intervention\"\n", "feature_columns = [\n", " column\n", " for column in df.columns\n", " if column not in [outcome_column, treatment_column]\n", "]\n", "categorical_feature_columns = [\n", " \"ethnicity\",\n", " \"gender\",\n", " \"frst_in_family\",\n", " \"school_urbanicity\",\n", " \"schoolid\",\n", "]\n", "# Note that explicitly setting the dtype of these features to category\n", "# allows both lightgbm as well as shap plots to\n", "# 1. Operate on features which are not of type int, bool or float\n", "# 2. Correctly interpret categoricals with int values to be\n", "# interpreted as categoricals, as compared to ordinals/numericals.\n", "for categorical_feature_column in categorical_feature_columns:\n", " df[categorical_feature_column] = df[categorical_feature_column].astype(\n", " \"category\"\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we've loaded the experiment data, we can train a MetaLearner.\n", "\n", "\n", "### Training a MetaLearner\n", "\n", "Again, mirroring our {ref}`example on estimating CATEs with a MetaLearner\n", "`, we can train an\n", "{class}`~metalearners.rlearner.RLearner` as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "from metalearners import RLearner\n", "from lightgbm import LGBMRegressor, LGBMClassifier\n", "\n", "rlearner = RLearner(\n", " nuisance_model_factory=LGBMRegressor,\n", " propensity_model_factory=LGBMClassifier,\n", " treatment_model_factory=LGBMRegressor,\n", " is_classification=False,\n", " n_variants=2,\n", " nuisance_model_params={\"verbose\": -1},\n", " propensity_model_params={\"verbose\": -1},\n", " treatment_model_params={\"verbose\": -1},\n", ")\n", "\n", "rlearner.fit(\n", " X=df[feature_columns],\n", " y=df[outcome_column],\n", " w=df[treatment_column],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generating lime plots\n", "\n", "``lime`` will expect a function which consumes an ``X`` and returns\n", "a one-dimensional vector of the same length as ``X``. We'll have to\n", "adapt the {meth}`~metalearners.rlearner.RLearner.predict` method of\n", "our {class}`~metalearners.rlearner.RLearner` in two ways:\n", "\n", "* We need to pass a value for the necessary parameter ``is_oos`` to {meth}`~metalearners.rlearner.RLearner.predict`.\n", "\n", "* We need to reshape the output of\n", " {meth}`~metalearners.rlearner.RLearner.predict` to be one-dimensional. This\n", " we can easily achieve via {func}`metalearners.utils.simplify_output`.\n", "\n", "This we can do as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "from metalearners.utils import simplify_output\n", "\n", "def predict(X):\n", " return simplify_output(rlearner.predict(X, is_oos=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "where we set ``is_oos=True`` since ``lime`` will call\n", "{meth}`~metalearners.rlearner.RLearner.predict`\n", "with various inputs which will not be able to be recognized as\n", "in-sample data.\n", "\n", "Since ``lime`` expects ``numpy`` datastructures, we'll have to\n", "manually encode the categorical features of our ``pandas`` data\n", "structure, see [this issue](https://github.com/microsoft/LightGBM/issues/5162) for more context." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "X = df[feature_columns].copy()\n", "for categorical_feature_column in categorical_feature_columns:\n", " X[categorical_feature_column] = X[categorical_feature_column].cat.codes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Moreover, we need to manually prepare the mapping of categorical codes\n", "to categorical values as well as the indices of categorical features:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "categorical_names: list[list] = []\n", "for i, column in enumerate(feature_columns):\n", " categorical_names.append([])\n", " if column in categorical_feature_columns:\n", " categorical_names[i] = list(df[column].cat.categories)\n", "\n", "categorical_feature_indices = [\n", " i for i, name in enumerate(feature_columns) if name in categorical_feature_columns\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now create the necessary ``lime`` objects:\n", "``LimeTabularExplainer`` to explain a sample at hand as\n", "well as ``SubmodularPick``, choosing samples for us to be\n", "locally explained.\n", "\n", "In the following we can see the three explanations which have been chosen. We find the\n", "most locally most relevant features on the vertical axis and the\n", "outcome dimension on the horizontal axis." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "from lime.lime_tabular import LimeTabularExplainer\n", "from lime.submodular_pick import SubmodularPick\n", "\n", "X = X.to_numpy()\n", "\n", "explainer = LimeTabularExplainer(\n", " X,\n", " feature_names=feature_columns,\n", " categorical_features=categorical_feature_indices,\n", " categorical_names=categorical_names,\n", " verbose=False,\n", " mode=\"regression\",\n", " discretize_continuous=True,\n", ")\n", "\n", "sp = SubmodularPick(\n", " data=X,\n", " explainer=explainer,\n", " predict_fn=predict,\n", " method=\"sample\",\n", " sample_size=1_000,\n", " num_exps_desired=3,\n", " num_features=5,\n", ")\n", "\n", "for explanation in sp.sp_explanations:\n", " explanation.as_pyplot_figure()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In these plots, the green bars signify that the presence of the corresponding feature\n", "referenced on the y-axis, increases the CATE estimate for that observation, whereas, the\n", "red bars represent that the feature presence in the observation reduces the CATE.\n", "Furthermore, the length of these colored bars corresponds to the magnitude of each feature's\n", "contribution towards the model prediction. Therefore, the longer the bar, the more\n", "significant the impact of that feature on the model prediction.\n", "\n", "For more guidelines on how to interpret such lime plots please see the [lime documentation](https://github.com/marcotcr/lime)." ] } ], "metadata": { "language_info": { "name": "python" }, "mystnb": { "execution_timeout": 240 } }, "nbformat": 4, "nbformat_minor": 2 }