diff --git a/causalml/inference/meta/drlearner.py b/causalml/inference/meta/drlearner.py index ea51b234..a5b051dc 100644 --- a/causalml/inference/meta/drlearner.py +++ b/causalml/inference/meta/drlearner.py @@ -235,16 +235,17 @@ def predict( X, treatment, y = convert_pd_to_np(X, treatment, y) te = np.zeros((X.shape[0], self.t_groups.shape[0])) - yhat_cs = {} yhat_ts = {} + # models_mu_c is fold-specific but not group-specific; predict once and reuse. + yhat_c = np.r_[[model.predict(X) for model in self.models_mu_c]].mean(axis=0) + # Shared-reference dict preserves the public yhat_cs[group] API cheaply. + yhat_cs = {group: yhat_c for group in self.t_groups} + for i, group in enumerate(self.t_groups): models_tau = self.models_tau[group] _te = np.r_[[model.predict(X) for model in models_tau]].mean(axis=0) te[:, i] = np.ravel(_te) - yhat_cs[group] = np.r_[ - [model.predict(X) for model in self.models_mu_c] - ].mean(axis=0) yhat_ts[group] = np.r_[ [model.predict(X) for model in self.models_mu_t[group]] ].mean(axis=0) @@ -256,7 +257,7 @@ def predict( w = (treatment_filt == group).astype(int) yhat = np.zeros_like(y_filt, dtype=float) - yhat[w == 0] = yhat_cs[group][mask][w == 0] + yhat[w == 0] = yhat_c[mask][w == 0] yhat[w == 1] = yhat_ts[group][mask][w == 1] logger.info("Error metrics for group {}".format(group)) @@ -595,16 +596,18 @@ def predict( X, treatment, y = convert_pd_to_np(X, treatment, y) te = np.zeros((X.shape[0], self.t_groups.shape[0])) - yhat_cs = {} yhat_ts = {} + # models_mu_c is fold-specific but not group-specific; predict once and reuse. + yhat_c = np.r_[ + [model.predict_proba(X)[:, 1] for model in self.models_mu_c] + ].mean(axis=0) + yhat_cs = {group: yhat_c for group in self.t_groups} + for i, group in enumerate(self.t_groups): models_tau = self.models_tau[group] _te = np.r_[[model.predict(X) for model in models_tau]].mean(axis=0) te[:, i] = np.ravel(_te) - yhat_cs[group] = np.r_[ - [model.predict_proba(X)[:, 1] for model in self.models_mu_c] - ].mean(axis=0) yhat_ts[group] = np.r_[ [model.predict_proba(X)[:, 1] for model in self.models_mu_t[group]] ].mean(axis=0) @@ -616,7 +619,7 @@ def predict( w = (treatment_filt == group).astype(int) yhat = np.zeros_like(y_filt, dtype=float) - yhat[w == 0] = yhat_cs[group][mask][w == 0] + yhat[w == 0] = yhat_c[mask][w == 0] yhat[w == 1] = yhat_ts[group][mask][w == 1] logger.info("Error metrics for group {}".format(group)) diff --git a/causalml/inference/meta/slearner.py b/causalml/inference/meta/slearner.py index 6891f02d..a5df639c 100644 --- a/causalml/inference/meta/slearner.py +++ b/causalml/inference/meta/slearner.py @@ -106,16 +106,15 @@ def predict( yhat_cs = {} yhat_ts = {} + # Build the augmented arrays once; they are identical for every group. + # (Separate allocations avoid in-place mutation by learners like CatBoost + # that set the writeable flag to False on arrays passed to predict().) + X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X)) + X_new_t = np.hstack((np.ones((X.shape[0], 1)), X)) + for group in self.t_groups: model = self.models[group] - - # Build separate arrays for control and treatment to avoid in-place - # mutation, which fails when learners like CatBoost set the - # writeable flag to False on arrays passed to predict(). - X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X)) yhat_cs[group] = model.predict(X_new_c) - - X_new_t = np.hstack((np.ones((X.shape[0], 1)), X)) yhat_ts[group] = model.predict(X_new_t) if (y is not None) and (treatment is not None) and verbose: @@ -344,16 +343,13 @@ def predict( yhat_cs = {} yhat_ts = {} + # Build the augmented arrays once; they are identical for every group. + X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X)) + X_new_t = np.hstack((np.ones((X.shape[0], 1)), X)) + for group in self.t_groups: model = self.models[group] - - # Build separate arrays for control and treatment to avoid in-place - # mutation, which fails when learners like CatBoost set the - # writeable flag to False on arrays passed to predict(). - X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X)) yhat_cs[group] = model.predict_proba(X_new_c)[:, 1] - - X_new_t = np.hstack((np.ones((X.shape[0], 1)), X)) yhat_ts[group] = model.predict_proba(X_new_t)[:, 1] if y is not None and (treatment is not None) and verbose: diff --git a/causalml/inference/meta/tlearner.py b/causalml/inference/meta/tlearner.py index 04ca796f..7d0bc02d 100644 --- a/causalml/inference/meta/tlearner.py +++ b/causalml/inference/meta/tlearner.py @@ -55,6 +55,9 @@ def __init__( else: self.model_c = control_learner + # Preserve the unfitted template so repeated fit() calls always start fresh. + self._model_c_template = self.model_c + if treatment_learner is None: self.model_t = deepcopy(learner) else: @@ -82,18 +85,20 @@ def fit(self, X, treatment, y, p=None): self.t_groups = np.unique(treatment[treatment != self.control_name]) self.t_groups.sort() self._classes = {group: i for i, group in enumerate(self.t_groups)} - self.models_c = {group: deepcopy(self.model_c) for group in self.t_groups} self.models_t = {group: deepcopy(self.model_t) for group in self.t_groups} - for group in self.t_groups: - mask = (treatment == group) | (treatment == self.control_name) - treatment_filt = treatment[mask] - X_filt = X[mask] - y_filt = y[mask] - w = (treatment_filt == group).astype(int) + # model_c is trained on the control group, which is identical for every + # treatment group, so fit it once. Deepcopy from the unfitted template so + # re-calling fit() always starts from a clean state (safe with warm_start). + control_mask = treatment == self.control_name + self.model_c = deepcopy(self._model_c_template) + self.model_c.fit(X[control_mask], y[control_mask]) + # Expose as a shared-reference dict to preserve the public models_c API. + self.models_c = {group: self.model_c for group in self.t_groups} - self.models_c[group].fit(X_filt[w == 0], y_filt[w == 0]) - self.models_t[group].fit(X_filt[w == 1], y_filt[w == 1]) + for group in self.t_groups: + treatment_mask = treatment == group + self.models_t[group].fit(X[treatment_mask], y[treatment_mask]) def predict( self, X, treatment=None, y=None, p=None, return_components=False, verbose=True @@ -110,14 +115,15 @@ def predict( (numpy.ndarray): Predictions of treatment effects. """ X, treatment, y = convert_pd_to_np(X, treatment, y) - yhat_cs = {} yhat_ts = {} + yhat_c = self.model_c.predict(X) + # Build a shared-reference dict so return_components callers keep the + # yhat_cs[group] indexing API without duplicating the underlying array. + yhat_cs = {group: yhat_c for group in self.t_groups} + for group in self.t_groups: - model_c = self.models_c[group] - model_t = self.models_t[group] - yhat_cs[group] = model_c.predict(X) - yhat_ts[group] = model_t.predict(X) + yhat_ts[group] = self.models_t[group].predict(X) if (y is not None) and (treatment is not None) and verbose: mask = (treatment == group) | (treatment == self.control_name) @@ -126,7 +132,7 @@ def predict( w = (treatment_filt == group).astype(int) yhat = np.zeros_like(y_filt, dtype=float) - yhat[w == 0] = yhat_cs[group][mask][w == 0] + yhat[w == 0] = yhat_c[mask][w == 0] yhat[w == 1] = yhat_ts[group][mask][w == 1] logger.info("Error metrics for group {}".format(group)) @@ -134,7 +140,7 @@ def predict( te = np.zeros((X.shape[0], self.t_groups.shape[0])) for i, group in enumerate(self.t_groups): - te[:, i] = yhat_ts[group] - yhat_cs[group] + te[:, i] = yhat_ts[group] - yhat_c if not return_components: return te @@ -178,7 +184,7 @@ def fit_predict( else: t_groups_global = self.t_groups _classes_global = self._classes - models_c_global = deepcopy(self.models_c) + model_c_global = deepcopy(self.model_c) models_t_global = deepcopy(self.models_t) te_bootstraps = np.zeros( shape=(X.shape[0], self.t_groups.shape[0], n_bootstraps) @@ -197,7 +203,8 @@ def fit_predict( # set member variables back to global (currently last bootstrapped outcome) self.t_groups = t_groups_global self._classes = _classes_global - self.models_c = deepcopy(models_c_global) + self.model_c = deepcopy(model_c_global) + self.models_c = {group: self.model_c for group in self.t_groups} self.models_t = deepcopy(models_t_global) return (te, te_lower, te_upper) @@ -271,7 +278,7 @@ def estimate_ate( else: t_groups_global = self.t_groups _classes_global = self._classes - models_c_global = deepcopy(self.models_c) + model_c_global = deepcopy(self.model_c) models_t_global = deepcopy(self.models_t) logger.info("Bootstrap Confidence Intervals for ATE") @@ -291,7 +298,8 @@ def estimate_ate( # set member variables back to global (currently last bootstrapped outcome) self.t_groups = t_groups_global self._classes = _classes_global - self.models_c = deepcopy(models_c_global) + self.model_c = deepcopy(model_c_global) + self.models_c = {group: self.model_c for group in self.t_groups} self.models_t = deepcopy(models_t_global) return ate, ate_lower, ate_upper @@ -371,14 +379,13 @@ def predict( Returns: (numpy.ndarray): Predictions of treatment effects. """ - yhat_cs = {} yhat_ts = {} + yhat_c = self.model_c.predict_proba(X)[:, 1] + yhat_cs = {group: yhat_c for group in self.t_groups} + for group in self.t_groups: - model_c = self.models_c[group] - model_t = self.models_t[group] - yhat_cs[group] = model_c.predict_proba(X)[:, 1] - yhat_ts[group] = model_t.predict_proba(X)[:, 1] + yhat_ts[group] = self.models_t[group].predict_proba(X)[:, 1] if (y is not None) and (treatment is not None) and verbose: mask = (treatment == group) | (treatment == self.control_name) @@ -387,7 +394,7 @@ def predict( w = (treatment_filt == group).astype(int) yhat = np.zeros_like(y_filt, dtype=float) - yhat[w == 0] = yhat_cs[group][mask][w == 0] + yhat[w == 0] = yhat_c[mask][w == 0] yhat[w == 1] = yhat_ts[group][mask][w == 1] logger.info("Error metrics for group {}".format(group)) @@ -395,7 +402,7 @@ def predict( te = np.zeros((X.shape[0], self.t_groups.shape[0])) for i, group in enumerate(self.t_groups): - te[:, i] = yhat_ts[group] - yhat_cs[group] + te[:, i] = yhat_ts[group] - yhat_c if not return_components: return te diff --git a/causalml/inference/meta/xlearner.py b/causalml/inference/meta/xlearner.py index 88b5dc1d..968fa870 100644 --- a/causalml/inference/meta/xlearner.py +++ b/causalml/inference/meta/xlearner.py @@ -56,6 +56,9 @@ def __init__( else: self.model_mu_c = control_outcome_learner + # Preserve the unfitted template so repeated fit() calls always start fresh. + self._model_mu_c_template = self.model_mu_c + if treatment_outcome_learner is None: self.model_mu_t = deepcopy(learner) else: @@ -114,7 +117,6 @@ def fit(self, X, treatment, y, p=None): p = self._format_p(p, self.t_groups) self._classes = {group: i for i, group in enumerate(self.t_groups)} - self.models_mu_c = {group: deepcopy(self.model_mu_c) for group in self.t_groups} self.models_mu_t = {group: deepcopy(self.model_mu_t) for group in self.t_groups} self.models_tau_c = { group: deepcopy(self.model_tau_c) for group in self.t_groups @@ -125,32 +127,36 @@ def fit(self, X, treatment, y, p=None): self.vars_c = {} self.vars_t = {} + # model_mu_c is trained on control data, which is the same for every treatment + # group. Deepcopy from the unfitted template so re-calling fit() starts fresh. + control_mask = treatment == self.control_name + self.model_mu_c = deepcopy(self._model_mu_c_template) + self.model_mu_c.fit(X[control_mask], y[control_mask]) + # Expose as a shared-reference dict to preserve the public models_mu_c API. + self.models_mu_c = {group: self.model_mu_c for group in self.t_groups} + + # var_c depends only on model_mu_c and control data — constant across groups. + y_control_pred = self.model_mu_c.predict(X[control_mask]) + self.var_c = (y[control_mask] - y_control_pred).var() + # Keep vars_c dict for backward compatibility with existing callers. + self.vars_c = {group: self.var_c for group in self.t_groups} + for group in self.t_groups: - mask = (treatment == group) | (treatment == self.control_name) - treatment_filt = treatment[mask] - X_filt = X[mask] - y_filt = y[mask] - w = (treatment_filt == group).astype(int) + treatment_mask = treatment == group + X_treat = X[treatment_mask] + y_treat = y[treatment_mask] - # Train outcome models - self.models_mu_c[group].fit(X_filt[w == 0], y_filt[w == 0]) - self.models_mu_t[group].fit(X_filt[w == 1], y_filt[w == 1]) + self.models_mu_t[group].fit(X_treat, y_treat) - # Calculate variances and treatment effects - var_c = ( - y_filt[w == 0] - self.models_mu_c[group].predict(X_filt[w == 0]) - ).var() - self.vars_c[group] = var_c - var_t = ( - y_filt[w == 1] - self.models_mu_t[group].predict(X_filt[w == 1]) + self.vars_t[group] = ( + y_treat - self.models_mu_t[group].predict(X_treat) ).var() - self.vars_t[group] = var_t - # Train treatment models - d_c = self.models_mu_t[group].predict(X_filt[w == 0]) - y_filt[w == 0] - d_t = y_filt[w == 1] - self.models_mu_c[group].predict(X_filt[w == 1]) - self.models_tau_c[group].fit(X_filt[w == 0], d_c) - self.models_tau_t[group].fit(X_filt[w == 1], d_t) + # Train treatment effect models using cross-group imputation + d_c = self.models_mu_t[group].predict(X[control_mask]) - y[control_mask] + d_t = y_treat - self.model_mu_c.predict(X_treat) + self.models_tau_c[group].fit(X[control_mask], d_c) + self.models_tau_t[group].fit(X_treat, d_t) def predict( self, X, treatment=None, y=None, p=None, return_components=False, verbose=True @@ -184,6 +190,12 @@ def predict( dhat_cs = {} dhat_ts = {} + # For verbose metrics, control predictions are constant across groups. + yhat_c_verbose = None + if (y is not None) and (treatment is not None) and verbose: + control_mask = treatment == self.control_name + yhat_c_verbose = self.model_mu_c.predict(X[control_mask]) + for i, group in enumerate(self.t_groups): model_tau_c = self.models_tau_c[group] model_tau_t = self.models_tau_t[group] @@ -195,7 +207,7 @@ def predict( ) te[:, i] = np.ravel(_te) - if (y is not None) and (treatment is not None) and verbose: + if yhat_c_verbose is not None: mask = (treatment == group) | (treatment == self.control_name) treatment_filt = treatment[mask] X_filt = X[mask] @@ -203,7 +215,7 @@ def predict( w = (treatment_filt == group).astype(int) yhat = np.zeros_like(y_filt, dtype=float) - yhat[w == 0] = self.models_mu_c[group].predict(X_filt[w == 0]) + yhat[w == 0] = yhat_c_verbose yhat[w == 1] = self.models_mu_t[group].predict(X_filt[w == 1]) logger.info("Error metrics for group {}".format(group)) @@ -262,7 +274,7 @@ def fit_predict( else: t_groups_global = self.t_groups _classes_global = self._classes - models_mu_c_global = deepcopy(self.models_mu_c) + model_mu_c_global = deepcopy(self.model_mu_c) models_mu_t_global = deepcopy(self.models_mu_t) models_tau_c_global = deepcopy(self.models_tau_c) models_tau_t_global = deepcopy(self.models_tau_t) @@ -283,7 +295,8 @@ def fit_predict( # set member variables back to global (currently last bootstrapped outcome) self.t_groups = t_groups_global self._classes = _classes_global - self.models_mu_c = deepcopy(models_mu_c_global) + self.model_mu_c = deepcopy(model_mu_c_global) + self.models_mu_c = {group: self.model_mu_c for group in self.t_groups} self.models_mu_t = deepcopy(models_mu_t_global) self.models_tau_c = deepcopy(models_tau_c_global) self.models_tau_t = deepcopy(models_tau_t_global) @@ -362,7 +375,7 @@ def estimate_ate( se = np.sqrt( ( self.vars_t[group] / prob_treatment - + self.vars_c[group] / (1 - prob_treatment) + + self.var_c / (1 - prob_treatment) + (p_filt * dhat_c + (1 - p_filt) * dhat_t).var() ) / w.shape[0] @@ -380,7 +393,7 @@ def estimate_ate( else: t_groups_global = self.t_groups _classes_global = self._classes - models_mu_c_global = deepcopy(self.models_mu_c) + model_mu_c_global = deepcopy(self.model_mu_c) models_mu_t_global = deepcopy(self.models_mu_t) models_tau_c_global = deepcopy(self.models_tau_c) models_tau_t_global = deepcopy(self.models_tau_t) @@ -402,7 +415,8 @@ def estimate_ate( # set member variables back to global (currently last bootstrapped outcome) self.t_groups = t_groups_global self._classes = _classes_global - self.models_mu_c = deepcopy(models_mu_c_global) + self.model_mu_c = deepcopy(model_mu_c_global) + self.models_mu_c = {group: self.model_mu_c for group in self.t_groups} self.models_mu_t = deepcopy(models_mu_t_global) self.models_tau_c = deepcopy(models_tau_c_global) self.models_tau_t = deepcopy(models_tau_t_global) @@ -528,7 +542,6 @@ def fit(self, X, treatment, y, p=None): p = self._format_p(p, self.t_groups) self._classes = {group: i for i, group in enumerate(self.t_groups)} - self.models_mu_c = {group: deepcopy(self.model_mu_c) for group in self.t_groups} self.models_mu_t = {group: deepcopy(self.model_mu_t) for group in self.t_groups} self.models_tau_c = { group: deepcopy(self.model_tau_c) for group in self.t_groups @@ -539,40 +552,37 @@ def fit(self, X, treatment, y, p=None): self.vars_c = {} self.vars_t = {} + # model_mu_c is trained on control data, which is the same for every treatment + # group, so fit it once and store as a single model (not a per-group dict). + control_mask = treatment == self.control_name + self.model_mu_c = deepcopy(self._model_mu_c_template) + self.model_mu_c.fit(X[control_mask], y[control_mask]) + self.models_mu_c = {group: self.model_mu_c for group in self.t_groups} + + # var_c depends only on model_mu_c and control data — constant across groups. + y_control_pred = self.model_mu_c.predict_proba(X[control_mask])[:, 1] + self.var_c = (y[control_mask] - y_control_pred).var() + self.vars_c = {group: self.var_c for group in self.t_groups} + for group in self.t_groups: - mask = (treatment == group) | (treatment == self.control_name) - treatment_filt = treatment[mask] - X_filt = X[mask] - y_filt = y[mask] - w = (treatment_filt == group).astype(int) + treatment_mask = treatment == group + X_treat = X[treatment_mask] + y_treat = y[treatment_mask] - # Train outcome models - self.models_mu_c[group].fit(X_filt[w == 0], y_filt[w == 0]) - self.models_mu_t[group].fit(X_filt[w == 1], y_filt[w == 1]) + self.models_mu_t[group].fit(X_treat, y_treat) - # Calculate variances and treatment effects - var_c = ( - y_filt[w == 0] - - self.models_mu_c[group].predict_proba(X_filt[w == 0])[:, 1] - ).var() - self.vars_c[group] = var_c - var_t = ( - y_filt[w == 1] - - self.models_mu_t[group].predict_proba(X_filt[w == 1])[:, 1] + self.vars_t[group] = ( + y_treat - self.models_mu_t[group].predict_proba(X_treat)[:, 1] ).var() - self.vars_t[group] = var_t - # Train treatment models + # Train treatment effect models using cross-group imputation d_c = ( - self.models_mu_t[group].predict_proba(X_filt[w == 0])[:, 1] - - y_filt[w == 0] + self.models_mu_t[group].predict_proba(X[control_mask])[:, 1] + - y[control_mask] ) - d_t = ( - y_filt[w == 1] - - self.models_mu_c[group].predict_proba(X_filt[w == 1])[:, 1] - ) - self.models_tau_c[group].fit(X_filt[w == 0], d_c) - self.models_tau_t[group].fit(X_filt[w == 1], d_t) + d_t = y_treat - self.model_mu_c.predict_proba(X_treat)[:, 1] + self.models_tau_c[group].fit(X[control_mask], d_c) + self.models_tau_t[group].fit(X_treat, d_t) def predict( self, X, treatment=None, y=None, p=None, return_components=False, verbose=True @@ -607,6 +617,12 @@ def predict( dhat_cs = {} dhat_ts = {} + # For verbose metrics, control predictions are constant across groups. + yhat_c_verbose = None + if (y is not None) and (treatment is not None) and verbose: + control_mask = treatment == self.control_name + yhat_c_verbose = self.model_mu_c.predict_proba(X[control_mask])[:, 1] + for i, group in enumerate(self.t_groups): model_tau_c = self.models_tau_c[group] model_tau_t = self.models_tau_t[group] @@ -618,7 +634,7 @@ def predict( ) te[:, i] = np.ravel(_te) - if (y is not None) and (treatment is not None) and verbose: + if yhat_c_verbose is not None: mask = (treatment == group) | (treatment == self.control_name) treatment_filt = treatment[mask] X_filt = X[mask] @@ -626,9 +642,7 @@ def predict( w = (treatment_filt == group).astype(int) yhat = np.zeros_like(y_filt, dtype=float) - yhat[w == 0] = self.models_mu_c[group].predict_proba(X_filt[w == 0])[ - :, 1 - ] + yhat[w == 0] = yhat_c_verbose yhat[w == 1] = self.models_mu_t[group].predict_proba(X_filt[w == 1])[ :, 1 ] diff --git a/tests/test_meta_learners.py b/tests/test_meta_learners.py index d5a60216..c067e449 100644 --- a/tests/test_meta_learners.py +++ b/tests/test_meta_learners.py @@ -1220,3 +1220,377 @@ def test_BaseDRClassifier(generate_classification_data): te_separate = learner_separate.fit_predict(X=X, treatment=treatment, y=y) assert te_separate.shape == te.shape + + +def test_multi_treatment_learners(): + """Comprehensive multi-treatment (N=3) contract test for all meta-learners. + + Verifies three classes of invariants: + 1. Common API contracts — return types and shapes for every public method. + 2. Structural post-fit invariants — shared-reference dicts, attribute presence. + 3. Optimisation correctness — control models trained once and shared. + + Covers: BaseTLearner, BaseXLearner, BaseSLearner, BaseDRLearner, BaseRLearner. + + Shared return-type contracts (regression learners below): + - ``fit(...)`` → ``None`` + - ``predict(...)`` → ``np.ndarray`` of shape ``(n_samples, n_treatment_groups)`` + - ``predict(..., return_components=True)`` → ``tuple`` of length 3 ``(te, comp_a, comp_b)`` + (not implemented for R-learner; its ``predict`` only returns CATE). + - ``fit_predict(..., return_ci=False)`` → CATE ``np.ndarray`` only (not a tuple) + - ``fit_predict(..., return_ci=True)`` → ``tuple`` ``(te, lb, ub)`` of three ndarrays + - ``estimate_ate(...)`` → ``tuple`` ``(ate, lb, ub)`` with each vector of shape + ``(n_treatment_groups,)`` for T/X/R/DR by default; **BaseSLearner** returns only + ``ate`` unless ``return_ci=True`` (then same triple as the others). + """ + np.random.seed(RANDOM_SEED) + n, p, n_groups = 600, 5, 3 + X = np.random.randn(n, p) + # Three treatment groups (1, 2, 3) plus control (0), ~150 obs each. + treatment = np.tile([0, 1, 2, 3], n // 4) + tau = np.where( + treatment == 1, + 1.0, + np.where(treatment == 2, 2.0, np.where(treatment == 3, 3.0, 0.0)), + ) + y = X[:, 0] + tau + 0.1 * np.random.randn(n) + # Flat propensity scores for learners that require them (X, R). + p_scores = {g: np.full(n, 1.0 / (n_groups + 1)) for g in [1, 2, 3]} + + # ── Shared assertion helpers ─────────────────────────────────────────────── + + def _assert_fit_attrs(lrn, name): + """t_groups must be a sorted ndarray; _classes must map each group to 0..N-1.""" + assert hasattr(lrn, "t_groups"), f"{name}: missing t_groups after fit" + assert isinstance(lrn.t_groups, np.ndarray), f"{name}: t_groups must be ndarray" + assert lrn.t_groups.shape == ( + n_groups, + ), f"{name}: t_groups shape {lrn.t_groups.shape}" + np.testing.assert_array_equal( + lrn.t_groups, + np.sort(lrn.t_groups), + err_msg=f"{name}: t_groups must be sorted", + ) + assert hasattr(lrn, "_classes") and isinstance(lrn._classes, dict) + assert set(lrn._classes.keys()) == set(lrn.t_groups) + assert set(lrn._classes.values()) == set(range(n_groups)) + + def _assert_te(te, name, method): + """te must be ndarray (n, n_groups) of finite values.""" + assert isinstance(te, np.ndarray), f"{name}.{method}: te must be ndarray" + assert te.shape == (n, n_groups), f"{name}.{method}: te.shape={te.shape}" + assert np.all(np.isfinite(te)), f"{name}.{method}: te has non-finite values" + + def _assert_components(yhat_cs, yhat_ts, t_groups, name): + """yhat_cs and yhat_ts must be dicts of finite (n,) arrays covering all groups.""" + for label, d in [("yhat_cs", yhat_cs), ("yhat_ts", yhat_ts)]: + assert isinstance(d, dict), f"{name}: {label} must be dict, got {type(d)}" + assert set(d.keys()) == set(t_groups), f"{name}: {label} keys != t_groups" + for g in t_groups: + assert isinstance(d[g], np.ndarray) and d[g].shape == (n,) + assert np.all(np.isfinite(d[g])), f"{name}: {label}[{g}] has non-finite" + + def _assert_ate(result, name): + """estimate_ate must return (ate, lb, ub) — finite ndarrays of shape (n_groups,), lb<=ub.""" + assert ( + isinstance(result, tuple) and len(result) == 3 + ), f"{name}.estimate_ate: expected 3-tuple, got {type(result)}" + ate, lb, ub = result + for arr, label in [(ate, "ate"), (lb, "lb"), (ub, "ub")]: + assert isinstance( + arr, np.ndarray + ), f"{name}.estimate_ate {label} must be ndarray" + assert arr.shape == ( + n_groups, + ), f"{name}.estimate_ate {label}.shape={arr.shape}" + assert np.all( + np.isfinite(arr) + ), f"{name}.estimate_ate {label} has non-finite" + assert np.all(lb <= ub), f"{name}.estimate_ate: lb > ub" + + def _assert_ci_triple(result, name, method): + """fit_predict(return_ci=True) must return (te, lb, ub), each (n, n_groups).""" + assert isinstance(result, tuple) and len(result) == 3 + for arr, label in zip(result, ["te", "lb", "ub"]): + _assert_te(arr, name, f"{method}[{label}]") + + def _assert_shared_ref_dict(d, single_obj, keys, name, attr): + """Every value in d must be the same Python object as single_obj.""" + assert isinstance(d, dict), f"{name}: {attr} must be dict" + assert set(d.keys()) == set(keys), f"{name}: {attr} keys mismatch" + assert all( + d[g] is single_obj for g in keys + ), f"{name}: all {attr} values must be shared refs to the single fitted model" + + def _assert_fit_returns_none(result, name): + assert result is None, f"{name}.fit(): expected None, got {type(result)}" + + def _assert_plain_fit_predict(result, name): + """fit_predict(return_ci=False) must return a single ndarray (CATE), not a tuple.""" + assert isinstance( + result, np.ndarray + ), f"{name}.fit_predict(return_ci=False): expected ndarray, got {type(result)}" + + # ── T-Learner ───────────────────────────────────────────────────────────── + name = "BaseTLearner" + tl = BaseTLearner(learner=LinearRegression()) + _assert_fit_returns_none(tl.fit(X=X, treatment=treatment, y=y), name) + + _assert_fit_attrs(tl, name) + assert hasattr(tl, "model_c"), f"{name}: missing model_c" + _assert_shared_ref_dict(tl.models_c, tl.model_c, tl.t_groups, name, "models_c") + assert hasattr(tl, "models_t") and isinstance(tl.models_t, dict) + assert set(tl.models_t.keys()) == set(tl.t_groups) + # Treatment models must be distinct objects (trained on different per-group data). + assert all( + tl.models_t[g1] is not tl.models_t[g2] + for g1, g2 in zip(tl.t_groups[:-1], tl.t_groups[1:]) + ), f"{name}: models_t must be distinct objects per group" + + te = tl.predict(X=X) + _assert_te(te, name, "predict()") + + out_pc = tl.predict(X=X, return_components=True) + assert ( + isinstance(out_pc, tuple) and len(out_pc) == 3 + ), f"{name}.predict(return_components=True) must return (te, yhat_cs, yhat_ts)" + te2, yhat_cs, yhat_ts = out_pc + np.testing.assert_array_equal(te, te2, err_msg=f"{name}: predict inconsistency") + _assert_components(yhat_cs, yhat_ts, tl.t_groups, name) + assert all( + yhat_cs[g] is yhat_cs[tl.t_groups[0]] for g in tl.t_groups + ), f"{name}: yhat_cs values must share the same underlying array" + + fp_plain = tl.fit_predict(X=X, treatment=treatment, y=y) + _assert_plain_fit_predict(fp_plain, name) + _assert_te(fp_plain, name, "fit_predict()") + _assert_ci_triple( + tl.fit_predict( + X=X, + treatment=treatment, + y=y, + return_ci=True, + n_bootstraps=5, + bootstrap_size=150, + ), + name, + "fit_predict", + ) + _assert_ate(tl.estimate_ate(X=X, treatment=treatment, y=y), name) + _assert_ate(tl.estimate_ate(X=X, treatment=treatment, y=y, pretrain=True), name) + + # ── X-Learner ───────────────────────────────────────────────────────────── + name = "BaseXLearner" + xl = BaseXLearner(learner=LinearRegression()) + _assert_fit_returns_none(xl.fit(X=X, treatment=treatment, y=y, p=p_scores), name) + + _assert_fit_attrs(xl, name) + assert hasattr(xl, "model_mu_c"), f"{name}: missing model_mu_c" + _assert_shared_ref_dict( + xl.models_mu_c, xl.model_mu_c, xl.t_groups, name, "models_mu_c" + ) + assert ( + hasattr(xl, "var_c") and np.isscalar(xl.var_c) and np.isfinite(xl.var_c) + ), f"{name}: var_c must be a finite scalar" + assert hasattr(xl, "vars_c") and isinstance(xl.vars_c, dict) + assert all( + xl.vars_c[g] == xl.var_c for g in xl.t_groups + ), f"{name}: vars_c values must all equal var_c" + for attr in ("models_mu_t", "models_tau_c", "models_tau_t", "vars_t"): + assert hasattr(xl, attr) and isinstance( + getattr(xl, attr), dict + ), f"{name}: missing {attr}" + assert set(getattr(xl, attr).keys()) == set( + xl.t_groups + ), f"{name}: {attr} keys mismatch" + + te = xl.predict(X=X, p=p_scores) + _assert_te(te, name, "predict()") + + out_pc = xl.predict(X=X, p=p_scores, return_components=True) + assert ( + isinstance(out_pc, tuple) and len(out_pc) == 3 + ), f"{name}.predict(return_components=True) must return (te, dhat_cs, dhat_ts)" + te2, dhat_cs, dhat_ts = out_pc + np.testing.assert_array_equal(te, te2, err_msg=f"{name}: predict inconsistency") + for label, d in [("dhat_cs", dhat_cs), ("dhat_ts", dhat_ts)]: + assert isinstance(d, dict) and set(d.keys()) == set( + xl.t_groups + ), f"{name}: {label} mismatch" + + fp_plain_x = xl.fit_predict(X=X, treatment=treatment, y=y, p=p_scores) + _assert_plain_fit_predict(fp_plain_x, name) + _assert_te(fp_plain_x, name, "fit_predict()") + _assert_ci_triple( + xl.fit_predict( + X=X, + treatment=treatment, + y=y, + p=p_scores, + return_ci=True, + n_bootstraps=5, + bootstrap_size=150, + ), + name, + "fit_predict", + ) + _assert_ate(xl.estimate_ate(X=X, treatment=treatment, y=y, p=p_scores), name) + _assert_ate( + xl.estimate_ate(X=X, treatment=treatment, y=y, p=p_scores, pretrain=True), name + ) + + # ── S-Learner ───────────────────────────────────────────────────────────── + name = "BaseSLearner" + sl = BaseSLearner(learner=LinearRegression()) + _assert_fit_returns_none(sl.fit(X=X, treatment=treatment, y=y), name) + + _assert_fit_attrs(sl, name) + assert hasattr(sl, "models") and isinstance( + sl.models, dict + ), f"{name}: missing models dict" + assert set(sl.models.keys()) == set(sl.t_groups) + # Each group's model is trained on different data so must be a distinct object. + assert all( + sl.models[g1] is not sl.models[g2] + for g1, g2 in zip(sl.t_groups[:-1], sl.t_groups[1:]) + ), f"{name}: models must be distinct per group" + + te = sl.predict(X=X) + _assert_te(te, name, "predict()") + + out_pc = sl.predict(X=X, return_components=True) + assert isinstance(out_pc, tuple) and len(out_pc) == 3 + te2, yhat_cs, yhat_ts = out_pc + np.testing.assert_array_equal(te, te2, err_msg=f"{name}: predict inconsistency") + _assert_components(yhat_cs, yhat_ts, sl.t_groups, name) + + fp_plain_s = sl.fit_predict(X=X, treatment=treatment, y=y) + _assert_plain_fit_predict(fp_plain_s, name) + _assert_te(fp_plain_s, name, "fit_predict()") + _assert_ci_triple( + sl.fit_predict( + X=X, + treatment=treatment, + y=y, + return_ci=True, + n_bootstraps=5, + bootstrap_size=150, + ), + name, + "fit_predict", + ) + ate_only = sl.estimate_ate(X=X, treatment=treatment, y=y, return_ci=False) + assert isinstance(ate_only, np.ndarray) and ate_only.shape == ( + n_groups, + ), f"{name}.estimate_ate(return_ci=False) must be shape (n_groups,)" + _assert_ate(sl.estimate_ate(X=X, treatment=treatment, y=y, return_ci=True), name) + _assert_ate( + sl.estimate_ate(X=X, treatment=treatment, y=y, return_ci=True, pretrain=True), + name, + ) + + # ── DR-Learner ──────────────────────────────────────────────────────────── + name = "BaseDRLearner" + dr = BaseDRLearner( + learner=LinearRegression(), treatment_effect_learner=LinearRegression() + ) + _assert_fit_returns_none(dr.fit(X=X, treatment=treatment, y=y), name) + + _assert_fit_attrs(dr, name) + # models_mu_c: list of 3 fold models (fold-specific, NOT per-group). + assert hasattr(dr, "models_mu_c") and isinstance( + dr.models_mu_c, list + ), f"{name}: models_mu_c must be a list" + assert len(dr.models_mu_c) == 3, f"{name}: models_mu_c must have 3 fold models" + # Per-group outcome and effect models: each a list of 3 fold models. + for attr in ("models_mu_t", "models_tau"): + assert hasattr(dr, attr) and isinstance(getattr(dr, attr), dict) + assert set(getattr(dr, attr).keys()) == set(dr.t_groups) + for g in dr.t_groups: + val = getattr(dr, attr)[g] + assert ( + isinstance(val, list) and len(val) == 3 + ), f"{name}: {attr}[{g}] must be list of 3 fold models" + + te = dr.predict(X=X) + _assert_te(te, name, "predict()") + + out_pc = dr.predict(X=X, return_components=True) + assert isinstance(out_pc, tuple) and len(out_pc) == 3 + te2, yhat_cs, yhat_ts = out_pc + np.testing.assert_array_equal(te, te2, err_msg=f"{name}: predict inconsistency") + _assert_components(yhat_cs, yhat_ts, dr.t_groups, name) + # yhat_cs must be a shared-reference dict (one fold-averaged control prediction). + assert all( + yhat_cs[g] is yhat_cs[dr.t_groups[0]] for g in dr.t_groups + ), f"{name}: yhat_cs values must share the same underlying array" + + fp_plain_dr = dr.fit_predict(X=X, treatment=treatment, y=y) + _assert_plain_fit_predict(fp_plain_dr, name) + _assert_te(fp_plain_dr, name, "fit_predict()") + _assert_ci_triple( + dr.fit_predict( + X=X, + treatment=treatment, + y=y, + return_ci=True, + n_bootstraps=5, + bootstrap_size=150, + ), + name, + "fit_predict", + ) + _assert_ate(dr.estimate_ate(X=X, treatment=treatment, y=y), name) + _assert_ate(dr.estimate_ate(X=X, treatment=treatment, y=y, pretrain=True), name) + + # ── R-Learner ───────────────────────────────────────────────────────────── + name = "BaseRLearner" + rl = BaseRLearner( + learner=LinearRegression(), + effect_learner=LinearRegression(), + cv_n_jobs=1, + ) + _assert_fit_returns_none( + rl.fit(X=X, treatment=treatment, y=y, p=p_scores, verbose=False), name + ) + + _assert_fit_attrs(rl, name) + # R-learner: single shared outcome model fitted once via cross-validation. + assert hasattr(rl, "model_mu"), f"{name}: missing model_mu" + assert hasattr(rl, "models_tau") and isinstance(rl.models_tau, dict) + assert set(rl.models_tau.keys()) == set(rl.t_groups) + assert all( + rl.models_tau[g1] is not rl.models_tau[g2] + for g1, g2 in zip(rl.t_groups[:-1], rl.t_groups[1:]) + ), f"{name}: models_tau must be distinct per group" + for attr in ("vars_c", "vars_t"): + assert hasattr(rl, attr) and isinstance(getattr(rl, attr), dict) + assert set(getattr(rl, attr).keys()) == set(rl.t_groups) + + # R-learner: predict(X, p=...) returns CATE only (no return_components path). + te = rl.predict(X=X, p=p_scores) + _assert_te(te, name, "predict()") + + fp_plain_r = rl.fit_predict( + X=X, treatment=treatment, y=y, p=p_scores, verbose=False + ) + _assert_plain_fit_predict(fp_plain_r, name) + _assert_te(fp_plain_r, name, "fit_predict()") + _assert_ci_triple( + rl.fit_predict( + X=X, + treatment=treatment, + y=y, + p=p_scores, + return_ci=True, + n_bootstraps=5, + bootstrap_size=150, + verbose=False, + ), + name, + "fit_predict", + ) + _assert_ate(rl.estimate_ate(X=X, treatment=treatment, y=y, p=p_scores), name) + _assert_ate( + rl.estimate_ate(X=X, treatment=treatment, y=y, p=p_scores, pretrain=True), name + )