3131from sklearn import config_context
3232from sklearn .base import BaseEstimator , ClassifierMixin , check_is_fitted
3333from tabpfn_common_utils .telemetry import track_model_call
34+ from tqdm .auto import tqdm
3435
3536from tabpfn .base import (
3637 ClassifierModelSpecs ,
@@ -222,6 +223,7 @@ def __init__( # noqa: PLR0913
222223 "fit_with_cache" ,
223224 "batched" ,
224225 ] = "fit_preprocessors" ,
226+ show_progress_bar : bool = False ,
225227 memory_saving_mode : MemorySavingMode = "auto" ,
226228 random_state : int | np .random .RandomState | np .random .Generator | None = 0 ,
227229 n_jobs : Annotated [int | None , deprecated ("Use n_preprocessing_jobs" )] = None ,
@@ -453,6 +455,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
453455 `eval_metric`. See
454456 [tabpfn.inference_tuning.ClassifierTuningConfig][] for details
455457 and options.
458+
459+ show_progress_bar:
460+ Whether to show a progress bar during inference. Defaults to False.
456461 """
457462 super ().__init__ ()
458463 self .n_estimators = n_estimators
@@ -467,6 +472,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
467472 inference_precision
468473 )
469474 self .fit_mode = fit_mode
475+ self .show_progress_bar = show_progress_bar
470476 self .memory_saving_mode : MemorySavingMode = memory_saving_mode
471477 self .random_state = random_state
472478 self .inference_config = inference_config
@@ -1428,10 +1434,16 @@ def forward( # noqa: C901, PLR0912
14281434 self .executor_ .use_torch_inference_mode (use_inference = actual_inference_mode )
14291435
14301436 outputs = []
1431- for output , config in self .executor_ .iter_outputs (
1432- X ,
1433- autocast = self .use_autocast_ ,
1434- task_type = "multiclass" ,
1437+ for output , config in tqdm (
1438+ self .executor_ .iter_outputs (
1439+ X ,
1440+ autocast = self .use_autocast_ ,
1441+ task_type = "multiclass" ,
1442+ ),
1443+ total = self .n_estimators ,
1444+ desc = "TabPFN inference" ,
1445+ unit = "estimator" ,
1446+ disable = not getattr (self , "show_progress_bar" , False ),
14351447 ):
14361448 original_ndim = output .ndim
14371449
0 commit comments