metalearners.cross_fit_estimator module
- class metalearners.cross_fit_estimator.CrossFitEstimator(n_folds, estimator_factory, estimator_params=<factory>, enable_overall=True, random_state=None)[source]
Bases:
objectHelper class for cross-fitting estimators on data.
Conceptually, it allows for fitting
n_foldsorn_folds+ 1 models onn_foldsfolds of the data.estimator_factoryis a class implementing an estimator with a scikit-learn interface. Instantiation parameters can be passed toestimator_params. An example argument forestimator_factorywould belightgbm.LGBMRegressor.Importantly, the
CrossFitEstimatorcan handle in-sample and out-of-sample (‘oos’) data for prediction. When doing in-sample prediction the single model will be used in which the respective data point has not been part of the training set. When doing oos prediction, different options exist. These options either rely on combining then_foldsmodels or using a model trained on all of the data (enable_overall).n_foldscan be set to 1 if the user desires to deactivate cross-fitting. In that case, theCrossFitEstimatorwould only fit one overall model which would be the one used for either in sample or out of sample predictions. Note that this is not recommended since it can lead to data leakage when doing in-sample predictions.- Parameters:
n_folds (int)
estimator_factory (type[_ScikitModel])
estimator_params (dict)
enable_overall (bool)
random_state (int | None)
- n_folds: int
- estimator_factory: type[_ScikitModel]
- estimator_params: dict
- enable_overall: bool = True
- random_state: int | None = None
- classes_: ndarray | None
- clone()[source]
Construct a new unfitted CrossFitEstimator with the same init parameters.
- Return type:
- fit(X, y, fit_params=None, n_jobs_cross_fitting=None, cv=None)[source]
Fit the underlying estimators.
One estimator is trained per
n_folds.If
enable_overallis set, an additional estimator is trained on all data.n_jobs_cross_fittingcan be used to specify the number of jobs for cross-fitting. For more information see the sklearn glossary.cvcan optionally be passed. If passed, it is expected to be a list of (train_indices, test_indices) tuples indicating how to split the data at hand into train and test/estimation sets for different folds.- Parameters:
X (DataFrame | ndarray)
y (Series | ndarray | DataFrame)
fit_params (dict | None)
n_jobs_cross_fitting (int | None)
cv (list[tuple[ndarray, ndarray]] | None)
- Return type:
Self
- predict(X, is_oos, oos_method=None, **kwargs)[source]
Predict from
X.If
is_oos, theoos_methodwill be used to generate predictions on ‘out of sample’ data. ‘Out of sample’ refers to this data not having been used in thefitmethod. Theoos_method'overall'can only be used if theCrossFitEstimatorhas been initialized withenable_overall=True.- Parameters:
X (DataFrame | ndarray)
is_oos (bool)
oos_method (Literal['overall', 'median', 'mean'] | None)
- Return type:
ndarray
- predict_proba(X, is_oos, oos_method=None)[source]
Predict probability from
X.If
is_oos, theoos_methodwill be used to generate predictions on ‘out of sample’ data. ‘Out of sample’ refers to this data not having been used in thefitmethod. Theoos_method'overall'can only be used if theCrossFitEstimatorhas been initialized withenable_overall=True.- Parameters:
X (DataFrame | ndarray)
is_oos (bool)
oos_method (Literal['overall', 'median', 'mean'] | None)
- Return type:
ndarray