metalearners.explainer module
- class metalearners.explainer.Explainer(cate_models)[source]
Bases:
objectResponsible class for managing all functions related to feature explanation and interpretation.
The
cate_modelsparameter should be a list of length \(n_{variants} -1\) containing a model for each treatment variant which estimates \(\tau_k\). The models should not be aCrossFitEstimatorrather just a plainsklearnBaseEstimator. A suggested option in the case of aCrossFitEstimatorwould be to use their_overall_estimator. These models should already be fitted on the data.- Parameters:
cate_models (list[_ScikitModel])
- classmethod from_estimates(X, cate_estimates, cate_model_factory, cate_model_params=None)[source]
Create an
Explainerobject from CATE estimates.This function will fit a model for each treatment variant with
Xas its input and the corresponding CATE estimates as its output.The
cate_estimatesshould be the raw outcome of a MetaLearner with 3 dimensions and should not be simplified.- Parameters:
X (DataFrame | ndarray)
cate_estimates (ndarray)
cate_model_factory (type[_ScikitModel])
cate_model_params (Mapping[str, int | float | str] | None)
- Return type:
- feature_importances(normalize=False, feature_names=None, sort_values=False)[source]
Calculates the feature importance for each treatment group.
If
normalization = True, for each treatment variant the feature importances are normalized so that they sum to 1.feature_namesis optional but in the case it’s not passed the names of the features will default tof"Feature {i}"whereiis the corresponding feature index.- Parameters:
normalize (bool)
feature_names (Collection[str] | None)
sort_values (bool)
- Return type:
list[Series]
- shap_values(X, shap_explainer_factory, shap_explainer_params=None)[source]
Calculates the shap values for each treatment group.
The parameter
shap_explainer_factorycan be used to specify the type of shap explainer, for the different options see here.- Parameters:
X (DataFrame | ndarray)
shap_explainer_factory (type[Explainer])
shap_explainer_params (dict | None)
- Return type:
list[ndarray]