Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 79 additions & 57 deletions src/tabpfn/finetuning/finetuned_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class FinetunedTabPFNBase(BaseEstimator, ABC):
data batches. This is helpful in most cases because, e.g., the column order
will stay the same across batches.
If False, the preprocessing will use a different random seed for each batch.
validation_frequency: How often (in epochs) to run validation. If set to
an integer N, validation is run every N epochs. If None, validation is
disabled entirely, which also disables early stopping. Defaults to 1
(validate every epoch).
"""

def __init__( # noqa: PLR0913
Expand Down Expand Up @@ -163,6 +167,7 @@ def __init__( # noqa: PLR0913
use_activation_checkpointing: bool = True,
save_checkpoint_interval: int | None = 10,
use_fixed_preprocessing_seed: bool = True,
validation_frequency: int | None = 1,
):
super().__init__()
self.device = device
Expand All @@ -188,6 +193,7 @@ def __init__( # noqa: PLR0913
self.save_checkpoint_interval = save_checkpoint_interval
self.meta_batch_size = META_BATCH_SIZE
self.use_fixed_preprocessing_seed = use_fixed_preprocessing_seed
self.validation_frequency = validation_frequency
Comment thread
Anakintano marked this conversation as resolved.

if self.use_fixed_preprocessing_seed and not (
self.n_estimators_finetune
Expand Down Expand Up @@ -528,16 +534,26 @@ def _fit( # noqa: C901,PLR0912
use_amp = self.device.startswith("cuda") and torch.cuda.is_available()
scaler = GradScaler() if use_amp else None # type: ignore

logger.info("--- 🚀 Eval default model ---")
eval_result = self._evaluate_model(
validation_eval_config,
X_train, # pyright: ignore[reportArgumentType]
y_train, # pyright: ignore[reportArgumentType]
X_val, # pyright: ignore[reportArgumentType]
y_val, # pyright: ignore[reportArgumentType]
)
self._log_epoch_evaluation(-1, eval_result, mean_train_loss=None)
best_metric: float = eval_result.primary
if self.validation_frequency is not None:
logger.info("--- 🚀 Eval default model ---")
eval_result = self._evaluate_model(
validation_eval_config,
X_train, # pyright: ignore[reportArgumentType]
y_train, # pyright: ignore[reportArgumentType]
X_val, # pyright: ignore[reportArgumentType]
y_val, # pyright: ignore[reportArgumentType]
)
self._log_epoch_evaluation(-1, eval_result, mean_train_loss=None)
best_metric: float = eval_result.primary
else:
if self.early_stopping:
warnings.warn(
"`early_stopping` is enabled but `validation_frequency` is None. "
"Early stopping requires validation; it will be disabled.",
UserWarning,
stacklevel=2,
)
best_metric = self._get_initial_best_metric()

static_seed, rng = infer_random_state(self.random_state)
preprocessing_random_state = (
Expand Down Expand Up @@ -684,61 +700,67 @@ def _fit( # noqa: C901,PLR0912
epoch_loss_sum / epoch_batches if epoch_batches > 0 else None
)

eval_result = self._evaluate_model(
validation_eval_config,
X_train, # pyright: ignore[reportArgumentType]
y_train, # pyright: ignore[reportArgumentType]
X_val, # pyright: ignore[reportArgumentType]
y_val, # pyright: ignore[reportArgumentType]
run_validation = (
self.validation_frequency is not None
and (epoch + 1) % self.validation_frequency == 0
)

self._log_epoch_evaluation(epoch, eval_result, mean_train_loss)
if run_validation:
eval_result = self._evaluate_model(
validation_eval_config,
X_train, # pyright: ignore[reportArgumentType]
y_train, # pyright: ignore[reportArgumentType]
X_val, # pyright: ignore[reportArgumentType]
y_val, # pyright: ignore[reportArgumentType]
)

primary_metric = eval_result.primary
self._log_epoch_evaluation(epoch, eval_result, mean_train_loss)

if output_dir is not None and not np.isnan(primary_metric):
save_interval_checkpoint = (
self.save_checkpoint_interval is not None
and (epoch + 1) % self.save_checkpoint_interval == 0
)
primary_metric = eval_result.primary

is_best = self._is_improvement(primary_metric, best_metric)

if save_interval_checkpoint or is_best:
save_checkpoint(
estimator=self.finetuned_estimator_,
output_dir=output_dir,
epoch=epoch + 1,
optimizer=optimizer,
metrics=self._get_checkpoint_metrics(eval_result),
train_size=train_size,
is_best=is_best,
save_interval_checkpoint=save_interval_checkpoint,
if output_dir is not None and not np.isnan(primary_metric):
save_interval_checkpoint = (
self.save_checkpoint_interval is not None
and (epoch + 1) % self.save_checkpoint_interval == 0
)

if self.early_stopping and not np.isnan(primary_metric):
if self._is_improvement(primary_metric, best_metric):
best_metric = primary_metric
patience_counter = 0
best_model = copy.deepcopy(self.finetuned_estimator_)
else:
patience_counter += 1
logger.info(
"⚠️ No improvement for %s epochs. Best %s: %.4f",
patience_counter,
self._metric_name,
best_metric,
)
is_best = self._is_improvement(primary_metric, best_metric)

if save_interval_checkpoint or is_best:
save_checkpoint(
estimator=self.finetuned_estimator_,
output_dir=output_dir,
epoch=epoch + 1,
optimizer=optimizer,
metrics=self._get_checkpoint_metrics(eval_result),
train_size=train_size,
is_best=is_best,
save_interval_checkpoint=save_interval_checkpoint,
)

if patience_counter >= self.early_stopping_patience:
logger.info(
"🛑 Early stopping triggered. Best %s: %.4f",
self._metric_name,
best_metric,
)
if best_model is not None:
self.finetuned_estimator_ = best_model
break
if self.early_stopping and not np.isnan(primary_metric):
if self._is_improvement(primary_metric, best_metric):
best_metric = primary_metric
patience_counter = 0
best_model = copy.deepcopy(self.finetuned_estimator_)
else:
patience_counter += 1
logger.info(
"⚠️ No improvement for %s epochs. Best %s: %.4f",
patience_counter,
self._metric_name,
best_metric,
)

if patience_counter >= self.early_stopping_patience:
logger.info(
"🛑 Early stopping triggered. Best %s: %.4f",
self._metric_name,
best_metric,
)
if best_model is not None:
self.finetuned_estimator_ = best_model
break

if self.time_limit is not None:
elapsed_time = time.monotonic() - start_time
Expand Down
6 changes: 6 additions & 0 deletions src/tabpfn/finetuning/finetuned_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ class FinetunedTabPFNClassifier(FinetunedTabPFNBase, ClassifierMixin):
data batches. This is helpful in most cases because, e.g., the column order
will stay the same across batches.
If False, the preprocessing will use a different random seed for each batch.
validation_frequency: How often (in epochs) to run validation. If set to
an integer N, validation is run every N epochs. If None, validation is
disabled entirely, which also disables early stopping. Defaults to 1
(validate every epoch).

FinetunedTabPFNClassifier specific arguments:

Expand Down Expand Up @@ -150,6 +154,7 @@ def __init__( # noqa: PLR0913
use_activation_checkpointing: bool = True,
save_checkpoint_interval: int | None = 10,
use_fixed_preprocessing_seed: bool = True,
validation_frequency: int | None = 1,
extra_classifier_kwargs: dict[str, Any] | None = None,
eval_metric: Literal["roc_auc", "log_loss"] | None = None,
):
Expand All @@ -176,6 +181,7 @@ def __init__( # noqa: PLR0913
use_activation_checkpointing=use_activation_checkpointing,
save_checkpoint_interval=save_checkpoint_interval,
use_fixed_preprocessing_seed=use_fixed_preprocessing_seed,
validation_frequency=validation_frequency,
)
self.extra_classifier_kwargs = extra_classifier_kwargs
self.eval_metric = eval_metric
Expand Down
6 changes: 6 additions & 0 deletions src/tabpfn/finetuning/finetuned_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ class FinetunedTabPFNRegressor(FinetunedTabPFNBase, RegressorMixin):
data batches. This is helpful in most cases because, e.g., the column order
will stay the same across batches.
If False, the preprocessing will use a different random seed for each batch.
validation_frequency: How often (in epochs) to run validation. If set to
an integer N, validation is run every N epochs. If None, validation is
disabled entirely, which also disables early stopping. Defaults to 1
(validate every epoch).

FinetunedTabPFNRegressor specific arguments:

Expand Down Expand Up @@ -333,6 +337,7 @@ def __init__( # noqa: PLR0913
use_activation_checkpointing: bool = True,
save_checkpoint_interval: int | None = 10,
use_fixed_preprocessing_seed: bool = True,
validation_frequency: int | None = 1,
extra_regressor_kwargs: dict[str, Any] | None = None,
ce_loss_weight: float = 0.0,
crps_loss_weight: float = 1.0,
Expand Down Expand Up @@ -366,6 +371,7 @@ def __init__( # noqa: PLR0913
use_activation_checkpointing=use_activation_checkpointing,
save_checkpoint_interval=save_checkpoint_interval,
use_fixed_preprocessing_seed=use_fixed_preprocessing_seed,
validation_frequency=validation_frequency,
)
self.extra_regressor_kwargs = extra_regressor_kwargs
self.eval_metric = eval_metric
Expand Down