metalearners.xlearner module

class metalearners.xlearner.XLearner(is_classification, n_variants, nuisance_model_factory=None, treatment_model_factory=None, propensity_model_factory=None, nuisance_model_params=None, treatment_model_params=None, propensity_model_params=None, fitted_nuisance_models=None, fitted_propensity_model=None, feature_set=None, n_folds=10, random_state=None)[source]

Bases: _ConditionalAverageOutcomeMetaLearner

X-Learner for CATE estimation as described by Kuenzel et al (2019).

Importantly, the current X-Learner implementation only supports:

  • binary classes in case of a classification outcome

Parameters:
  • is_classification (bool)

  • n_variants (int)

  • nuisance_model_factory (type[_ScikitModel] | dict[str, type[_ScikitModel]] | None)

  • treatment_model_factory (type[_ScikitModel] | dict[str, type[_ScikitModel]] | None)

  • propensity_model_factory (type[_ScikitModel] | None)

  • nuisance_model_params (Mapping[str, int | float | str] | dict[str, Mapping[str, int | float | str]] | None)

  • treatment_model_params (Mapping[str, int | float | str] | dict[str, Mapping[str, int | float | str]] | None)

  • propensity_model_params (Mapping[str, int | float | str] | None)

  • fitted_nuisance_models (dict[str, list[CrossFitEstimator]] | None)

  • fitted_propensity_model (CrossFitEstimator | None)

  • feature_set (Collection[str] | Collection[int] | dict[str, Collection[str] | Collection[int]] | None)

  • n_folds (int | dict[str, int])

  • random_state (int | None)

classmethod nuisance_model_specifications()[source]

Return the specifications of all first-stage models.

Return type:

dict[str, _ModelSpecifications]

classmethod treatment_model_specifications()[source]

Return the specifications of all second-stage models.

Return type:

dict[str, _ModelSpecifications]

fit(X, y, w, n_jobs_cross_fitting=None, fit_params=None, synchronize_cross_fitting=True, n_jobs_base_learners=None)[source]

Fit all models of the MetaLearner.

If pre-fitted models were passed at instantiation, these are never refitted.

n_jobs_cross_fitting will be used at the cross-fitting level and n_jobs_base_learners will be used at the stage level. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors. For more information about parallelism check What about parallelism?

fit_params is an optional dict to be forwarded to base estimator fit calls. It supports two usages patterns:

  • fit_params={"parameter_of_interest": value_of_interest}
    
  • fit_params={
        "nuisance": {
            "nuisance_model_kind1": {"parameter_of_interest1": value_of_interest1},
            "nuisance_model_kind3": {"parameter_of_interest3": value_of_interest3},
        },
        "treatment": {"treatment_model_kind1": {"parameter_of_interest4": value_of_interest4}}
    }
    

In the former approach, the parameter and value of interest are passed to all base models. In the the latter approach, the explicitly qualified parameter-value pairs are passed to respective base models and no fitting parameters are passed to base models not explicitly listed. Note that in this pattern, propensity models are considered a nuisance model.

synchronize_cross_fitting indicates whether the learning of different base models should use exactly the same data splits where possible. Note that if there are several models to be synchronized which are classifiers, these cannot be split via stratification.

Parameters:
  • X (DataFrame | ndarray)

  • y (Series | ndarray)

  • w (Series | ndarray)

  • n_jobs_cross_fitting (int | None)

  • fit_params (dict | None)

  • synchronize_cross_fitting (bool)

  • n_jobs_base_learners (int | None)

Return type:

Self

predict(X, is_oos, oos_method='overall')[source]

Estimate the CATE.

If is_oos, an acronym for ‘is out of sample’, is False, the estimates will stem from cross-fitting. Otherwise, various approaches exist, specified via oos_method.

The returned ndarray is of shape:

  • \((n_{obs}, n_{variants} - 1, 1)\) if the outcome is a scalar, i.e. in case of a regression problem.

  • \((n_{obs}, n_{variants} - 1, n_{classes})\) if the outcome is a class, i.e. in case of a classification problem.

In the case of multiple treatment variants, the second dimension represents the CATE of the corresponding variant vs the control (variant 0).

Parameters:
  • X (DataFrame | ndarray)

  • is_oos (bool)

  • oos_method (Literal['overall', 'median', 'mean'])

Return type:

ndarray

evaluate(X, y, w, is_oos, oos_method='overall', scoring=None)[source]

Evaluate the MetaLearner.

The keys in scoring which are not a name of a model contained in the MetaLearner will be ignored, for information about this names check nuisance_model_specifications() and treatment_model_specifications(). The values must be a list of:

  • string representing a sklearn scoring method. Check here for the possible values.

  • Callable with signature scorer(estimator, X, y_true, **kwargs). We recommend using sklearn.metrics.make_scorer to create such a Callable.

If some model name is not present in the keys of scoring then the default used metrics will be neg_log_loss if it is a classifier and neg_root_mean_squared_error if it is a regressor.

The returned dictionary keys have the following structure:

  • For nuisance models:

    • If the cardinality is one: f"{model_kind}_{scorer}"

    • If there is one model for each treatment variant (including control): f"{model_kind}_{treatment_variant}_{scorer}"

  • For treatment models: f"{model_kind}_{treatment_variant}_vs_0_{scorer}"

Where scorer is the name of the scorer if it is a string and "custom_scorer_{idx}" if it is a callable where idx is the index in the scorers list.

Parameters:
  • X (DataFrame | ndarray)

  • y (Series | ndarray)

  • w (Series | ndarray)

  • is_oos (bool)

  • oos_method (Literal['overall', 'median', 'mean'])

  • scoring (Mapping[str, Sequence[str | Callable]] | None)

Return type:

dict[str, float]

explainer(X=None, cate_estimates=None, cate_model_factory=None, cate_model_params=None)[source]

Create an Explainer which can be used in feature_importances().

This function can be used in two distinct manners based on the provided parameters:

  • When parameters X, cate_estimates, and cate_model_factory are all set to None, the function creates an Explainer using the pre-existing treatment models. If these models do not exist, however, it triggers a ValueError.

  • On the contrary, if X, cate_estimates, and cate_model_factory are not None, the function initiates an instance of the Explainer class using these parameters. This instance then fits new models for each treatment variant, and these models are employed to calculate the importance of features.

Parameters:
  • X (DataFrame | ndarray | None)

  • cate_estimates (ndarray | None)

  • cate_model_factory (type[_ScikitModel] | None)

  • cate_model_params (Mapping[str, int | float | str] | None)

Return type:

Explainer

feature_importances(feature_names=None, normalize=False, sort_values=False, explainer=None, X=None, cate_estimates=None, cate_model_factory=None, cate_model_params=None)[source]

Calculates the feature importance for each treatment group.

If explainer is None, a new Explainer is created using explainer() with the passed parameters. If explainer is not None, then the parameters X, cate_estimates, cate_model_factory and cate_model_params are ignored.

If normalization = True, for each treatment variant the feature importances are normalized so that they sum to 1.

feature_names is optional but in the case it’s not passed the names of the features will default to f"Feature {i}" where i is the corresponding feature index.

The returned list contains the feature importances for each treatment variant in ascending order.

Parameters:
  • feature_names (Collection[str] | None)

  • normalize (bool)

  • sort_values (bool)

  • explainer (Explainer | None)

  • X (DataFrame | ndarray | None)

  • cate_estimates (ndarray | None)

  • cate_model_factory (type[_ScikitModel] | None)

  • cate_model_params (Mapping[str, int | float | str] | None)

Return type:

list[Series]

fit_nuisance(X, y, model_kind, model_ord, fit_params=None, n_jobs_cross_fitting=None, cv=None)[source]

Fit a given nuisance model of a MetaLearner.

y represents the objective of the given nuisance model, not necessarily the outcome of the experiment. If pre-fitted models were passed at instantiation, these are never refitted.

Parameters:
  • X (DataFrame | ndarray)

  • y (Series | ndarray)

  • model_kind (str)

  • model_ord (int)

  • fit_params (dict | None)

  • n_jobs_cross_fitting (int | None)

  • cv (list[tuple[ndarray, ndarray]] | None)

Return type:

Self

fit_treatment(X, y, model_kind, model_ord, fit_params=None, n_jobs_cross_fitting=None, cv=None)[source]

Fit the treatment model of a MetaLearner.

y represents the objective of the given treatment model, not necessarily the outcome of the experiment.

Parameters:
  • X (DataFrame | ndarray)

  • y (Series | ndarray)

  • model_kind (str)

  • model_ord (int)

  • fit_params (dict | None)

  • n_jobs_cross_fitting (int | None)

  • cv (list[tuple[ndarray, ndarray]] | None)

Return type:

Self

predict_conditional_average_outcomes(X, is_oos, oos_method='overall')[source]

Predict the vectors of conditional average outcomes.

These are defined as \(\mathbb{E}[Y_i(w) | X]\) for each treatment variant \(w\).

If is_oos, an acronym for ‘is out of sample’ is False, the estimates will stem from cross-fitting. Otherwise, various approaches exist, specified via oos_method.

The returned ndarray is of shape:

  • \((n_{obs}, n_{variants}, 1)\) if the outcome is a scalar, i.e. in case of a regression problem.

  • \((n_{obs}, n_{variants}, n_{classes})\) if the outcome is a class, i.e. in case of a classification problem.

Parameters:
  • X (DataFrame | ndarray)

  • is_oos (bool)

  • oos_method (Literal['overall', 'median', 'mean'])

Return type:

ndarray

predict_nuisance(X, model_kind, model_ord, is_oos, oos_method='overall')[source]

Estimate based on a given nuisance model.

Importantly, this method needs to implement the subselection of X based on the feature_set field of MetaLearner.

Parameters:
  • X (DataFrame | ndarray)

  • model_kind (str)

  • model_ord (int)

  • is_oos (bool)

  • oos_method (Literal['overall', 'median', 'mean'])

Return type:

ndarray

predict_treatment(X, model_kind, model_ord, is_oos, oos_method='overall')[source]

Estimate based on a given treatment model.

Importantly, this method needs to implement the subselection of X based on the feature_set field of MetaLearner.

Parameters:
  • X (DataFrame | ndarray)

  • model_kind (str)

  • model_ord (int)

  • is_oos (bool)

  • oos_method (Literal['overall', 'median', 'mean'])

Return type:

ndarray

shap_values(X, shap_explainer_factory, shap_explainer_params=None, explainer=None, cate_estimates=None, cate_model_factory=None, cate_model_params=None)[source]

Calculates the shap values for each treatment group.

If explainer is None a new Explainer is created using explainer() with the passed parameters. If explainer` is not None, then the parameters cate_estimates, cate_model_factory and cate_model_params are ignored.

The parameter shap_explainer_factory can be used to specify the type of shap explainer, for the different options see here.

The returned list contains the shap values for each treatment variant in ascending order.

Parameters:
  • X (DataFrame | ndarray)

  • shap_explainer_factory (type[Explainer])

  • shap_explainer_params (dict | None)

  • explainer (Explainer | None)

  • cate_estimates (ndarray | None)

  • cate_model_factory (type[_ScikitModel] | None)

  • cate_model_params (Mapping[str, int | float | str] | None)

Return type:

list[ndarray]