Skip to content

Commit 8000ea5

Browse files
committed
add progress bar
1 parent e22bdb4 commit 8000ea5

6 files changed

Lines changed: 114 additions & 8 deletions

File tree

changelog/899.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add an optional `show_progress_bar` flag to TabPFN classifier and regressor inference, defaulting to `False`.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
# Once Python 3.10 is the minimum version, this can be removed.
2222
"eval-type-backport>=0.2.2",
2323
"joblib>=1.2.0",
24+
"tqdm>=4.66.0",
2425
"tabpfn-common-utils[telemetry-interactive]>=0.2.13",
2526
"filelock>=3.11.0",
2627
]

src/tabpfn/classifier.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sklearn import config_context
3232
from sklearn.base import BaseEstimator, ClassifierMixin, check_is_fitted
3333
from tabpfn_common_utils.telemetry import track_model_call
34+
from tqdm.auto import tqdm
3435

3536
from tabpfn.base import (
3637
ClassifierModelSpecs,
@@ -230,6 +231,7 @@ def __init__( # noqa: PLR0913
230231
differentiable_input: bool = False,
231232
eval_metric: str | ClassifierEvalMetrics | None = None,
232233
tuning_config: dict | ClassifierTuningConfig | None = None,
234+
show_progress_bar: bool = False,
233235
) -> None:
234236
"""Construct a TabPFN classifier.
235237
@@ -453,6 +455,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
453455
`eval_metric`. See
454456
[tabpfn.inference_tuning.ClassifierTuningConfig][] for details
455457
and options.
458+
459+
show_progress_bar:
460+
Whether to show a progress bar during inference. Defaults to False.
456461
"""
457462
super().__init__()
458463
self.n_estimators = n_estimators
@@ -467,6 +472,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
467472
inference_precision
468473
)
469474
self.fit_mode = fit_mode
475+
self.show_progress_bar = show_progress_bar
470476
self.memory_saving_mode: MemorySavingMode = memory_saving_mode
471477
self.random_state = random_state
472478
self.inference_config = inference_config
@@ -1442,10 +1448,16 @@ def forward( # noqa: C901, PLR0912
14421448
self.executor_.use_torch_inference_mode(use_inference=actual_inference_mode)
14431449

14441450
outputs = []
1445-
for output, config in self.executor_.iter_outputs(
1446-
X,
1447-
autocast=self.use_autocast_,
1448-
task_type="multiclass",
1451+
for output, config in tqdm(
1452+
self.executor_.iter_outputs(
1453+
X,
1454+
autocast=self.use_autocast_,
1455+
task_type="multiclass",
1456+
),
1457+
total=self.n_estimators,
1458+
desc="TabPFN inference",
1459+
unit="estimator",
1460+
disable=not self.show_progress_bar,
14491461
):
14501462
original_ndim = output.ndim
14511463

src/tabpfn/regressor.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
check_is_fitted,
3838
)
3939
from tabpfn_common_utils.telemetry import track_model_call
40+
from tqdm.auto import tqdm
4041

4142
from tabpfn.architectures.base.bar_distribution import FullSupportBarDistribution
4243
from tabpfn.base import (
@@ -236,6 +237,7 @@ def __init__( # noqa: PLR0913
236237
n_preprocessing_jobs: int = 1,
237238
inference_config: dict | InferenceConfig | None = None,
238239
differentiable_input: bool = False,
240+
show_progress_bar: bool = False,
239241
) -> None:
240242
"""Construct a TabPFN regressor.
241243
@@ -436,6 +438,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
436438
If true, preprocessing attempts to be end-to-end differentiable.
437439
Less relevant for standard regression fine-tuning compared to
438440
prompt-tuning.
441+
442+
show_progress_bar:
443+
Whether to show a progress bar during inference. Defaults to False.
439444
"""
440445
super().__init__()
441446
self.n_estimators = n_estimators
@@ -454,6 +459,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
454459
"fit_with_cache",
455460
"batched",
456461
] = fit_mode
462+
self.show_progress_bar = show_progress_bar
457463
self.memory_saving_mode: MemorySavingMode = memory_saving_mode
458464
self.random_state = random_state
459465
self.inference_config = inference_config
@@ -967,8 +973,12 @@ def predict( # noqa: C901, PLR0912
967973
n_estimators = 0
968974
accumulated_logits: torch.Tensor | None = None
969975
with handle_oom_errors(self.devices_, X, model_type="regressor"):
970-
for borders_t, output in self._iter_forward_executor(
971-
X, use_inference_mode=True
976+
for borders_t, output in tqdm(
977+
self._iter_forward_executor(X, use_inference_mode=True),
978+
total=self.n_estimators,
979+
desc="TabPFN inference",
980+
unit="estimator",
981+
disable=not self.show_progress_bar,
972982
):
973983
transformed = translate_probs_across_borders(
974984
output,
@@ -1164,8 +1174,12 @@ def forward(
11641174
outputs: list[torch.Tensor] = []
11651175
borders: list[np.ndarray] = []
11661176

1167-
for border, output in self._iter_forward_executor(
1168-
X, use_inference_mode=use_inference_mode
1177+
for border, output in tqdm(
1178+
self._iter_forward_executor(X, use_inference_mode=use_inference_mode),
1179+
total=self.n_estimators,
1180+
desc="TabPFN inference",
1181+
unit="estimator",
1182+
disable=not self.show_progress_bar,
11691183
):
11701184
borders.append(border)
11711185
outputs.append(output)

tests/test_classifier_interface.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,34 @@ def X_y() -> tuple[np.ndarray, np.ndarray]:
6565
fit_modes = ["low_memory", "fit_preprocessors"]
6666

6767

68+
def test__show_progress_bar__is_configurable() -> None:
69+
model = TabPFNClassifier(show_progress_bar=True)
70+
assert model.show_progress_bar is True
71+
assert model.get_params()["show_progress_bar"] is True
72+
73+
default_model = TabPFNClassifier()
74+
assert default_model.show_progress_bar is False
75+
assert default_model.get_params()["show_progress_bar"] is False
76+
77+
78+
def test__predict__show_progress_bar_true__tiny_dataset_does_not_crash() -> None:
79+
model = TabPFNClassifier(n_estimators=1, show_progress_bar=True, random_state=42)
80+
X, y = sklearn.datasets.make_classification(
81+
n_samples=9,
82+
n_features=3,
83+
n_informative=3,
84+
n_redundant=0,
85+
n_classes=3,
86+
random_state=0,
87+
)
88+
89+
model.fit(X, y)
90+
91+
predictions = model.predict(X)
92+
93+
assert predictions.shape == (X.shape[0],)
94+
95+
6896
@pytest.mark.parametrize(
6997
("device", "n_estimators", "fit_mode", "inference_precision"),
7098
mark_mps_configs_as_slow(

tests/test_regressor_interface.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sklearn.utils.estimator_checks import parametrize_with_checks
1919
from torch import nn
2020

21+
import tabpfn.regressor as regressor_module
2122
from tabpfn import TabPFNRegressor
2223
from tabpfn.base import RegressorModelSpecs, initialize_tabpfn_model
2324
from tabpfn.constants import ModelVersion
@@ -40,6 +41,55 @@
4041
fit_modes = ["low_memory", "fit_preprocessors"]
4142

4243

44+
def test__show_progress_bar__is_configurable() -> None:
45+
model = TabPFNRegressor(show_progress_bar=True)
46+
assert model.show_progress_bar is True
47+
assert model.get_params()["show_progress_bar"] is True
48+
49+
default_model = TabPFNRegressor()
50+
assert default_model.show_progress_bar is False
51+
assert default_model.get_params()["show_progress_bar"] is False
52+
53+
54+
def test__predict__show_progress_bar_true__tiny_dataset_does_not_crash() -> None:
55+
model = TabPFNRegressor(n_estimators=1, show_progress_bar=True, random_state=42)
56+
X, y = sklearn.datasets.make_regression(
57+
n_samples=9,
58+
n_features=3,
59+
random_state=0,
60+
coef=False,
61+
)
62+
63+
model.fit(X, y)
64+
65+
predictions = model.predict(X)
66+
67+
assert predictions.shape == (X.shape[0],)
68+
69+
70+
def test__forward__passes_progress_bar_flag(monkeypatch: pytest.MonkeyPatch) -> None:
71+
model = TabPFNRegressor(n_estimators=1, show_progress_bar=True, random_state=42)
72+
X, y = sklearn.datasets.make_regression(
73+
n_samples=9, n_features=3, random_state=0, coef=False
74+
)
75+
model.fit(X, y)
76+
77+
captured_kwargs: dict[str, object] = {}
78+
79+
def fake_tqdm(iterable, **kwargs) -> typing.Iterable[object]:
80+
captured_kwargs.update(kwargs)
81+
return iterable
82+
83+
monkeypatch.setattr(regressor_module, "tqdm", fake_tqdm)
84+
85+
averaged_logits, outputs, borders = model.forward(X, use_inference_mode=True)
86+
87+
assert averaged_logits is not None
88+
assert outputs
89+
assert borders
90+
assert captured_kwargs["disable"] is False
91+
92+
4393
@pytest.fixture(scope="module")
4494
def X_y() -> tuple[np.ndarray, np.ndarray]:
4595
X, y, _ = sklearn.datasets.make_regression(

0 commit comments

Comments
 (0)