Skip to content

Commit b3bbf94

Browse files
committed
add progress bar
1 parent 231de0c commit b3bbf94

6 files changed

Lines changed: 50 additions & 6 deletions

File tree

changelog/899.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add an optional `show_progress_bar` flag to TabPFN classifier and regressor inference, defaulting to `False`.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
# Once Python 3.10 is the minimum version, this can be removed.
2222
"eval-type-backport>=0.2.2",
2323
"joblib>=1.2.0",
24+
"tqdm>=4.66.0",
2425
"tabpfn-common-utils[telemetry-interactive]>=0.2.13",
2526
"filelock>=3.11.0",
2627
]

src/tabpfn/classifier.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sklearn import config_context
3232
from sklearn.base import BaseEstimator, ClassifierMixin, check_is_fitted
3333
from tabpfn_common_utils.telemetry import track_model_call
34+
from tqdm.auto import tqdm
3435

3536
from tabpfn.base import (
3637
ClassifierModelSpecs,
@@ -222,6 +223,7 @@ def __init__( # noqa: PLR0913
222223
"fit_with_cache",
223224
"batched",
224225
] = "fit_preprocessors",
226+
show_progress_bar: bool = False,
225227
memory_saving_mode: MemorySavingMode = "auto",
226228
random_state: int | np.random.RandomState | np.random.Generator | None = 0,
227229
n_jobs: Annotated[int | None, deprecated("Use n_preprocessing_jobs")] = None,
@@ -453,6 +455,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
453455
`eval_metric`. See
454456
[tabpfn.inference_tuning.ClassifierTuningConfig][] for details
455457
and options.
458+
459+
show_progress_bar:
460+
Whether to show a progress bar during inference. Defaults to False.
456461
"""
457462
super().__init__()
458463
self.n_estimators = n_estimators
@@ -467,6 +472,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
467472
inference_precision
468473
)
469474
self.fit_mode = fit_mode
475+
self.show_progress_bar = show_progress_bar
470476
self.memory_saving_mode: MemorySavingMode = memory_saving_mode
471477
self.random_state = random_state
472478
self.inference_config = inference_config
@@ -1428,10 +1434,16 @@ def forward( # noqa: C901, PLR0912
14281434
self.executor_.use_torch_inference_mode(use_inference=actual_inference_mode)
14291435

14301436
outputs = []
1431-
for output, config in self.executor_.iter_outputs(
1432-
X,
1433-
autocast=self.use_autocast_,
1434-
task_type="multiclass",
1437+
for output, config in tqdm(
1438+
self.executor_.iter_outputs(
1439+
X,
1440+
autocast=self.use_autocast_,
1441+
task_type="multiclass",
1442+
),
1443+
total=self.n_estimators,
1444+
desc="TabPFN inference",
1445+
unit="estimator",
1446+
disable=not getattr(self, "show_progress_bar", False),
14351447
):
14361448
original_ndim = output.ndim
14371449

src/tabpfn/regressor.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
check_is_fitted,
3737
)
3838
from tabpfn_common_utils.telemetry import track_model_call
39+
from tqdm.auto import tqdm
3940

4041
from tabpfn.architectures.base.bar_distribution import FullSupportBarDistribution
4142
from tabpfn.base import (
@@ -229,6 +230,7 @@ def __init__( # noqa: PLR0913
229230
"fit_with_cache",
230231
"batched",
231232
] = "fit_preprocessors",
233+
show_progress_bar: bool = False,
232234
memory_saving_mode: MemorySavingMode = "auto",
233235
random_state: int | np.random.RandomState | np.random.Generator | None = 0,
234236
n_jobs: Annotated[int | None, deprecated("Use n_preprocessing_jobs")] = None,
@@ -435,6 +437,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
435437
If true, preprocessing attempts to be end-to-end differentiable.
436438
Less relevant for standard regression fine-tuning compared to
437439
prompt-tuning.
440+
441+
show_progress_bar:
442+
Whether to show a progress bar during inference. Defaults to False.
438443
"""
439444
super().__init__()
440445
self.n_estimators = n_estimators
@@ -453,6 +458,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
453458
"fit_with_cache",
454459
"batched",
455460
] = fit_mode
461+
self.show_progress_bar = show_progress_bar
456462
self.memory_saving_mode: MemorySavingMode = memory_saving_mode
457463
self.random_state = random_state
458464
self.inference_config = inference_config
@@ -952,8 +958,12 @@ def predict( # noqa: C901, PLR0912
952958
n_estimators = 0
953959
accumulated_logits: torch.Tensor | None = None
954960
with handle_oom_errors(self.devices_, X, model_type="regressor"):
955-
for borders_t, output in self._iter_forward_executor(
956-
X, use_inference_mode=True
961+
for borders_t, output in tqdm(
962+
self._iter_forward_executor(X, use_inference_mode=True),
963+
total=self.n_estimators,
964+
desc="TabPFN inference",
965+
unit="estimator",
966+
disable=not getattr(self, "show_progress_bar", False),
957967
):
958968
transformed = translate_probs_across_borders(
959969
output,

tests/test_classifier_interface.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ def X_y() -> tuple[np.ndarray, np.ndarray]:
6565
fit_modes = ["low_memory", "fit_preprocessors"]
6666

6767

68+
def test__show_progress_bar__is_configurable() -> None:
69+
model = TabPFNClassifier(show_progress_bar=True)
70+
assert model.show_progress_bar is True
71+
assert model.get_params()["show_progress_bar"] is True
72+
73+
default_model = TabPFNClassifier()
74+
assert default_model.show_progress_bar is False
75+
assert default_model.get_params()["show_progress_bar"] is False
76+
77+
6878
@pytest.mark.parametrize(
6979
("device", "n_estimators", "fit_mode", "inference_precision"),
7080
mark_mps_configs_as_slow(

tests/test_regressor_interface.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@
4040
fit_modes = ["low_memory", "fit_preprocessors"]
4141

4242

43+
def test__show_progress_bar__is_configurable() -> None:
44+
model = TabPFNRegressor(show_progress_bar=True)
45+
assert model.show_progress_bar is True
46+
assert model.get_params()["show_progress_bar"] is True
47+
48+
default_model = TabPFNRegressor()
49+
assert default_model.show_progress_bar is False
50+
assert default_model.get_params()["show_progress_bar"] is False
51+
52+
4353
@pytest.fixture(scope="module")
4454
def X_y() -> tuple[np.ndarray, np.ndarray]:
4555
X, y, _ = sklearn.datasets.make_regression(

0 commit comments

Comments
 (0)