metalearners.utils module

metalearners.utils.metalearner_factory(metalearner_prefix)[source]

Returns the MetaLearner class corresponding to the given prefix.

The accepted metalearner_prefix values are:

Parameters:

metalearner_prefix (str)

Return type:

type[MetaLearner]

metalearners.utils.simplify_output(tensor)[source]

Reduces dimensions of a CATE estimation tensor if possible.

The returned results will be of shape

  • \((n_{obs})\) if there are 2 tratment variants and and the outcome is either a regression outcome or a binary classification outcome.

  • \((n_{obs}, n_{classes})\) if there are 2 treatment variants and and the outcome is a classification outcome with at least 3 classes.

  • \((n_{obs}, n_{variants} - 1)\) if there are at least 3 variants and the outcome is either a regression outcome or a binary classification outcome.

  • \((n_{obs}, n_{variants} - 1, n_{classes})\) if there are at least 3 variants and and the outcome is a classification outcome with at least 3 classes.

Parameters:

tensor (ndarray)

Return type:

ndarray