metalearners.tlearner module
- class metalearners.tlearner.TLearner(is_classification, n_variants, nuisance_model_factory=None, treatment_model_factory=None, propensity_model_factory=None, nuisance_model_params=None, treatment_model_params=None, propensity_model_params=None, fitted_nuisance_models=None, fitted_propensity_model=None, feature_set=None, n_folds=10, random_state=None)[source]
Bases:
_ConditionalAverageOutcomeMetaLearnerT-Learner for CATE estimation as described by Kuenzel et al (2019).
- Parameters:
is_classification (bool)
n_variants (int)
nuisance_model_factory (type[_ScikitModel] | dict[str, type[_ScikitModel]] | None)
treatment_model_factory (type[_ScikitModel] | dict[str, type[_ScikitModel]] | None)
propensity_model_factory (type[_ScikitModel] | None)
nuisance_model_params (Mapping[str, int | float | str] | dict[str, Mapping[str, int | float | str]] | None)
treatment_model_params (Mapping[str, int | float | str] | dict[str, Mapping[str, int | float | str]] | None)
propensity_model_params (Mapping[str, int | float | str] | None)
fitted_nuisance_models (dict[str, list[CrossFitEstimator]] | None)
fitted_propensity_model (CrossFitEstimator | None)
feature_set (Collection[str] | Collection[int] | dict[str, Collection[str] | Collection[int]] | None)
n_folds (int | dict[str, int])
random_state (int | None)
- classmethod nuisance_model_specifications()[source]
Return the specifications of all first-stage models.
- Return type:
dict[str, _ModelSpecifications]
- classmethod treatment_model_specifications()[source]
Return the specifications of all second-stage models.
- Return type:
dict[str, _ModelSpecifications]
- fit(X, y, w, n_jobs_cross_fitting=None, fit_params=None, synchronize_cross_fitting=True, n_jobs_base_learners=None)[source]
Fit all models of the MetaLearner.
If pre-fitted models were passed at instantiation, these are never refitted.
n_jobs_cross_fittingwill be used at the cross-fitting level andn_jobs_base_learnerswill be used at the stage level.Nonemeans 1 unless in a joblib.parallel_backend context.-1means using all processors. For more information about parallelism check What about parallelism?fit_paramsis an optionaldictto be forwarded to base estimatorfitcalls. It supports two usages patterns:fit_params={"parameter_of_interest": value_of_interest}
fit_params={ "nuisance": { "nuisance_model_kind1": {"parameter_of_interest1": value_of_interest1}, "nuisance_model_kind3": {"parameter_of_interest3": value_of_interest3}, }, "treatment": {"treatment_model_kind1": {"parameter_of_interest4": value_of_interest4}} }
In the former approach, the parameter and value of interest are passed to all base models. In the the latter approach, the explicitly qualified parameter-value pairs are passed to respective base models and no fitting parameters are passed to base models not explicitly listed. Note that in this pattern, propensity models are considered a nuisance model.
synchronize_cross_fittingindicates whether the learning of different base models should use exactly the same data splits where possible. Note that if there are several models to be synchronized which are classifiers, these cannot be split via stratification.- Parameters:
X (DataFrame | ndarray)
y (Series | ndarray)
w (Series | ndarray)
n_jobs_cross_fitting (int | None)
fit_params (dict | None)
synchronize_cross_fitting (bool)
n_jobs_base_learners (int | None)
- Return type:
Self
- predict(X, is_oos, oos_method='overall')[source]
Estimate the CATE.
If
is_oos, an acronym for ‘is out of sample’, isFalse, the estimates will stem from cross-fitting. Otherwise, various approaches exist, specified viaoos_method.The returned ndarray is of shape:
\((n_{obs}, n_{variants} - 1, 1)\) if the outcome is a scalar, i.e. in case of a regression problem.
\((n_{obs}, n_{variants} - 1, n_{classes})\) if the outcome is a class, i.e. in case of a classification problem.
In the case of multiple treatment variants, the second dimension represents the CATE of the corresponding variant vs the control (variant 0).
- Parameters:
is_oos (bool)
oos_method (Literal['overall', 'median', 'mean'])
- Return type:
ndarray
- evaluate(X, y, w, is_oos, oos_method='overall', scoring=None)[source]
Evaluate the MetaLearner.
The keys in
scoringwhich are not a name of a model contained in the MetaLearner will be ignored, for information about this names checknuisance_model_specifications()andtreatment_model_specifications(). The values must be a list of:stringrepresenting asklearnscoring method. Check here for the possible values.Callablewith signaturescorer(estimator, X, y_true, **kwargs). We recommend using sklearn.metrics.make_scorer to create such aCallable.
If some model name is not present in the keys of
scoringthen the default used metrics will beneg_log_lossif it is a classifier andneg_root_mean_squared_errorif it is a regressor.The returned dictionary keys have the following structure:
For nuisance models:
If the cardinality is one:
f"{model_kind}_{scorer}"If there is one model for each treatment variant (including control):
f"{model_kind}_{treatment_variant}_{scorer}"
For treatment models:
f"{model_kind}_{treatment_variant}_vs_0_{scorer}"
Where
scoreris the name of the scorer if it is a string and"custom_scorer_{idx}"if it is a callable whereidxis the index in thescorerslist.- Parameters:
X (DataFrame | ndarray)
y (Series | ndarray)
w (Series | ndarray)
is_oos (bool)
oos_method (Literal['overall', 'median', 'mean'])
scoring (Mapping[str, Sequence[str | Callable]] | None)
- Return type:
dict[str, float]
- explainer(X=None, cate_estimates=None, cate_model_factory=None, cate_model_params=None)[source]
Create an
Explainerwhich can be used infeature_importances().This function can be used in two distinct manners based on the provided parameters:
When parameters
X,cate_estimates, andcate_model_factoryare all set toNone, the function creates anExplainerusing the pre-existing treatment models. If these models do not exist, however, it triggers aValueError.On the contrary, if
X,cate_estimates, andcate_model_factoryare notNone, the function initiates an instance of theExplainerclass using these parameters. This instance then fits new models for each treatment variant, and these models are employed to calculate the importance of features.
- Parameters:
X (DataFrame | ndarray | None)
cate_estimates (ndarray | None)
cate_model_factory (type[_ScikitModel] | None)
cate_model_params (Mapping[str, int | float | str] | None)
- Return type:
- feature_importances(feature_names=None, normalize=False, sort_values=False, explainer=None, X=None, cate_estimates=None, cate_model_factory=None, cate_model_params=None)[source]
Calculates the feature importance for each treatment group.
If
explainerisNone, a newExplaineris created usingexplainer()with the passed parameters. Ifexplaineris notNone, then the parametersX,cate_estimates,cate_model_factoryandcate_model_paramsare ignored.If
normalization = True, for each treatment variant the feature importances are normalized so that they sum to 1.feature_namesis optional but in the case it’s not passed the names of the features will default tof"Feature {i}"whereiis the corresponding feature index.The returned list contains the feature importances for each treatment variant in ascending order.
- Parameters:
feature_names (Collection[str] | None)
normalize (bool)
sort_values (bool)
explainer (Explainer | None)
X (DataFrame | ndarray | None)
cate_estimates (ndarray | None)
cate_model_factory (type[_ScikitModel] | None)
cate_model_params (Mapping[str, int | float | str] | None)
- Return type:
list[Series]
- fit_nuisance(X, y, model_kind, model_ord, fit_params=None, n_jobs_cross_fitting=None, cv=None)[source]
Fit a given nuisance model of a MetaLearner.
yrepresents the objective of the given nuisance model, not necessarily the outcome of the experiment. If pre-fitted models were passed at instantiation, these are never refitted.- Parameters:
X (DataFrame | ndarray)
y (Series | ndarray)
model_kind (str)
model_ord (int)
fit_params (dict | None)
n_jobs_cross_fitting (int | None)
cv (list[tuple[ndarray, ndarray]] | None)
- Return type:
Self
- fit_treatment(X, y, model_kind, model_ord, fit_params=None, n_jobs_cross_fitting=None, cv=None)[source]
Fit the treatment model of a MetaLearner.
yrepresents the objective of the given treatment model, not necessarily the outcome of the experiment.- Parameters:
X (DataFrame | ndarray)
y (Series | ndarray)
model_kind (str)
model_ord (int)
fit_params (dict | None)
n_jobs_cross_fitting (int | None)
cv (list[tuple[ndarray, ndarray]] | None)
- Return type:
Self
- predict_conditional_average_outcomes(X, is_oos, oos_method='overall')[source]
Predict the vectors of conditional average outcomes.
These are defined as \(\mathbb{E}[Y_i(w) | X]\) for each treatment variant \(w\).
If
is_oos, an acronym for ‘is out of sample’ isFalse, the estimates will stem from cross-fitting. Otherwise, various approaches exist, specified viaoos_method.The returned ndarray is of shape:
\((n_{obs}, n_{variants}, 1)\) if the outcome is a scalar, i.e. in case of a regression problem.
\((n_{obs}, n_{variants}, n_{classes})\) if the outcome is a class, i.e. in case of a classification problem.
- Parameters:
X (DataFrame | ndarray)
is_oos (bool)
oos_method (Literal['overall', 'median', 'mean'])
- Return type:
ndarray
- predict_nuisance(X, model_kind, model_ord, is_oos, oos_method='overall')[source]
Estimate based on a given nuisance model.
Importantly, this method needs to implement the subselection of
Xbased on thefeature_setfield ofMetaLearner.- Parameters:
X (DataFrame | ndarray)
model_kind (str)
model_ord (int)
is_oos (bool)
oos_method (Literal['overall', 'median', 'mean'])
- Return type:
ndarray
- predict_treatment(X, model_kind, model_ord, is_oos, oos_method='overall')[source]
Estimate based on a given treatment model.
Importantly, this method needs to implement the subselection of
Xbased on thefeature_setfield ofMetaLearner.- Parameters:
X (DataFrame | ndarray)
model_kind (str)
model_ord (int)
is_oos (bool)
oos_method (Literal['overall', 'median', 'mean'])
- Return type:
ndarray
- shap_values(X, shap_explainer_factory, shap_explainer_params=None, explainer=None, cate_estimates=None, cate_model_factory=None, cate_model_params=None)[source]
Calculates the shap values for each treatment group.
If
explainerisNonea newExplaineris created usingexplainer()with the passed parameters. If explainer` is notNone, then the parameterscate_estimates,cate_model_factoryandcate_model_paramsare ignored.The parameter
shap_explainer_factorycan be used to specify the type of shap explainer, for the different options see here.The returned list contains the shap values for each treatment variant in ascending order.
- Parameters:
X (DataFrame | ndarray)
shap_explainer_factory (type[Explainer])
shap_explainer_params (dict | None)
explainer (Explainer | None)
cate_estimates (ndarray | None)
cate_model_factory (type[_ScikitModel] | None)
cate_model_params (Mapping[str, int | float | str] | None)
- Return type:
list[ndarray]