diff --git a/.gitignore b/.gitignore index e5742f754..b495229ff 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ cython_debug/ CLAUDE.md .claude uv.lock +.claude/settings.local.json diff --git a/scripts/check_distributional_calibration.py b/scripts/check_distributional_calibration.py new file mode 100644 index 000000000..32e7a6ec5 --- /dev/null +++ b/scripts/check_distributional_calibration.py @@ -0,0 +1,163 @@ +"""Diagnostic script to verify whether softmax_temperature and average_before_softmax +affect distributional calibration metrics. + +Hypothesis (from Jonas Landsgesell paper + email thread): +- softmax_temperature=0.9 sharpens distributions -> hurts log-score / CRPS / CRLS +- averaging probabilities vs logits before ensembling may matter for calibration + +Metrics: + NLL - Negative log-likelihood (log score), sensitive to sharpness + CRPS - Continuous Ranked Probability Score, integral over quantile losses + CRLS - Continuous Ranked Log Score = CRPS + NLL / 2 (Brehmer & Gneiting 2021) + IS95 - Interval Score at 95%: penalises width + miscoverage (lower = better) + Cov95 - Empirical coverage of 95% prediction interval (target: 0.95) + Sharp - Mean width of 95% PI (lower = sharper, but only good if well-calibrated) + MACE - Mean Absolute Calibration Error from PIT (lower = better) + KS_p - p-value of KS test for PIT uniformity (higher = better calibrated) + RMSE - Root mean squared error of the mean prediction +""" + +from __future__ import annotations + +import numpy as np +import torch +from scipy.stats import kstest + +from tabpfn import TabPFNRegressor + + +# --------------------------------------------------------------------------- +# Metric helpers +# --------------------------------------------------------------------------- + +def compute_pit(criterion, logits: torch.Tensor, y: np.ndarray) -> np.ndarray: + """P(Y <= y_true) under predicted distribution. Uniform => calibrated.""" + y_t = torch.as_tensor(y, dtype=logits.dtype, device=logits.device).unsqueeze(-1) + return criterion.cdf(logits, y_t).squeeze(-1).cpu().detach().numpy() + + +def compute_nll(criterion, logits: torch.Tensor, y: np.ndarray) -> float: + """Mean negative log-likelihood (log score).""" + y_t = torch.as_tensor(y, dtype=logits.dtype, device=logits.device) + return criterion(logits, y_t).mean().item() + + +def compute_crps(criterion, logits: torch.Tensor, y: np.ndarray) -> float: + """CRPS via quantile decomposition: E_q[(F^{-1}(q) - y)*(q - 1{y<=F^{-1}(q)})].""" + quantile_levels = np.linspace(0.01, 0.99, 99) + crps_sum = 0.0 + for q in quantile_levels: + q_pred = criterion.icdf(logits, q).cpu().detach().numpy() + indicator = (y <= q_pred).astype(float) + crps_sum += np.mean((indicator - q) ** 2) + return crps_sum / len(quantile_levels) + + +def compute_crls(crps: float, nll: float) -> float: + """Continuous Ranked Log Score (Brehmer & Gneiting 2021). + Combines sharpness of log score with calibration of CRPS: + CRLS = (CRPS + NLL) / 2 + """ + return (crps + nll) / 2 + + +def compute_is95(criterion, logits: torch.Tensor, y: np.ndarray) -> tuple[float, float, float]: + """Interval Score at 95% PI. + IS_alpha = (u - l) + (2/alpha) * [max(0, l-y) + max(0, y-u)] + """ + alpha = 0.05 + l = criterion.icdf(logits, alpha / 2).cpu().detach().numpy() + u = criterion.icdf(logits, 1 - alpha / 2).cpu().detach().numpy() + width = u - l + penalty = (2 / alpha) * (np.maximum(0, l - y) + np.maximum(0, y - u)) + is95 = np.mean(width + penalty) + coverage = np.mean((y >= l) & (y <= u)) + sharpness = np.mean(width) + return is95, coverage, sharpness + + +def compute_mace(pit: np.ndarray, n_bins: int = 10) -> float: + """Mean Absolute Calibration Error from PIT histogram.""" + expected = 1.0 / n_bins + counts, _ = np.histogram(pit, bins=n_bins, range=(0, 1)) + observed = counts / len(pit) + return np.mean(np.abs(observed - expected)) + + +# --------------------------------------------------------------------------- +# Evaluation +# --------------------------------------------------------------------------- + +def evaluate_config(X_train, y_train, X_test, y_test, *, softmax_temperature, average_before_softmax, ensemble_temperature=1.0): + reg = TabPFNRegressor( + n_estimators=8, + softmax_temperature=softmax_temperature, + average_before_softmax=average_before_softmax, + ensemble_temperature=ensemble_temperature, + random_state=42, + ) + reg.fit(X_train, y_train) + result = reg.predict(X_test, output_type="full") + + criterion = result["criterion"] + logits = result["logits"] + + pit = compute_pit(criterion, logits, y_test) + nll = compute_nll(criterion, logits, y_test) + crps = compute_crps(criterion, logits, y_test) + crls = compute_crls(crps, nll) + is95, cov95, sharp = compute_is95(criterion, logits, y_test) + mace = compute_mace(pit) + _, ks_p = kstest(pit, "uniform") + rmse = np.sqrt(np.mean((result["mean"] - y_test) ** 2)) + + return dict(nll=nll, crps=crps, crls=crls, is95=is95, + cov95=cov95, sharp=sharp, mace=mace, ks_p=ks_p, rmse=rmse) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + rng = np.random.default_rng(42) + n_train, n_test, n_features = 100, 200, 2 + + X = rng.normal(0, 1, (n_train + n_test, n_features)) + w = rng.normal(0, 1, n_features) + y = X @ w + rng.normal(0, 1, n_train + n_test) + + X_train, X_test = X[:n_train], X[n_train:] + y_train, y_test = y[:n_train], y[n_train:] + + configs = [ + {"softmax_temperature": 0.9, "average_before_softmax": False, "ensemble_temperature": 1.0}, # old default + {"softmax_temperature": 1.0, "average_before_softmax": False, "ensemble_temperature": 1.0}, # no temp scaling + {"softmax_temperature": 0.9, "average_before_softmax": False, "ensemble_temperature": 1/0.9}, # NEW default + {"softmax_temperature": 1.0, "average_before_softmax": False, "ensemble_temperature": 1.0}, # fully neutral + ] + + cols = ["NLL", "CRPS", "CRLS", "IS95", "Cov95", "Sharp", "MACE", "KS_p", "RMSE"] + header = f"{'Config':<45}" + "".join(f"{c:>8}" for c in cols) + print(header) + print("-" * len(header)) + + for cfg in configs: + label = f"sm_t={cfg['softmax_temperature']}, ens_t={cfg['ensemble_temperature']:.3f}" + m = evaluate_config(X_train, y_train, X_test, y_test, **cfg) + print( + f"{label:<45}" + f"{m['nll']:>8.4f}" + f"{m['crps']:>8.4f}" + f"{m['crls']:>8.4f}" + f"{m['is95']:>8.4f}" + f"{m['cov95']:>8.4f}" + f"{m['sharp']:>8.4f}" + f"{m['mace']:>8.4f}" + f"{m['ks_p']:>8.4f}" + f"{m['rmse']:>8.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 038df7b22..f002c2ed4 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -212,6 +212,7 @@ def __init__( # noqa: PLR0913 categorical_features_indices: Sequence[int] | None = None, softmax_temperature: float = 0.9, average_before_softmax: bool = False, + ensemble_temperature: float = 1 / 0.9, model_path: str | Path | list[str] @@ -280,6 +281,27 @@ def __init__( # noqa: PLR0913 - If `False`, the softmax function is applied to each set of logits. Then, we average the resulting probabilities of each forward pass. + ensemble_temperature: + Temperature applied to the final ensemble log-probabilities after all + estimators have been combined. This is a post-hoc recalibration step + that adjusts the sharpness of the predictive distribution without + affecting how individual estimators are aggregated. + + The default ``1 / 0.9 ≈ 1.111`` is deliberately paired with the + default ``softmax_temperature=0.9``: the per-estimator sharpening + (``softmax_temperature=0.9``) improves point-estimate quality (RMSE) + by helping ensemble members mix better, while ``ensemble_temperature`` + compensates for the resulting over-confidence of the final predictive + distribution, restoring calibration of quantiles, log-score, and CRPS. + + - Values **below 1.0** sharpen the distribution (more confident, + narrower prediction intervals). Use when the model is under-confident. + - Values **above 1.0** flatten the distribution (more uncertain, wider + intervals). Use to improve calibration when the model is + over-confident. + - Set both this and ``softmax_temperature`` to ``1.0`` to disable all + temperature scaling. + model_path: The path to the TabPFN model file, i.e., the pre-trained weights. @@ -440,6 +462,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this self.categorical_features_indices = categorical_features_indices self.softmax_temperature = softmax_temperature self.average_before_softmax = average_before_softmax + self.ensemble_temperature = ensemble_temperature self.model_path = model_path self.device = device self.ignore_pretraining_limits = ignore_pretraining_limits @@ -487,6 +510,7 @@ def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self: ), "n_estimators": 8, "softmax_temperature": 0.9, + "ensemble_temperature": 1 / 0.9, } elif version == ModelVersion.V2_5: options = { @@ -495,6 +519,7 @@ def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self: ), "n_estimators": 8, "softmax_temperature": 0.9, + "ensemble_temperature": 1 / 0.9, } elif version == ModelVersion.V2_6: options = { @@ -503,6 +528,7 @@ def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self: ), "n_estimators": 8, "softmax_temperature": 0.9, + "ensemble_temperature": 1 / 0.9, } else: raise ValueError(f"Unknown version: {version}") @@ -980,6 +1006,16 @@ def predict( # noqa: C901, PLR0912 if logits.dtype == torch.float16: logits = logits.float() + # Apply ensemble_temperature as a final recalibration of the predictive + # distribution. Dividing log-probabilities by T is equivalent to raising + # each probability to the power 1/T before renormalising via softmax. + # T > 1 flattens (widens) the distribution; T < 1 sharpens it. + # The default 1/0.9 ≈ 1.111 compensates for the sharpening introduced by + # softmax_temperature=0.9, restoring calibration while keeping the RMSE + # benefit from per-estimator sharpening. + if self.ensemble_temperature != 1.0: + logits = logits / self.ensemble_temperature + # Determine and return intended output type logit_to_output = partial( _logits_to_output,