Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,4 @@ cython_debug/
CLAUDE.md
.claude
uv.lock
.claude/settings.local.json
163 changes: 163 additions & 0 deletions scripts/check_distributional_calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Diagnostic script to verify whether softmax_temperature and average_before_softmax
affect distributional calibration metrics.

Hypothesis (from Jonas Landsgesell paper + email thread):
- softmax_temperature=0.9 sharpens distributions -> hurts log-score / CRPS / CRLS
- averaging probabilities vs logits before ensembling may matter for calibration

Metrics:
NLL - Negative log-likelihood (log score), sensitive to sharpness
CRPS - Continuous Ranked Probability Score, integral over quantile losses
CRLS - Continuous Ranked Log Score = CRPS + NLL / 2 (Brehmer & Gneiting 2021)
IS95 - Interval Score at 95%: penalises width + miscoverage (lower = better)
Cov95 - Empirical coverage of 95% prediction interval (target: 0.95)
Sharp - Mean width of 95% PI (lower = sharper, but only good if well-calibrated)
MACE - Mean Absolute Calibration Error from PIT (lower = better)
KS_p - p-value of KS test for PIT uniformity (higher = better calibrated)
RMSE - Root mean squared error of the mean prediction
"""

from __future__ import annotations

import numpy as np
import torch
from scipy.stats import kstest

from tabpfn import TabPFNRegressor

Check failure on line 26 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (I001)

scripts/check_distributional_calibration.py:20:1: I001 Import block is un-sorted or un-formatted


# ---------------------------------------------------------------------------
# Metric helpers
# ---------------------------------------------------------------------------

def compute_pit(criterion, logits: torch.Tensor, y: np.ndarray) -> np.ndarray:

Check failure on line 33 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (ANN001)

scripts/check_distributional_calibration.py:33:17: ANN001 Missing type annotation for function argument `criterion`
"""P(Y <= y_true) under predicted distribution. Uniform => calibrated."""
y_t = torch.as_tensor(y, dtype=logits.dtype, device=logits.device).unsqueeze(-1)
return criterion.cdf(logits, y_t).squeeze(-1).cpu().detach().numpy()


def compute_nll(criterion, logits: torch.Tensor, y: np.ndarray) -> float:

Check failure on line 39 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (ANN001)

scripts/check_distributional_calibration.py:39:17: ANN001 Missing type annotation for function argument `criterion`
"""Mean negative log-likelihood (log score)."""
y_t = torch.as_tensor(y, dtype=logits.dtype, device=logits.device)
return criterion(logits, y_t).mean().item()


def compute_crps(criterion, logits: torch.Tensor, y: np.ndarray) -> float:

Check failure on line 45 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (ANN001)

scripts/check_distributional_calibration.py:45:18: ANN001 Missing type annotation for function argument `criterion`
"""CRPS via quantile decomposition: E_q[(F^{-1}(q) - y)*(q - 1{y<=F^{-1}(q)})]."""
quantile_levels = np.linspace(0.01, 0.99, 99)
crps_sum = 0.0
for q in quantile_levels:
q_pred = criterion.icdf(logits, q).cpu().detach().numpy()
indicator = (y <= q_pred).astype(float)
crps_sum += np.mean((indicator - q) ** 2)
return crps_sum / len(quantile_levels)
Comment on lines +51 to +53
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of CRPS using quantile decomposition is incorrect. The current code calculates the Brier score for each quantile (the variance of the indicator function), which does not represent the Continuous Ranked Probability Score. CRPS is equivalent to twice the integral of the quantile loss (pinball loss) over all quantile levels.

Suggested change
indicator = (y <= q_pred).astype(float)
crps_sum += np.mean((indicator - q) ** 2)
return crps_sum / len(quantile_levels)
q_pred = criterion.icdf(logits, q).cpu().detach().numpy()
errors = y - q_pred
crps_sum += np.mean(np.maximum(q * errors, (q - 1) * errors))
return 2 * crps_sum / len(quantile_levels)



def compute_crls(crps: float, nll: float) -> float:
"""Continuous Ranked Log Score (Brehmer & Gneiting 2021).
Combines sharpness of log score with calibration of CRPS:
CRLS = (CRPS + NLL) / 2
"""

Check failure on line 60 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (D415)

scripts/check_distributional_calibration.py:57:5: D415 First line should end with a period, question mark, or exclamation point
return (crps + nll) / 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different to definition of CRLS (aka Exceedance Probability Score) https://link.springer.com/article/10.1080/15598608.2012.695663



def compute_is95(criterion, logits: torch.Tensor, y: np.ndarray) -> tuple[float, float, float]:

Check failure on line 64 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (E501)

scripts/check_distributional_calibration.py:64:89: E501 Line too long (95 > 88)

Check failure on line 64 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (ANN001)

scripts/check_distributional_calibration.py:64:18: ANN001 Missing type annotation for function argument `criterion`
"""Interval Score at 95% PI.
IS_alpha = (u - l) + (2/alpha) * [max(0, l-y) + max(0, y-u)]
"""

Check failure on line 67 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (D415)

scripts/check_distributional_calibration.py:65:5: D415 First line should end with a period, question mark, or exclamation point
alpha = 0.05
l = criterion.icdf(logits, alpha / 2).cpu().detach().numpy()

Check failure on line 69 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (E741)

scripts/check_distributional_calibration.py:69:5: E741 Ambiguous variable name: `l`
u = criterion.icdf(logits, 1 - alpha / 2).cpu().detach().numpy()
width = u - l
penalty = (2 / alpha) * (np.maximum(0, l - y) + np.maximum(0, y - u))
is95 = np.mean(width + penalty)
coverage = np.mean((y >= l) & (y <= u))
sharpness = np.mean(width)
return is95, coverage, sharpness


def compute_mace(pit: np.ndarray, n_bins: int = 10) -> float:
"""Mean Absolute Calibration Error from PIT histogram."""
expected = 1.0 / n_bins
counts, _ = np.histogram(pit, bins=n_bins, range=(0, 1))
observed = counts / len(pit)
return np.mean(np.abs(observed - expected))


# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------

def evaluate_config(X_train, y_train, X_test, y_test, *, softmax_temperature, average_before_softmax, ensemble_temperature=1.0):

Check failure on line 91 in scripts/check_distributional_calibration.py

View workflow job for this annotation

GitHub Actions / Ruff Linting & Formatting

Ruff (ANN201)

scripts/check_distributional_calibration.py:91:5: ANN201 Missing return type annotation for public function `evaluate_config`
reg = TabPFNRegressor(
n_estimators=8,
softmax_temperature=softmax_temperature,
average_before_softmax=average_before_softmax,
ensemble_temperature=ensemble_temperature,
random_state=42,
)
reg.fit(X_train, y_train)
result = reg.predict(X_test, output_type="full")

criterion = result["criterion"]
logits = result["logits"]

pit = compute_pit(criterion, logits, y_test)
nll = compute_nll(criterion, logits, y_test)
crps = compute_crps(criterion, logits, y_test)
crls = compute_crls(crps, nll)
is95, cov95, sharp = compute_is95(criterion, logits, y_test)
mace = compute_mace(pit)
_, ks_p = kstest(pit, "uniform")
rmse = np.sqrt(np.mean((result["mean"] - y_test) ** 2))

return dict(nll=nll, crps=crps, crls=crls, is95=is95,
cov95=cov95, sharp=sharp, mace=mace, ks_p=ks_p, rmse=rmse)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
rng = np.random.default_rng(42)
n_train, n_test, n_features = 100, 200, 2

X = rng.normal(0, 1, (n_train + n_test, n_features))
w = rng.normal(0, 1, n_features)
y = X @ w + rng.normal(0, 1, n_train + n_test)

X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]

configs = [
{"softmax_temperature": 0.9, "average_before_softmax": False, "ensemble_temperature": 1.0}, # old default
{"softmax_temperature": 1.0, "average_before_softmax": False, "ensemble_temperature": 1.0}, # no temp scaling
{"softmax_temperature": 0.9, "average_before_softmax": False, "ensemble_temperature": 1/0.9}, # NEW default
{"softmax_temperature": 1.0, "average_before_softmax": False, "ensemble_temperature": 1.0}, # fully neutral
]

cols = ["NLL", "CRPS", "CRLS", "IS95", "Cov95", "Sharp", "MACE", "KS_p", "RMSE"]
header = f"{'Config':<45}" + "".join(f"{c:>8}" for c in cols)
print(header)
print("-" * len(header))

for cfg in configs:
label = f"sm_t={cfg['softmax_temperature']}, ens_t={cfg['ensemble_temperature']:.3f}"
m = evaluate_config(X_train, y_train, X_test, y_test, **cfg)
print(
f"{label:<45}"
f"{m['nll']:>8.4f}"
f"{m['crps']:>8.4f}"
f"{m['crls']:>8.4f}"
f"{m['is95']:>8.4f}"
f"{m['cov95']:>8.4f}"
f"{m['sharp']:>8.4f}"
f"{m['mace']:>8.4f}"
f"{m['ks_p']:>8.4f}"
f"{m['rmse']:>8.4f}"
)


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions src/tabpfn/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def __init__( # noqa: PLR0913
categorical_features_indices: Sequence[int] | None = None,
softmax_temperature: float = 0.9,
average_before_softmax: bool = False,
ensemble_temperature: float = 1 / 0.9,
model_path: str
| Path
| list[str]
Expand Down Expand Up @@ -280,6 +281,27 @@ def __init__( # noqa: PLR0913
- If `False`, the softmax function is applied to each set of logits.
Then, we average the resulting probabilities of each forward pass.

ensemble_temperature:
Temperature applied to the final ensemble log-probabilities after all
estimators have been combined. This is a post-hoc recalibration step
that adjusts the sharpness of the predictive distribution without
affecting how individual estimators are aggregated.

The default ``1 / 0.9 ≈ 1.111`` is deliberately paired with the
default ``softmax_temperature=0.9``: the per-estimator sharpening
(``softmax_temperature=0.9``) improves point-estimate quality (RMSE)
by helping ensemble members mix better, while ``ensemble_temperature``
compensates for the resulting over-confidence of the final predictive
distribution, restoring calibration of quantiles, log-score, and CRPS.

- Values **below 1.0** sharpen the distribution (more confident,
narrower prediction intervals). Use when the model is under-confident.
- Values **above 1.0** flatten the distribution (more uncertain, wider
intervals). Use to improve calibration when the model is
over-confident.
- Set both this and ``softmax_temperature`` to ``1.0`` to disable all
temperature scaling.

model_path:
The path to the TabPFN model file, i.e., the pre-trained weights.

Expand Down Expand Up @@ -440,6 +462,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
self.categorical_features_indices = categorical_features_indices
self.softmax_temperature = softmax_temperature
self.average_before_softmax = average_before_softmax
self.ensemble_temperature = ensemble_temperature
self.model_path = model_path
self.device = device
self.ignore_pretraining_limits = ignore_pretraining_limits
Expand Down Expand Up @@ -487,6 +510,7 @@ def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self:
),
"n_estimators": 8,
"softmax_temperature": 0.9,
"ensemble_temperature": 1 / 0.9,
}
elif version == ModelVersion.V2_5:
options = {
Expand All @@ -495,6 +519,7 @@ def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self:
),
"n_estimators": 8,
"softmax_temperature": 0.9,
"ensemble_temperature": 1 / 0.9,
}
elif version == ModelVersion.V2_6:
options = {
Expand All @@ -503,6 +528,7 @@ def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self:
),
"n_estimators": 8,
"softmax_temperature": 0.9,
"ensemble_temperature": 1 / 0.9,
}
else:
raise ValueError(f"Unknown version: {version}")
Expand Down Expand Up @@ -980,6 +1006,16 @@ def predict( # noqa: C901, PLR0912
if logits.dtype == torch.float16:
logits = logits.float()

# Apply ensemble_temperature as a final recalibration of the predictive
# distribution. Dividing log-probabilities by T is equivalent to raising
# each probability to the power 1/T before renormalising via softmax.
# T > 1 flattens (widens) the distribution; T < 1 sharpens it.
# The default 1/0.9 ≈ 1.111 compensates for the sharpening introduced by
# softmax_temperature=0.9, restoring calibration while keeping the RMSE
# benefit from per-estimator sharpening.
if self.ensemble_temperature != 1.0:
logits = logits / self.ensemble_temperature

# Determine and return intended output type
logit_to_output = partial(
_logits_to_output,
Expand Down
Loading