3737 check_is_fitted ,
3838)
3939from tabpfn_common_utils .telemetry import track_model_call
40+ from tqdm .auto import tqdm
4041
4142from tabpfn .architectures .base .bar_distribution import FullSupportBarDistribution
4243from tabpfn .base import (
@@ -236,6 +237,7 @@ def __init__( # noqa: PLR0913
236237 n_preprocessing_jobs : int = 1 ,
237238 inference_config : dict | InferenceConfig | None = None ,
238239 differentiable_input : bool = False ,
240+ show_progress_bar : bool = False ,
239241 ) -> None :
240242 """Construct a TabPFN regressor.
241243
@@ -436,6 +438,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
436438 If true, preprocessing attempts to be end-to-end differentiable.
437439 Less relevant for standard regression fine-tuning compared to
438440 prompt-tuning.
441+
442+ show_progress_bar:
443+ Whether to show a progress bar during inference. Defaults to False.
439444 """
440445 super ().__init__ ()
441446 self .n_estimators = n_estimators
@@ -454,6 +459,7 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
454459 "fit_with_cache" ,
455460 "batched" ,
456461 ] = fit_mode
462+ self .show_progress_bar = show_progress_bar
457463 self .memory_saving_mode : MemorySavingMode = memory_saving_mode
458464 self .random_state = random_state
459465 self .inference_config = inference_config
@@ -967,8 +973,12 @@ def predict( # noqa: C901, PLR0912
967973 n_estimators = 0
968974 accumulated_logits : torch .Tensor | None = None
969975 with handle_oom_errors (self .devices_ , X , model_type = "regressor" ):
970- for borders_t , output in self ._iter_forward_executor (
971- X , use_inference_mode = True
976+ for borders_t , output in tqdm (
977+ self ._iter_forward_executor (X , use_inference_mode = True ),
978+ total = self .n_estimators ,
979+ desc = "TabPFN inference" ,
980+ unit = "estimator" ,
981+ disable = not self .show_progress_bar ,
972982 ):
973983 transformed = translate_probs_across_borders (
974984 output ,
@@ -1164,8 +1174,12 @@ def forward(
11641174 outputs : list [torch .Tensor ] = []
11651175 borders : list [np .ndarray ] = []
11661176
1167- for border , output in self ._iter_forward_executor (
1168- X , use_inference_mode = use_inference_mode
1177+ for border , output in tqdm (
1178+ self ._iter_forward_executor (X , use_inference_mode = use_inference_mode ),
1179+ total = self .n_estimators ,
1180+ desc = "TabPFN inference" ,
1181+ unit = "estimator" ,
1182+ disable = not self .show_progress_bar ,
11691183 ):
11701184 borders .append (border )
11711185 outputs .append (output )
0 commit comments