metalearners.explainer module

class metalearners.explainer.Explainer(cate_models)[source]

Bases: object

Responsible class for managing all functions related to feature explanation and interpretation.

The cate_models parameter should be a list of length \(n_{variants} -1\) containing a model for each treatment variant which estimates \(\tau_k\). The models should not be a CrossFitEstimator rather just a plain sklearn BaseEstimator. A suggested option in the case of a CrossFitEstimator would be to use their _overall_estimator. These models should already be fitted on the data.

Parameters:

cate_models (list[_ScikitModel])

classmethod from_estimates(X, cate_estimates, cate_model_factory, cate_model_params=None)[source]

Create an Explainer object from CATE estimates.

This function will fit a model for each treatment variant with X as its input and the corresponding CATE estimates as its output.

The cate_estimates should be the raw outcome of a MetaLearner with 3 dimensions and should not be simplified.

Parameters:
  • X (DataFrame | ndarray)

  • cate_estimates (ndarray)

  • cate_model_factory (type[_ScikitModel])

  • cate_model_params (Mapping[str, int | float | str] | None)

Return type:

Explainer

feature_importances(normalize=False, feature_names=None, sort_values=False)[source]

Calculates the feature importance for each treatment group.

If normalization = True, for each treatment variant the feature importances are normalized so that they sum to 1.

feature_names is optional but in the case it’s not passed the names of the features will default to f"Feature {i}" where i is the corresponding feature index.

Parameters:
  • normalize (bool)

  • feature_names (Collection[str] | None)

  • sort_values (bool)

Return type:

list[Series]

shap_values(X, shap_explainer_factory, shap_explainer_params=None)[source]

Calculates the shap values for each treatment group.

The parameter shap_explainer_factory can be used to specify the type of shap explainer, for the different options see here.

Parameters:
  • X (DataFrame | ndarray)

  • shap_explainer_factory (type[Explainer])

  • shap_explainer_params (dict | None)

Return type:

list[ndarray]