diff --git a/examples/asr/run_ipl.py b/examples/asr/run_ipl.py new file mode 100644 index 000000000000..8de8c1aa0140 --- /dev/null +++ b/examples/asr/run_ipl.py @@ -0,0 +1,263 @@ +import copy +import glob +import os +import subprocess +import sys +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import torch +from omegaconf import OmegaConf, open_dict + +from nemo.collections.asr.parts.utils.run_ipl_utils import * +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +def get_command_for_inference( + inference_config: str, inference_config_dir: Union[str, Path], p_cache: float, checkpoint: str +) -> Tuple[str, List[str], List[str]]: + """ + Generates the command string for running speech inference with transcribe_speech_parallel. + + Args: + inference_config (str): Path to the base inference configuration file. + inference_config_dir (Union[str, Path]): Directory to store temporary modified configurations. + p_cache (float): Proportion of the dataset to be cached for pseudo-labeling. + checkpoint (str): Path to the model checkpoint to use for inference. + + Returns: + Tuple[str, List[str], List[str]]: + - The command string to execute inference for all specified manifests. + - List of output directories corresponding to each manifest. + - List of completed full pass transcribed manifest paths, if any. + """ + manifests, tarr_audio_files = separate_multiple_transcriptions(inference_config) + num_gpus = torch.cuda.device_count() + output_dirs = [] + cmd = "" + print(f"manifests {manifests}") + print(f"tarr_audio_files {tarr_audio_files}") + for i in range(len(manifests)): + output_dir = os.path.dirname(manifests[i]) + output_dirs.append(output_dir) + + base_cfg = OmegaConf.load(inference_config) + temp_config_dir = Path(str(inference_config_dir) + "/temp_configs").absolute() + os.makedirs(temp_config_dir, exist_ok=True) + modified_cfg = copy.deepcopy(base_cfg) + + # Check if we need to run inference on the whole set or update part of it + full_pass_done = glob.glob(os.path.join(output_dir, 'transcribed_manifest*')) + if full_pass_done: + number_of_files = count_files_for_pseudo_labeling(manifests[i], bool(tarr_audio_files)) + limit_predict_batches = int((number_of_files * p_cache) / (modified_cfg.predict_ds.batch_size * num_gpus)) + OmegaConf.update(modified_cfg, "trainer.limit_predict_batches", limit_predict_batches) + + # Replace OmegaConf updates with simple assignments + OmegaConf.update(modified_cfg, "output_path", output_dir) + OmegaConf.update(modified_cfg, "predict_ds.manifest_filepath", manifests[i]) + if tarr_audio_files: + OmegaConf.update(modified_cfg, "predict_ds.tarred_audio_filepaths", tarr_audio_files[i]) + OmegaConf.update(modified_cfg, "model", checkpoint) + + temp_config_file = os.path.join(temp_config_dir, f"modified_config_{i}.yaml") + OmegaConf.save(modified_cfg, temp_config_file) + cmd += f"python examples/asr/transcribe_speech_parallel.py --config-path {temp_config_dir} --config-name modified_config_{i}.yaml && " + + # Remove trailing '&&' from the final command string + cmd = cmd.rstrip(" &&") + + print(f"Inference command: {cmd}") + return cmd, output_dirs, full_pass_done + + +def get_execution_script(cluster_script_path: str, config_name: str, config_path: str, updated_manifest_filepaths=None, updated_tarred_filepaths=None) -> str: + """ + Constructs a command string to execute a training with the specified configuration. + + Args: + cluster_script_path (str): Path to the cluster script to be executed. + config_name (str): Name of the configuration file or object to be passed as a parameter. + config_path (str): Path to the directory where the configuration resides. + + Returns: + str: A formatted command string ready for execution. + """ + # Create the command to run the script + cmd = """ + cd {cluster_script_dir} && \ + python {cluster_script_path} --config-path {config_path} --config-name "{config_name}" """ + format_dict = dict( + cluster_script_dir=os.path.dirname(cluster_script_path), + cluster_script_path=os.path.basename(cluster_script_path), + config_path=config_path, + config_name=config_name, + ) + cmd = cmd.format(**format_dict) + if updated_manifest_filepaths: + cmd += f" model.train_ds.manifest_filepath={updated_manifest_filepaths}" + if updated_tarred_filepaths: + cmd += f" model.train_ds.tarred_audio_filepaths={updated_tarred_filepaths}" + print(f"Training command: {cmd}") + return cmd + + +def find_checkpoint_dir(base_path): + """ + Find the 'checkpoints' folder in the directory structure. + Parameters: + base_path (str): The base directory path to search from. + """ + for root, dirs, files in os.walk(base_path): + for dir_name in dirs: + if dir_name == "checkpoints": + return os.path.join(root, dir_name), root + return None, None + + +def run_command(cmd: str, shell: bool = True, log_file: str = None) -> bool: + """ + Safely run a shell command using subprocess and stream output in real-time. + + Args: + cmd (str): Command to execute + shell (bool): Whether to use shell for command execution + log_file (str): Optional path to save logs to file + + Returns: + bool: True if command executed successfully, False otherwise + """ + try: + # Create log file if specified + log_handle = None + if log_file: + log_handle = open(log_file, 'a') + log_handle.write(f"\n{'='*80}\n") + log_handle.write(f"Command: {cmd}\n") + log_handle.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + log_handle.write(f"{'='*80}\n\n") + + # Start the process + process = subprocess.Popen( + cmd, + shell=shell, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, # Line buffered + universal_newlines=True + ) + + # Stream output in real-time + while True: + output = process.stdout.readline() + if output == '' and process.poll() is not None: + break + if output: + print(output.strip()) + if log_handle: + log_handle.write(output) + log_handle.flush() + + # Get any remaining stderr + stderr = process.stderr.read() + if stderr: + print(f"Error output: {stderr}", file=sys.stderr) + if log_handle: + log_handle.write(f"Error output: {stderr}\n") + log_handle.flush() + + # Get return code + return_code = process.poll() + + if log_handle: + log_handle.write(f"\nProcess completed with return code: {return_code}\n") + log_handle.close() + + return return_code == 0 + + except Exception as e: + print(f"Command failed with error: {e}") + if log_handle: + log_handle.write(f"Command failed with error: {e}\n") + log_handle.close() + return False + + +@hydra_runner(config_path='./', config_name='run_ipl') +def main(run_config): + script_config = run_config.script_config + script_path = run_config.script_path + inference_config = run_config.inference_config + ipl_epochs = run_config.ipl_epochs + inference_config_dir = os.path.dirname(Path(inference_config).absolute()) + script_config_path = os.path.dirname(Path(script_config).absolute()) + script_config_name = os.path.basename(Path(script_config).absolute()) + inference_config = os.path.join(inference_config_dir, inference_config) + + # Create logs directory + logs_dir = os.path.join(script_config_path, "ipl_logs") + os.makedirs(logs_dir, exist_ok=True) + log_file = os.path.join(logs_dir, f"ipl_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") + + # Load the config directly + config = OmegaConf.load(script_config) + config.exp_manager.resume_if_exists = True + + # Find initial checkpoint + + add_pl_datasets = True + updated_manifest_filepaths=None + updated_tarred_audio_filepaths=None + for epoch in range(ipl_epochs): + print(f"\nStarting IPL epoch {epoch + 1}/{ipl_epochs}") + + # First run training + training_command = get_execution_script(script_path, script_config_name, script_config_path, updated_manifest_filepaths, updated_tarred_audio_filepaths) + run_command(training_command, log_file=log_file) + + # Update checkpoint after training + checkpoint_path, logs_dir = find_checkpoint_dir( + os.path.join(config.exp_manager.exp_dir, config.exp_manager.name) + ) + checkpoint = os.path.join(checkpoint_path, config.exp_manager.name + ".nemo") + + # Then run inference + cmd, output_dirs, full_pass_done = get_command_for_inference( + inference_config, inference_config_dir, 0.5, checkpoint + ) + if not run_command(cmd, log_file=log_file): + print("Inference failed, stopping IPL process") + break + + # Create manifests based on whether it's first pass or not + if not full_pass_done: + if config.model.train_ds.is_tarred: + all_manifest_filepaths = create_transcribed_shard_manifests(output_dirs) + else: + all_manifest_filepaths = create_transcribed_manifests(output_dirs) + else: + if config.model.train_ds.is_tarred: + all_manifest_filepaths = write_sampled_shard_transcriptions(output_dirs) + else: + all_manifest_filepaths = write_sampled_transcriptions(output_dirs) + + # Update training sets if needed + if add_pl_datasets: + base_cfg = OmegaConf.load(inference_config) + updated_manifest_filepaths, updated_tarred_audio_filepaths = update_training_sets( + config, all_manifest_filepaths, base_cfg.predict_ds.get("tarred_audio_filepaths", None) + ) + add_pl_datasets = False + + # Save updated config for next iteration + config_filepath = os.path.join(script_config_path, "update_script_config.yaml") + OmegaConf.save(config, config_filepath) + + print(f"Completed IPL epoch {epoch + 1}/{ipl_epochs}") + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/asr/parts/utils/ipl_utils.py b/nemo/collections/asr/parts/utils/ipl_utils.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/asr/parts/utils/run_ipl_utils.py b/nemo/collections/asr/parts/utils/run_ipl_utils.py new file mode 100644 index 000000000000..4a154833c1c7 --- /dev/null +++ b/nemo/collections/asr/parts/utils/run_ipl_utils.py @@ -0,0 +1,307 @@ +import glob +import json +import os +from typing import List, Optional, Tuple, Union + +from omegaconf import OmegaConf + + +def separate_multiple_transcriptions(inference_config: str) -> Tuple[List[str], Optional[List[str]]]: + """ + Separates and returns the manifest and tarred audio file paths from the configuration. + This function makes it easier to run transcribe_speech_parallel for each bucket separately + + Args: + inference_config (str): Path to the inference configuration file. + Returns: + Tuple[List[str], Optional[List[str]]]: A tuple containing: + - A list of manifest file paths. + - An optional list of tarred audio file paths, or None if not applicable. + """ + + inference_config = OmegaConf.load(inference_config) + print(f"config {inference_config}") + if hasattr(inference_config.predict_ds, "is_tarred") and inference_config.predict_ds.is_tarred: + tarred_audio_filepaths = inference_config.predict_ds.tarred_audio_filepaths + manifest_filepaths = inference_config.predict_ds.manifest_filepath + if type(tarred_audio_filepaths) != str and len(tarred_audio_filepaths) > 1: + manifests = [] + tarr_audio_files = [] + for manifest_filepath, tarred_audio_filepath in zip(manifest_filepaths, tarred_audio_filepaths): + manifests.append(manifest_filepath[0]) + tarr_audio_files.append(tarred_audio_filepath[0]) + return manifests, tarr_audio_files + else: + return [manifest_filepaths], [tarred_audio_filepaths] + else: + if isinstance(inference_config.predict_ds.manifest_filepath, str): + return [inference_config.predict_ds.manifest_filepath ], None + else: + return inference_config.predict_ds.manifest_filepath, None + + + +def create_transcribed_shard_manifests( + prediction_filepaths: List[str], +) -> List[str]: + """ + Processes prediction files and generates transcribed shard manifests. + + This function reads prediction JSON files grouped by `shard_id` from + specified directories, organizes the entries by shard, and writes the + results to new JSON manifest files. + + Args: + prediction_filepaths (List[str]): A list of filepaths to directories + containing prediction JSON files (named like `predictions_[0-9]*.json`). + + Returns: + List[str]: A list of filepaths to the created manifest files. + """ + all_manifest_filepaths = [] + for prediction_filepath in prediction_filepaths: + max_shard_id = 0 + shard_data = {} + for full_path in glob.glob(os.path.join(prediction_filepath, "predictions_[0-9]*.json")): + # Collect data based on their shard id + with open(full_path, 'r') as f: + for line in f.readlines(): + data_entry = json.loads(line) + shard_id = data_entry.get("shard_id") + max_shard_id = max(max_shard_id, shard_id) + shard_data.setdefault(shard_id, []).append(data_entry) + + # Write each shard's data to a new JSON file in the output directory + for shard_id, entries in shard_data.items(): + output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json") + with open(output_filename, 'w') as f: + for data_entry in entries: + if data_entry['audio_filepath'].endswith(".wav"): + json.dump(data_entry, f, ensure_ascii=False) + f.write("\n") + shard_manifest_filepath = os.path.join( + prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json" + ) + + all_manifest_filepaths.append([shard_manifest_filepath]) + return all_manifest_filepaths + + +def create_transcribed_manifests( + prediction_filepaths: List[str], +) -> List[str]: + """ + Renames prediction files to 'transcribed_manifest.json' for each directory + and returns a list of the new file paths. + Args: + prediction_filepaths (List[str]): A list of file paths to directories + containing the 'predictions_all.json' file. + Returns: + List[str]: A list of file paths to the renamed 'transcribed_manifest.json' files. + """ + all_manifest_filepaths = [] + + for prediction_filepath in prediction_filepaths: + prediction_name = os.path.join(prediction_filepath, "predictions_all.json") + transcripted_name = os.path.join(prediction_filepath, "transcribed_manifest.json") + + os.rename(prediction_name, transcripted_name) + all_manifest_filepaths.append(transcripted_name) + + return all_manifest_filepaths + + +def write_sampled_shard_transcriptions(manifest_filepaths: List[str]) -> List[List[str]]: + """ + Updates transcriptions by merging predicted shard data and transcribed manifest data. + + This function processes prediction and transcribed manifest files, merges them + by matching the shard_id and audio file paths. For each shard, the corresponding + data entries are written to a new file. + + Args: + manifest_filepaths (List[str]): A list of file paths to directories containing + prediction and transcribed manifest files. + + Returns: + List[List[str]]: A list of lists containing the file paths to the generated + transcribed shard manifest files. + """ + all_manifest_filepaths = [] + + # Process each prediction directory + for prediction_filepath in manifest_filepaths: + predicted_shard_data = {} + + # Collect entries from prediction files based on shard id + for prediction_path in glob.glob(os.path.join(prediction_filepath, "predictions_[0-9]*.json")): + with open(prediction_path, 'r') as f: + for line in f: + data_entry = json.loads(line) + shard_id = data_entry.get("shard_id") + audio_filepath = data_entry['audio_filepath'] + predicted_shard_data.setdefault(shard_id, {})[audio_filepath] = data_entry + + # Collect entries from transcribed manifest files + all_data_entries = [] + max_shard_id = 0 + + for full_path in glob.glob(os.path.join(prediction_filepath, "transcribed_manifest_[0-9]*.json")): + + with open(full_path, 'r') as f: + for line in f: + data_entry = json.loads(line) + shard_id = data_entry.get("shard_id") + max_shard_id = max(max_shard_id, shard_id) + all_data_entries.append(data_entry) + + # Write the merged data to a new manifest file keeping new transcriptions + output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json") + with open(output_filename, 'w') as f: + for data_entry in all_data_entries: + audio_filepath = data_entry['audio_filepath'] + # Escape duplicated audio files that end with *dup + if audio_filepath.endswith(".wav"): + if shard_id in predicted_shard_data and audio_filepath in predicted_shard_data[shard_id]: + predicted_data_entry = predicted_shard_data[shard_id][audio_filepath] + json.dump(predicted_data_entry, f, ensure_ascii=False) + else: + json.dump(data_entry, f, ensure_ascii=False) + f.write("\n") + + shard_manifest_filepath = os.path.join( + prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json" + ) + all_manifest_filepaths.append([shard_manifest_filepath]) + + return all_manifest_filepaths + + +def write_sampled_transcriptions(manifest_filepaths: List[str]) -> List[str]: + """ + Updates transcriptions by merging predicted data with transcribed manifest data. + + + This function processes prediction files and transcribed manifest files, merging + them by matching audio file paths. The merged data is then written to a new file. + + Args: + manifest_filepaths (List[str]): A list of file paths to directories containing + prediction and transcribed manifest files. + Returns: + List[str]: A list of file paths to the generated transcribed manifest files. + """ + all_manifest_filepaths = [] + + # Process each prediction directory + for prediction_filepath in manifest_filepaths: + predicted_data = {} + + # Collect entries from prediction files + prediction_path = os.path.join(prediction_filepath, "predictions_all.json") + with open(prediction_path, 'r') as f: + for line in f: + data_entry = json.loads(line) + path = data_entry['audio_filepath'] + predicted_data[path] = data_entry + + # Collect entries from transcribed manifest file + transcribed_manifest_path = os.path.join(prediction_filepath, "transcribed_manifest.json") + all_data_entries = [] + with open(transcribed_manifest_path, 'r') as f: + for line in f: + data_entry = json.loads(line) + all_data_entries.append(data_entry) + + # Merge predicted data with transcribed data and write to a new file + output_filename = os.path.join(prediction_filepath, "transcribed_manifest.json") + with open(output_filename, 'w') as f: + for data_entry in all_data_entries: + audio_filepath = data_entry['audio_filepath'] + if audio_filepath in predicted_data: + predicted_data_entry = predicted_data[audio_filepath] + json.dump(predicted_data_entry, f, ensure_ascii=False) + else: + json.dump(data_entry, f, ensure_ascii=False) + f.write("\n") + + all_manifest_filepaths.append(output_filename) + + return all_manifest_filepaths + +def update_training_sets( + config: OmegaConf, + final_cache_manifests: list, + updated_tarred_audio_filepaths: Union[list, str], + prefix:str = None +) -> Tuple[str, str]: + + """ + Updates the training dataset configuration by adding pseudo-labeled datasets + to the training paths based on the dataset type. + Args: + config (DictConfig): Training config file to be updated. + updated_manifest_filepaths (List[str]): List of updated manifest file paths to be included. + updated_tarred_audio_filepaths (Optional[List[str]]): List of updated tarred audio filepaths to be included. + Returns: + Tuple[str, str]: A tuple containing: + - Updated manifest file paths as a string, formatted for Omegaconf. + - Updated tarred audio file paths as a string, formatted for Omegaconf. + """ + updated_manifest_filepaths = final_cache_manifests + manifest_filepath = config.model.train_ds.manifest_filepath + + if updated_tarred_audio_filepaths: + updated_tarred_audio_filepaths = [[path] for path in updated_tarred_audio_filepaths] + + # Updating the configuration based on dataset types + if config.model.train_ds.get("is_tarred", False): + tarred_audio_filepaths = config.model.train_ds.tarred_audio_filepaths + if isinstance(tarred_audio_filepaths, str): + updated_tarred_audio_filepaths.append([tarred_audio_filepaths]) + updated_manifest_filepaths.append([manifest_filepath]) + else: + updated_tarred_audio_filepaths += tarred_audio_filepaths + updated_manifest_filepaths += manifest_filepath + else: + if config.model.train_ds.get("use_lhotse", False): + if isinstance(manifest_filepath, str): + updated_manifest_filepaths.append([manifest_filepath]) + else: + updated_manifest_filepaths += manifest_filepath + else: + if isinstance(manifest_filepath, str): + updated_manifest_filepaths.append(manifest_filepath) + else: + updated_manifest_filepaths += manifest_filepath + + # Returning strings formatted for Omegaconf + return ( + str(updated_manifest_filepaths).replace(", ", ","), + str(updated_tarred_audio_filepaths).replace(", ", ",") if updated_tarred_audio_filepaths else None, + ) + + +def count_files_for_pseudo_labeling(manifest_filepath: str, is_tarred: bool) -> int: + """ + Counts the number of files for pseudo-labeling. + + Args: + manifest_filepath (str): The path to the manifest file(s). + is_tarred (bool): Flag to determine whether to count files for multiple shard manifests. + + Returns: + int: The total number of audio files given for pseudo labeling. + """ + if is_tarred: + dir_path, filename = os.path.split(manifest_filepath) + prefix = filename.split('_', 1)[0] + number_of_files = 0 + for full_path in glob.glob(os.path.join(dir_path, f"{prefix}_[0-9]*.json")): + with open(full_path, 'r') as f: + number_of_files += len(f.readlines()) + else: + with open(manifest_filepath, 'r') as f: + number_of_files = len(f.readlines()) + + return number_of_files diff --git a/nemo/collections/common/callbacks/ipl_epoch_stopper.py b/nemo/collections/common/callbacks/ipl_epoch_stopper.py new file mode 100644 index 000000000000..b83e24aff3be --- /dev/null +++ b/nemo/collections/common/callbacks/ipl_epoch_stopper.py @@ -0,0 +1,45 @@ + +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.core import LightningModule + + +class IPLEpochStopper(Callback): + r""" + Gracefully terminates training at the *end* of an epoch. + This is done to generate pseudo-labels dynamically. + enable_stop : bool, default=False + If ``True`` the callback will request a stop in + :py:meth:`on_train_epoch_end`. If ``False`` it is inert. + """ + + def __init__(self, enable_stop: bool = False, stop_every_n_epochs: int = 1) -> None: + super().__init__() + self.enable_stop = bool(enable_stop) + self.stop_every_n_epochs = stop_every_n_epochs + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """ + Sets `should_stop` stop flag to terminate the training. + """ + super().__init__() + + if self.stop_every_n_epochs != 0: + self.stop_every_n_epochs -= 1 + if self.stop_every_n_epochs == 0: + trainer.should_stop = True + diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index fc082ecf5831..786d2f08790b 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -32,6 +32,8 @@ from hydra.utils import get_original_cwd from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.callbacks.early_stopping import EarlyStopping +from nemo.collections.common.callbacks.ipl_epoch_stopper import IPLEpochStopper + from lightning.pytorch.callbacks.timer import Interval, Timer from lightning.pytorch.loggers import MLFlowLogger, NeptuneLogger, TensorBoardLogger, WandbLogger from lightning.pytorch.loops import _TrainingEpochLoop @@ -157,6 +159,12 @@ class CallbackParams: multistorageclient_enabled: Optional[bool] = False +@dataclass +class IPLEpochStopperParams: + """IPLEpochStopperParams POD""" + # Flag that allows stopping + enable_stop: bool = True + stop_every_n_epochs: int = 1 @dataclass class StepTimingParams: """StepTimingParams POD""" @@ -243,9 +251,13 @@ class ExpManagerConfig: create_checkpoint_callback: Optional[bool] = True checkpoint_callback_params: Optional[CallbackParams] = field(default_factory=lambda: CallbackParams()) create_early_stopping_callback: Optional[bool] = False + create_ipl_epoch_stopper_callback: Optional[bool] = False early_stopping_callback_params: Optional[EarlyStoppingParams] = field( default_factory=lambda: EarlyStoppingParams() ) + ipl_epoch_stopper_callback_params: Optional[IPLEpochStopperParams] = field( + default_factory=lambda: IPLEpochStopperParams() + ) create_preemption_callback: Optional[bool] = True # Additional exp_manager arguments files_to_copy: Optional[List[str]] = None @@ -707,6 +719,9 @@ def exp_manager(trainer: 'lightning.pytorch.Trainer', cfg: Optional[Union[DictCo if cfg.create_early_stopping_callback: early_stop_callback = EarlyStopping(**cfg.early_stopping_callback_params) trainer.callbacks.append(early_stop_callback) + if cfg.create_ipl_epoch_stopper_callback: + ipl_epoch_stopper_callback = IPLEpochStopper(**cfg.ipl_epoch_stopper_callback_params) + trainer.callbacks.append(ipl_epoch_stopper_callback) if cfg.create_checkpoint_callback: configure_checkpointing(