metalearners.utils module
- metalearners.utils.metalearner_factory(metalearner_prefix)[source]
Returns the MetaLearner class corresponding to the given prefix.
The accepted
metalearner_prefixvalues are:- Parameters:
metalearner_prefix (str)
- Return type:
type[MetaLearner]
- metalearners.utils.simplify_output(tensor)[source]
Reduces dimensions of a CATE estimation tensor if possible.
The returned results will be of shape
\((n_{obs})\) if there are 2 tratment variants and and the outcome is either a regression outcome or a binary classification outcome.
\((n_{obs}, n_{classes})\) if there are 2 treatment variants and and the outcome is a classification outcome with at least 3 classes.
\((n_{obs}, n_{variants} - 1)\) if there are at least 3 variants and the outcome is either a regression outcome or a binary classification outcome.
\((n_{obs}, n_{variants} - 1, n_{classes})\) if there are at least 3 variants and and the outcome is a classification outcome with at least 3 classes.
- Parameters:
tensor (ndarray)
- Return type:
ndarray