diff --git a/biapy/config/config.py b/biapy/config/config.py index 6add747a..1d7e5137 100644 --- a/biapy/config/config.py +++ b/biapy/config/config.py @@ -1412,6 +1412,178 @@ def __init__(self, job_dir: str, job_identifier: str): # _C.MODEL.TORCHVISION_MODEL_NAME = "" + # + # BIAPY BACKEND MODELS + # + # Architecture of the network. Possible values are: + # * Semantic segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' + # * Instance segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' + # * Detection: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' + # * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2', 'nafnet' + # * Super-resolution: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2' + # * Self-supervision: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2' + # * Classification: 'simple_cnn', 'vit', 'efficientnet_b[0-7]' (only 2D) + # * Image to image: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2' + _C.MODEL.ARCHITECTURE = "unet" + # Number of feature maps on each level of the network. + _C.MODEL.FEATURE_MAPS = [16, 32, 64, 128, 256] + # Values to make the dropout with. Set to 0 to prevent dropout. When using it with 'ViT' or 'unetr' + # a list with just one number must be provided + _C.MODEL.DROPOUT_VALUES = [0.0, 0.0, 0.0, 0.0, 0.0] + # Normalization layer (one of 'bn', 'sync_bn' 'in', 'gn' or 'none'). + _C.MODEL.NORMALIZATION = "bn" + # Kernel size + _C.MODEL.KERNEL_SIZE = 3 + # Upsampling layer to use in the model. Options: ["upsampling", "convtranspose"] + _C.MODEL.UPSAMPLE_LAYER = "convtranspose" + # Activation function to use along the model + _C.MODEL.ACTIVATION = "ELU" + # Number of classes including the background class (that should be using 0 label) + _C.DATA.N_CLASSES = 2 + # Downsampling to be made in Z. This value will be the third integer of the MaxPooling operation. When facing + # anysotropic datasets set it to get better performance + _C.MODEL.Z_DOWN = [0, 0, 0, 0] + # For each level of the model (U-Net levels), set to true or false if the dimensions of the feature maps are isotropic. + _C.MODEL.ISOTROPY = [True, True, True, True, True] + # Include extra convolutional layers with larger kernel at the beginning and end of the U-Net-like model. + _C.MODEL.LARGER_IO = False + # Checkpoint: set to True to load previous training weigths (needed for inference or to make fine-tunning) + _C.MODEL.LOAD_CHECKPOINT = False + # When loading checkpoints whether only model's weights are going to be loaded or optimizer, epochs and loss_scaler. + _C.MODEL.LOAD_CHECKPOINT_ONLY_WEIGHTS = True + # Decide which checkpoint to load from job's dir if PATHS.CHECKPOINT_FILE is ''. + # Options: 'best_on_val' or 'last_on_train' + _C.MODEL.LOAD_CHECKPOINT_EPOCH = "best_on_val" + # Whether to load the model from the checkpoint instead of builiding it following 'MODEL.ARCHITECTURE' when 'MODEL.SOURCE' is "biapy" + _C.MODEL.LOAD_MODEL_FROM_CHECKPOINT = True + # Format of the output checkpoint. Options are 'pth' (native PyTorch format) or 'safetensors' (https://github.com/huggingface/safetensors) + _C.MODEL.OUT_CHECKPOINT_FORMAT = "pth" + # To skip loading those layers that do not match in shape with the given checkpoint. If this is set to False a regular load function will be + # done, which will fail if a layer mismatch is found. Only works when 'MODEL.LOAD_MODEL_FROM_CHECKPOINT' is True + _C.MODEL.SKIP_UNMATCHED_LAYERS = False + # Epochs to save a checkpoint of the model apart from the ones saved with LOAD_CHECKPOINT_ONLY_WEIGHTS. Set it to -1 to + # not do it. + _C.MODEL.SAVE_CKPT_FREQ = -1 + # Number of ConvNeXtBlocks in each level. + _C.MODEL.CONVNEXT_LAYERS = [2, 2, 2, 2, 2] # CONVNEXT_LAYERS + # Maximum Stochastic Depth probability for the U-NeXt model. + _C.MODEL.CONVNEXT_SD_PROB = 0.1 + # Layer Scale parameter for the U-NeXt model. + _C.MODEL.CONVNEXT_LAYER_SCALE = 1e-6 + # Size of the stem kernel in the U-NeXt model. + _C.MODEL.CONVNEXT_STEM_K_SIZE = 2 + + # TRANSFORMERS MODELS + # Type of model. Options are "custom", "vit_base_patch16", "vit_large_patch16" and "vit_huge_patch16". On custom setting + # the rest of the ViT parameters can be modified as other options will set them automatically. + _C.MODEL.VIT_MODEL = "custom" + # Size of the patches that are extracted from the input image. + _C.MODEL.VIT_TOKEN_SIZE = 16 + # Dimension of the embedding space + _C.MODEL.VIT_EMBED_DIM = 768 + # Number of transformer encoder layers + _C.MODEL.VIT_NUM_LAYERS = 12 + # Number of heads in the multi-head attention layer. + _C.MODEL.VIT_NUM_HEADS = 12 + # Size of the dense layers of the final classifier. This value will mutiply 'VIT_EMBED_DIM' + _C.MODEL.VIT_MLP_RATIO = 4.0 + # Normalization layer epsion + _C.MODEL.VIT_NORM_EPS = 1e-6 + + # Dimension of the embedding space for the MAE decoder + _C.MODEL.MAE_DEC_HIDDEN_SIZE = 512 + # Number of transformer decoder layers + _C.MODEL.MAE_DEC_NUM_LAYERS = 8 + # Number of heads in the multi-head attention layer. + _C.MODEL.MAE_DEC_NUM_HEADS = 16 + # Size of the dense layers of the final classifier + _C.MODEL.MAE_DEC_MLP_DIMS = 2048 + # Type of the masking strategy. Options: ["grid", "random"] + _C.MODEL.MAE_MASK_TYPE = "grid" + # Percentage of the input image to mask (applied only when MODEL.MAE_MASK_TYPE == "random"). Value between 0 and 1. + _C.MODEL.MAE_MASK_RATIO = 0.5 + + # UNETR + # Multiple of the transformer encoder layers from of which the skip connection signal is going to be extracted + _C.MODEL.UNETR_VIT_HIDD_MULT = 3 + # Number of filters in the first UNETR's layer of the decoder. In each layer the previous number of filters is doubled. + _C.MODEL.UNETR_VIT_NUM_FILTERS = 16 + # Decoder activation + _C.MODEL.UNETR_DEC_ACTIVATION = "relu" + # Decoder convolutions' kernel size + _C.MODEL.UNETR_DEC_KERNEL_SIZE = 3 + + # Specific for SR models based on U-Net architectures. Options are ["pre", "post"] + _C.MODEL.UNET_SR_UPSAMPLE_POSITION = "pre" + + # RCAN + # Number of RG modules + _C.MODEL.RCAN_RG_BLOCK_NUM = 10 + # Number of RCAB modules in each RG block + _C.MODEL.RCAN_RCAB_BLOCK_NUM = 20 + # Filters in the convolutions + _C.MODEL.RCAN_CONV_FILTERS = 16 + # Channel reduction ratio for channel attention + _C.MODEL.RCAN_REDUCTION_RATIO = 16 + # Whether to maintain or not the upscaling layer. + _C.MODEL.RCAN_UPSCALING_LAYER = True + + # These parameters can be used as a template for building custom HRNet versions + _C.MODEL.HRNET = CN() + # Whether to downsample the input in Z or not + _C.MODEL.HRNET.Z_DOWN = True + # Type of block to use in HRNet. Options: 'BASIC', 'BOTTLENECK', 'CONVNEXT_V1' and 'CONVNEXT_V2' + _C.MODEL.HRNET.BLOCK_TYPE = 'BASIC' + # Indicate whether to use a custom configuration for HRNet or use a predefined one. If set to True + # MODEL.HRNET.STAGE2, MODEL.HRNET.STAGE3 and MODEL.HRNET.STAGE4 will be used. If False, the configuration + # will be set depending on the selected architecture (see PROBLEM.MODEL_ARCHITECTURE) + _C.MODEL.HRNET.HEAD_TYPE = "FCN" # Options: "OCR", "ASPP", "PSP", "FCN" + _C.MODEL.HRNET.CUSTOM = False + + # These stages are used for HRNet18, HRNet32, HRNet48 and HRNet64 + _C.MODEL.HRNET.STAGE2 = CN() + _C.MODEL.HRNET.STAGE2.NUM_MODULES = 1 + _C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2 + _C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4] + _C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [18, 36] + _C.MODEL.HRNET.STAGE3 = CN() + _C.MODEL.HRNET.STAGE3.NUM_MODULES = 4 + _C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3 + _C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4] + _C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [18, 36, 72] + _C.MODEL.HRNET.STAGE4 = CN() + _C.MODEL.HRNET.STAGE4.NUM_MODULES = 3 + _C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4 + _C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] + _C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [18, 36, 72, 144] + + _C.MODEL.STUNET = CN() + # Variant of the STUNet model. Options are: 'small', 'base', 'large' + _C.MODEL.STUNET.VARIANT = 'base' + # Whether to use a pretrained version of STUNet on ImageNet + _C.MODEL.STUNET.PRETRAINED = False + + # NafNet + _C.MODEL.NAFNET = CN() + # Base number of channels (width) used in the first layer and base levels. + _C.MODEL.NAFNET.WIDTH = 16 + # Number of NAFBlocks stacked at the bottleneck (deepest level). + _C.MODEL.NAFNET.MIDDLE_BLK_NUM = 12 + # Number of NAFBlocks assigned to each downsampling level of the encoder. + _C.MODEL.NAFNET.ENC_BLK_NUMS = [2, 2, 4, 8] + # Number of NAFBlocks assigned to each upsampling level of the decoder. + _C.MODEL.NAFNET.DEC_BLK_NUMS = [2, 2, 2, 2] + # Channel expansion factor for the depthwise convolution within the gating unit. + _C.MODEL.NAFNET.DW_EXPAND = 2 + # Expansion factor for the hidden layer within the feed-forward network. + _C.MODEL.NAFNET.FFN_EXPAND = 2 + # Discriminator architecture + _C.MODEL.NAFNET.ARCHITECTURE_D = "patchgan" + # Discriminator PATCHGAN + _C.MODEL.NAFNET.PATCHGAN = CN() + # Number of initial convolutional filters in the first layer of the discriminator. + _C.MODEL.NAFNET.PATCHGAN.BASE_FILTERS = 64 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 6. Loss definition options # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1474,6 +1646,20 @@ def __init__(self, job_dir: str, job_identifier: str): _C.LOSS.CONTRAST.PROJ_DIM = 256 _C.LOSS.CONTRAST.PIXEL_UPD_FREQ = 10 + # Fine-grained GAN composition. Set any weight to 0.0 to disable that term. + # Used when LOSS.TYPE == "CYCLEGAN". + _C.LOSS.CYCLEGAN = CN() + # Weight for adversarial BCE term. + _C.LOSS.CYCLEGAN.LAMBDA_GAN = 1.0 + # Weight for L1 reconstruction term. + _C.LOSS.CYCLEGAN.LAMBDA_RECON = 10.0 + # Weight for MSE reconstruction term. + _C.LOSS.CYCLEGAN.DELTA_MSE = 0.0 + # Weight for VGG perceptual term. + _C.LOSS.CYCLEGAN.ALPHA_PERCEPTUAL = 0.0 + # Weight for SSIM term. + _C.LOSS.CYCLEGAN.GAMMA_SSIM = 1.0 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 7. Training phase options # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1481,19 +1667,16 @@ def __init__(self, job_dir: str, job_identifier: str): _C.TRAIN.ENABLE = False # Enable verbosity _C.TRAIN.VERBOSE = False - # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" - _C.TRAIN.OPTIMIZER = "SGD" - # Learning rate - _C.TRAIN.LR = 1.0e-4 + # Optimizer(s) to use. Possible values: "SGD", "ADAM" or "ADAMW". + _C.TRAIN.OPTIMIZER = ["SGD"] + # Learning rate(s). + _C.TRAIN.LR = [1.0e-4] # Weight decay _C.TRAIN.W_DECAY = 0.02 # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers - _C.TRAIN.OPT_BETAS = (0.9, 0.999) + _C.TRAIN.OPT_BETAS = [[0.9, 0.999]] # Batch size _C.TRAIN.BATCH_SIZE = 2 - # If memory or # gpus is limited, use this variable to maintain the effective batch size, which is - # batch_size (per gpu) * nodes * (gpus per node) * accum_iter. - _C.TRAIN.ACCUM_ITER = 1 # Number of epochs to train the model _C.TRAIN.EPOCHS = 360 # Epochs to wait with no validation data improvement until the training is stopped @@ -1509,6 +1692,9 @@ def __init__(self, job_dir: str, job_identifier: str): # * Classification: 'accuracy', 'top-5-accuracy' # * Image to image: "psnr", "mae", "mse", "ssim" _C.TRAIN.METRICS = [] + + # Gradient clipping max norm applied per optimizer. 0 = disabled. + _C.TRAIN.GRADIENT_CLIP_NORM = 0.0 # Callbacks # To determine which value monitor to consider which epoch consider the best to save. Currently not used. @@ -1526,7 +1712,7 @@ def __init__(self, job_dir: str, job_identifier: str): _C.TRAIN.LR_SCHEDULER = CN() _C.TRAIN.LR_SCHEDULER.NAME = "" # Possible options: 'warmupcosine', 'reduceonplateau', 'onecycle' # Lower bound on the learning rate used in 'warmupcosine' and 'reduceonplateau' - _C.TRAIN.LR_SCHEDULER.MIN_LR = -1.0 + _C.TRAIN.LR_SCHEDULER.MIN_LR = [-1.0] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 7.1.1 Reduce on plateau options diff --git a/biapy/data/generators/__init__.py b/biapy/data/generators/__init__.py index 0a352233..8b0cdfa6 100644 --- a/biapy/data/generators/__init__.py +++ b/biapy/data/generators/__init__.py @@ -251,7 +251,7 @@ def create_train_val_augmentors( dic["zflip"] = cfg.AUGMENTOR.ZFLIP if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.MODEL.ARCHITECTURE != 'nafnet': dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR @@ -297,7 +297,7 @@ def create_train_val_augmentors( ) if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.MODEL.ARCHITECTURE != 'nafnet': dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR @@ -317,7 +317,7 @@ def create_train_val_augmentors( ) # Training dataset - total_batch_size = cfg.TRAIN.BATCH_SIZE * get_world_size() * cfg.TRAIN.ACCUM_ITER + total_batch_size = cfg.TRAIN.BATCH_SIZE * get_world_size() training_samples = len(train_generator) # ---- Choose num_workers for this DataLoader ---- @@ -352,7 +352,6 @@ def worker_init_fn(worker_id): num_training_steps_per_epoch = training_samples // total_batch_size print(f"Train/val generators with {num_workers} workers") - print("Accumulate grad iterations: %d" % cfg.TRAIN.ACCUM_ITER) print("Effective batch size: %d" % total_batch_size) print("Sampler_train = %s" % str(sampler_train)) train_dataset = DataLoader( diff --git a/biapy/engine/__init__.py b/biapy/engine/__init__.py index 8991943f..21ef97c6 100644 --- a/biapy/engine/__init__.py +++ b/biapy/engine/__init__.py @@ -21,7 +21,7 @@ def prepare_optimizer( cfg: CN, model_without_ddp: nn.Module | nn.parallel.DistributedDataParallel, steps_per_epoch: int, -) -> Tuple[Optimizer, Scheduler | None]: +) -> Tuple[list[Optimizer], list[Scheduler | None]]: """ Create and configure the optimizer and learning rate scheduler for the given model. @@ -40,50 +40,63 @@ def prepare_optimizer( Returns ------- - optimizer : Optimizer - Configured optimizer for the model. - lr_scheduler : Scheduler or None - Configured learning rate scheduler, or None if not specified. + optimizers : List[Optimizer] + Configured optimizers for the models. + lr_schedulers : List[Scheduler | None] + Configured learning rate schedulers, or None if not specified. """ - lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR - opt_args = {} - if cfg.TRAIN.OPTIMIZER in ["ADAM", "ADAMW"]: - opt_args["betas"] = cfg.TRAIN.OPT_BETAS - optimizer = timm.optim.create_optimizer_v2( - model_without_ddp, - opt=cfg.TRAIN.OPTIMIZER, - lr=lr, - weight_decay=cfg.TRAIN.W_DECAY, - **opt_args, - ) - print(optimizer) - - # Learning rate schedulers - lr_scheduler = None - if cfg.TRAIN.LR_SCHEDULER.NAME != "": - if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": - lr_scheduler = ReduceLROnPlateau( - optimizer, - patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, - factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, - min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, - ) - elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": - lr_scheduler = WarmUpCosineDecayScheduler( - lr=cfg.TRAIN.LR, - min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, - warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, - epochs=cfg.TRAIN.EPOCHS, - ) - elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": - lr_scheduler = OneCycleLR( - optimizer, - cfg.TRAIN.LR, - epochs=cfg.TRAIN.EPOCHS, - steps_per_epoch=steps_per_epoch, - ) - - return optimizer, lr_scheduler + + optimizers = [] + lr_schedulers = [] + + if hasattr(model_without_ddp, 'param_groups'): + param_groups = model_without_ddp.param_groups + else: + param_groups = [[p for p in model_without_ddp.parameters()]] + + for i in range(len(cfg.TRAIN.OPTIMIZER)): + lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR + opt_args = {} + if cfg.TRAIN.OPTIMIZER[i] in ["ADAM", "ADAMW"]: + opt_args["betas"] = cfg.TRAIN.OPT_BETAS[i] + optimizer = timm.optim.create_optimizer_v2( + param_groups[i], + opt=cfg.TRAIN.OPTIMIZER[i], + lr=lr[i], + weight_decay=cfg.TRAIN.W_DECAY, + **opt_args, + ) + print(optimizer) + optimizers.append(optimizer) + + # Learning rate schedulers + lr_scheduler = None + if cfg.TRAIN.LR_SCHEDULER.NAME != "": + if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": + lr_scheduler = ReduceLROnPlateau( + optimizer, + patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, + factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, + min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR[i], + ) + elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": + lr_scheduler = WarmUpCosineDecayScheduler( + lr=cfg.TRAIN.LR[i], + min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR[i], + warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, + epochs=cfg.TRAIN.EPOCHS, + ) + elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler = OneCycleLR( + optimizer, + cfg.TRAIN.LR[i], + epochs=cfg.TRAIN.EPOCHS, + steps_per_epoch=steps_per_epoch, + ) + + lr_schedulers.append(lr_scheduler) + + return optimizers, lr_schedulers def build_callbacks(cfg: CN) -> EarlyStopping | None: diff --git a/biapy/engine/base_workflow.py b/biapy/engine/base_workflow.py index 43ec47c7..dd690188 100644 --- a/biapy/engine/base_workflow.py +++ b/biapy/engine/base_workflow.py @@ -231,6 +231,7 @@ def __init__( self.gt_channels_expected = -1 self.train_metrics_message = "" self.test_metrics_message = "" + self.loss_names = ["loss"] self.resolution: List[int | float] = list(self.cfg.DATA.TEST.RESOLUTION) if self.cfg.PROBLEM.NDIM == "2D": @@ -893,8 +894,9 @@ def prepare_logging_tool(self): self.log_writer = None self.plot_values = {} - self.plot_values["loss"] = [] - self.plot_values["val_loss"] = [] + for loss_name in self.loss_names: + self.plot_values[loss_name] = [] + self.plot_values[f"val_{loss_name}"] = [] for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]] = [] self.plot_values["val_" + self.train_metric_names[i]] = [] @@ -969,6 +971,7 @@ def train(self): memory_bank=self.memory_bank, total_iters=total_iters, contrast_warmup_iters=contrast_init_iter, + loss_names=self.loss_names, ) total_iters += iterations_done @@ -1003,6 +1006,7 @@ def train(self): data_loader=self.val_generator, lr_scheduler=self.lr_scheduler, memory_bank=self.memory_bank, + loss_names=self.loss_names, ) # Save checkpoint is val loss improved @@ -1065,9 +1069,10 @@ def train(self): f.write(json.dumps(log_stats) + "\n") # Create training plot - self.plot_values["loss"].append(train_stats["loss"]) - if self.val_generator: - self.plot_values["val_loss"].append(test_stats["loss"]) + for loss_name in self.loss_names: + self.plot_values[loss_name].append(train_stats[loss_name]) + if self.val_generator: + self.plot_values[f"val_{loss_name}"].append(test_stats[loss_name]) for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]].append(train_stats[self.train_metric_names[i]]) if self.val_generator: @@ -1078,6 +1083,7 @@ def train(self): create_plots( self.plot_values, self.train_metric_names, + self.loss_names, self.job_identifier, self.cfg.PATHS.CHARTS, ) diff --git a/biapy/engine/check_configuration.py b/biapy/engine/check_configuration.py index ee1fe576..f69b653e 100644 --- a/biapy/engine/check_configuration.py +++ b/biapy/engine/check_configuration.py @@ -1242,7 +1242,7 @@ def sort_key(item): assert sum(cfg.LOSS.WEIGHTS) == 1, "'LOSS.WEIGHTS' values need to sum 1" elif cfg.PROBLEM.TYPE == "DENOISING": loss = "MSE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE - assert loss == "MSE", "LOSS.TYPE must be 'MSE'" + assert loss in ["MSE", "CYCLEGAN"], "LOSS.TYPE must be in ['MSE', 'CYCLEGAN'] for DENOISING" elif cfg.PROBLEM.TYPE == "CLASSIFICATION": loss = "CE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE assert loss == "CE", "LOSS.TYPE must be 'CE'" @@ -1803,12 +1803,19 @@ def sort_key(item): #### Denoising #### elif cfg.PROBLEM.TYPE == "DENOISING": - if cfg.DATA.TEST.LOAD_GT: - raise ValueError( - "Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'" - ) - if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX): - raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range") + if cfg.PROBLEM.DENOISING.LOAD_GT_DATA or cfg.LOSS.TYPE == "CYCLEGAN": + if not cfg.DATA.TRAIN.GT_PATH and not cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA: + raise ValueError( + "Supervised denoising (e.g., with CYCLEGAN or LOAD_GT_DATA=True) " + "requires ground truth. 'DATA.TRAIN.GT_PATH' must be provided." + ) + else: + if cfg.DATA.TEST.LOAD_GT: + raise ValueError( + "Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'" + ) + if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX): + raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range") if cfg.MODEL.SOURCE == "torchvision": raise ValueError("'MODEL.SOURCE' as 'torchvision' is not available in denoising workflow") @@ -2348,6 +2355,7 @@ def sort_key(item): "unext_v2", "hrnet", "stunet", + "nafnet", ], "MODEL.ARCHITECTURE not in ['unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'simple_cnn', 'efficientnet_b[0-7]', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2', 'hrnet', 'stunet']" if ( model_arch @@ -2566,6 +2574,7 @@ def sort_key(item): "unext_v2", "hrnet", "stunet", + "nafnet", ]: raise ValueError( "Architectures available for {} are: ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet', 'resunet_se', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2', 'hrnet', 'stunet']".format( @@ -2771,11 +2780,53 @@ def sort_key(item): assert cfg.MODEL.OUT_CHECKPOINT_FORMAT in ["pth", "safetensors"], "MODEL.OUT_CHECKPOINT_FORMAT not in ['pth', 'safetensors']" ### Train ### - assert cfg.TRAIN.OPTIMIZER in [ - "SGD", - "ADAM", - "ADAMW", - ], "TRAIN.OPTIMIZER not in ['SGD', 'ADAM', 'ADAMW']" + ## Optimizers ## + if not isinstance(cfg.TRAIN.OPTIMIZER, list): + raise ValueError("'TRAIN.OPTIMIZER' must be a list") + if cfg.MODEL.ARCHITECTURE in ['nafnet'] and cfg.MODEL.NAFNET.ARCHITECTURE_D != "": + if len(cfg.TRAIN.OPTIMIZER) != 2: + raise ValueError( + f"Configuration mismatch: You requested {len(cfg.TRAIN.OPTIMIZER)} optimizers, " + f"but the model has 2 parameter group(s). " + f"Check your TRAIN.OPTIMIZER list in the config." + ) + elif len(cfg.TRAIN.OPTIMIZER) > 1: + raise ValueError( + "Multiple optimizers were provided but no discriminator architecture is configured. " + "Either set a discriminator (e.g. 'MODEL.NAFNET.ARCHITECTURE_D') or reduce 'TRAIN.OPTIMIZER' to a single entry." + ) + for opt in cfg.TRAIN.OPTIMIZER: + if opt not in ["SGD", "ADAM", "ADAMW"]: + raise ValueError("'TRAIN.OPTIMIZER' values must be in ['SGD', 'ADAM', 'ADAMW']") + + ## LR ## + if not isinstance(cfg.TRAIN.LR, list): + raise ValueError("'TRAIN.LR' must be a list") + if len(cfg.TRAIN.OPTIMIZER) != len(cfg.TRAIN.LR): + raise ValueError("'TRAIN.OPTIMIZER' and 'TRAIN.LR' must have the same length") + + ## Betas ## + if not isinstance(cfg.TRAIN.OPT_BETAS, list): + raise ValueError("'TRAIN.OPT_BETAS' must be a list") + for idx, beta_pair in enumerate(cfg.TRAIN.OPT_BETAS): + if isinstance(beta_pair, str): + raise ValueError( + f"Config Error in 'TRAIN.OPT_BETAS': Found a string '{beta_pair}'. " + f"You must use nested square brackets `[]`. " + f"Change it to: [[0.9, 0.999]]" + ) + if not isinstance(beta_pair, list): + raise ValueError( + f"Config Error: Each item in 'TRAIN.OPT_BETAS' must be a list. " + f"Got {type(beta_pair).__name__} at index {idx}." + ) + if len(cfg.TRAIN.OPT_BETAS) not in [1, len(cfg.TRAIN.OPTIMIZER)]: + raise ValueError("'TRAIN.OPT_BETAS' must have length 1 or match 'TRAIN.OPTIMIZER' length") + if len(cfg.TRAIN.OPT_BETAS) == 1 and len(cfg.TRAIN.OPTIMIZER) > 1: + cfg.TRAIN.OPT_BETAS = cfg.TRAIN.OPT_BETAS * len(cfg.TRAIN.OPTIMIZER) + for beta_pair in cfg.TRAIN.OPT_BETAS: + if len(beta_pair) != 2: + raise ValueError("Each entry in 'TRAIN.OPT_BETAS' must be a tuple/list of length 2") if cfg.TRAIN.ENABLE and cfg.TRAIN.LR_SCHEDULER.NAME != "": if cfg.TRAIN.LR_SCHEDULER.NAME not in [ @@ -2784,7 +2835,21 @@ def sort_key(item): "onecycle", ]: raise ValueError("'TRAIN.LR_SCHEDULER.NAME' must be in ['reduceonplateau', 'warmupcosine', 'onecycle']") - if cfg.TRAIN.LR_SCHEDULER.MIN_LR == -1.0 and cfg.TRAIN.LR_SCHEDULER.NAME != "onecycle": + if cfg.TRAIN.LR_SCHEDULER.NAME != "onecycle": + if not isinstance(cfg.TRAIN.LR_SCHEDULER.MIN_LR, list): + raise ValueError("'TRAIN.LR_SCHEDULER.MIN_LR' must be a list") + if len(cfg.TRAIN.LR_SCHEDULER.MIN_LR) not in [1, len(cfg.TRAIN.OPTIMIZER)]: + raise ValueError("'TRAIN.LR_SCHEDULER.MIN_LR' must have length 1 or match 'TRAIN.OPTIMIZER' length") + if len(cfg.TRAIN.LR_SCHEDULER.MIN_LR) == 1 and len(cfg.TRAIN.OPTIMIZER) > 1: + opts.extend(["TRAIN.LR_SCHEDULER.MIN_LR", cfg.TRAIN.LR_SCHEDULER.MIN_LR * len(cfg.TRAIN.OPTIMIZER)]) + if all(x == -1.0 for x in cfg.TRAIN.LR_SCHEDULER.MIN_LR): + raise ValueError( + "'TRAIN.LR_SCHEDULER.MIN_LR' needs to be set when 'TRAIN.LR_SCHEDULER.NAME' is between ['reduceonplateau', 'warmupcosine']" + ) + elif len(cfg.TRAIN.LR_SCHEDULER.MIN_LR) > 1 and len(cfg.TRAIN.LR_SCHEDULER.MIN_LR) != len(cfg.TRAIN.OPTIMIZER): + raise ValueError("'TRAIN.LR_SCHEDULER.MIN_LR' must have length 1 or match 'TRAIN.OPTIMIZER' length") + + if cfg.TRAIN.LR_SCHEDULER.NAME != "onecycle" and all(x == -1.0 for x in cfg.TRAIN.LR_SCHEDULER.MIN_LR): raise ValueError( "'TRAIN.LR_SCHEDULER.MIN_LR' needs to be set when 'TRAIN.LR_SCHEDULER.NAME' is between ['reduceonplateau', 'warmupcosine']" ) @@ -2807,6 +2872,12 @@ def sort_key(item): if cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS > cfg.TRAIN.EPOCHS: raise ValueError("'TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS' needs to be less than 'TRAIN.EPOCHS'") + # Gradient clipping validation + if not isinstance(cfg.TRAIN.GRADIENT_CLIP_NORM, (int, float)): + raise ValueError("'TRAIN.GRADIENT_CLIP_NORM' must be a number") + if cfg.TRAIN.GRADIENT_CLIP_NORM < 0: + raise ValueError("'TRAIN.GRADIENT_CLIP_NORM' must be non-negative (0 to disable)") + #### Augmentation #### if cfg.AUGMENTOR.ENABLE: if not check_value(cfg.AUGMENTOR.DA_PROB): @@ -3000,6 +3071,87 @@ def _assert_list_of_pos_ints(x, ctx): for i, v in enumerate(x): assert isinstance(v, int) and v > 0, f"'{ctx}[{i}]' must be a positive integer" +def compare_configurations_without_model(actual_cfg, old_cfg, header_message="", old_cfg_version=None): + """ + Compare two BiaPy configurations and raise an error if critical workflow variables differ. + + This function checks that key configuration variables (such as problem type, patch size, + number of classes, and data channels) match between the current and previous configuration. + It ignores model-specific parameters and allows for some backward compatibility. + + Parameters + ---------- + actual_cfg : yacs.config.CfgNode + The current configuration object. + old_cfg : yacs.config.CfgNode or dict + The previous configuration object to compare against. + header_message : str, optional + Message to prepend to any error or warning (default: ""). + old_cfg_version : str or None, optional + Version string of the old configuration, for backward compatibility (default: None). + + Raises + ------ + ValueError + If a critical configuration variable does not match and cannot be ignored. + """ + print("Comparing configurations . . .") + + vars_to_compare = [ + "PROBLEM.TYPE", + "PROBLEM.NDIM", + "DATA.PATCH_SIZE", + "PROBLEM.INSTANCE_SEG.DATA_CHANNELS", + "PROBLEM.SUPER_RESOLUTION.UPSCALING", + "DATA.N_CLASSES", + "TRAIN.OPTIMIZER", # yeah not so sure how many + ] + + def get_attribute_recursive(var, attr): + att = attr.split(".") + if len(att) == 1: + return getattr(var, att[0]) + else: + return get_attribute_recursive(getattr(var, att[0]), ".".join(att[1:])) + + # Old configuration translation + dim_count = 2 if old_cfg.PROBLEM.NDIM == "2D" else 3 + # BiaPy version less than 3.5.5 + if old_cfg_version is None: + if isinstance(old_cfg["PROBLEM"]["SUPER_RESOLUTION"]["UPSCALING"], int): + old_cfg["PROBLEM"]["SUPER_RESOLUTION"]["UPSCALING"] = ( + old_cfg["PROBLEM"]["SUPER_RESOLUTION"]["UPSCALING"], + ) * dim_count + + for var_to_compare in vars_to_compare: + current_value = get_attribute_recursive(actual_cfg, var_to_compare) + old_value = get_attribute_recursive(old_cfg, var_to_compare) + if current_value != old_value: + error_message, warning_message = "", "" + if var_to_compare == "DATA.N_CLASSES": + if not actual_cfg.MODEL.SKIP_UNMATCHED_LAYERS: + error_message = header_message \ + + f"The '{var_to_compare}' value of the compared configurations does not match: " \ + + f"{current_value} (current configuration) vs {old_value} (from loaded configuration). " \ + + "If you want to load all weights from the checkpoint that match in shape with your model " \ + + "(e.g., to fine-tune the head), set 'MODEL.SKIP_UNMATCHED_LAYERS' to True." + # Allow SSL pretrainings + elif not (var_to_compare == "PROBLEM.TYPE" and old_value == "SELF_SUPERVISED"): + error_message = header_message \ + + f"The '{var_to_compare}' value of the compared configurations does not match: " \ + + f"{current_value} (current configuration) vs {old_value} (from loaded configuration)" + elif var_to_compare == "DATA.PATCH_SIZE" and any([new for new, old in zip(current_value,old_value) if new < old]): + warning_message = \ + f"WARNING: The 'DATA.PATCH_SIZE' value used for training the model that you are trying to load was {old_value}." \ + + f"It seems that one of the values in your 'DATA.PATCH_SIZE', which is {current_value}, is smaller so may be causing " \ + + "an error during model building process" + + if error_message != "": + raise ValueError( error_message ) + if warning_message != "": + print( warning_message ) + + print("Configurations seem to be compatible. Continuing . . .") def convert_old_model_cfg_to_current_version(old_cfg: dict) -> dict: """ @@ -3018,6 +3170,20 @@ def convert_old_model_cfg_to_current_version(old_cfg: dict) -> dict: new_cfg : dict Updated configuration to the current BiaPy version. """ + if "TRAIN" in old_cfg: + if "OPTIMIZER" in old_cfg["TRAIN"] and isinstance(old_cfg["TRAIN"]["OPTIMIZER"], str): + old_cfg["TRAIN"]["OPTIMIZER"] = [old_cfg["TRAIN"]["OPTIMIZER"]] + if "LR" in old_cfg["TRAIN"] and isinstance(old_cfg["TRAIN"]["LR"], float): + old_cfg["TRAIN"]["LR"] = [old_cfg["TRAIN"]["LR"]] + if "OPT_BETAS" in old_cfg["TRAIN"] and isinstance(old_cfg["TRAIN"]["OPT_BETAS"], str): + clean_str = old_cfg["TRAIN"]["OPT_BETAS"].strip().strip("()") + number_list = [float(x.strip()) for x in clean_str.split(",")] + old_cfg["TRAIN"]["OPT_BETAS"] = [number_list] + if "ACCUM_ITER" in old_cfg["TRAIN"]: + del old_cfg["TRAIN"]["ACCUM_ITER"] + if "LR_SCHEDULER" in old_cfg["TRAIN"]: + if "MIN_LR" in old_cfg["TRAIN"]["LR_SCHEDULER"] and isinstance(old_cfg["TRAIN"]["LR_SCHEDULER"]["MIN_LR"], float): + old_cfg["TRAIN"]["LR_SCHEDULER"]["MIN_LR"] = [old_cfg["TRAIN"]["LR_SCHEDULER"]["MIN_LR"]] * len(old_cfg["TRAIN"]["OPTIMIZER"]) workflow = old_cfg.get("PROBLEM", {}).get("TYPE", "SEMANTIC_SEG") if "TEST" in old_cfg: if "STATS" in old_cfg["TEST"]: diff --git a/biapy/engine/denoising.py b/biapy/engine/denoising.py index fc328bc1..71bde5a0 100644 --- a/biapy/engine/denoising.py +++ b/biapy/engine/denoising.py @@ -26,7 +26,7 @@ from biapy.engine.base_workflow import Base_Workflow from biapy.data.data_manipulation import save_tif from biapy.utils.misc import to_pytorch_format, is_main_process, MetricLogger -from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation +from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation, CycleGanLoss from biapy.data.norm import undo_image_norm class Denoising_Workflow(Base_Workflow): @@ -176,9 +176,40 @@ def define_metrics(self): # print("Overriding 'LOSS.TYPE' to set it to N2V loss (masked MSE)") if self.cfg.LOSS.TYPE == "MSE": self.loss = loss_encapsulation(n2v_loss_mse) + elif self.cfg.LOSS.TYPE == "CYCLEGAN": + self.cyclegan_loss = CycleGanLoss(cfg=self.cfg, device=self.device) + self.loss = self.NAFNetGan_loss_wrapper + if "loss_discriminator" not in self.loss_names: + self.loss_names.append("loss_discriminator") super().define_metrics() + def NAFNetGan_loss_wrapper(self, output, targets): + """Extract pre-computed GAN losses from NAFNet. + + The model computes losses internally via :meth:`NAFNet.forward_loss`, + and this wrapper retrieves them so the training engine never + sees the discriminator. + + Parameters + ---------- + output : torch.Tensor or dict + Model predictions (dict with ``"pred"`` key). + targets : torch.Tensor + Ground-truth images. + + Returns + ------- + tuple + ``(loss_generator, loss_discriminator)``. + """ + if isinstance(output, dict): + pred = output["pred"] + else: + pred = output + loss_g, loss_d = self.model_without_ddp.forward_loss(pred, targets, self.cyclegan_loss) + return {"losses": [loss_g, loss_d]} + def metric_calculation( self, output: NDArray | torch.Tensor, @@ -242,7 +273,13 @@ def metric_calculation( with torch.no_grad(): for i, metric in enumerate(list_to_use): - val = metric(_output.contiguous(), _targets[:, _output.shape[1]:].contiguous()) + # Nafnet for Gan With Supervised + if _targets.shape[1] == _output.shape[1]: + target_for_metric = _targets.contiguous() + # Normal N2Void + else: + target_for_metric = metric(_output.contiguous(), _targets[:, _output.shape[1]:].contiguous()) + val = metric(_output.contiguous(), target_for_metric) val = val.item() if not torch.isnan(val) else 0 out_metrics[list_names_to_use[i]] = val diff --git a/biapy/engine/metrics.py b/biapy/engine/metrics.py index 3aa16ada..c76d9e42 100644 --- a/biapy/engine/metrics.py +++ b/biapy/engine/metrics.py @@ -17,6 +17,8 @@ from pytorch_msssim import SSIM import torch.nn.functional as F import torch.nn as nn +from torchvision import transforms +from torchvision.models import vgg16, VGG16_Weights from typing import Optional, List, Tuple, Dict, Union def jaccard_index_numpy(y_true, y_pred): @@ -446,7 +448,7 @@ def forward(self, inputs, targets): """ if isinstance(inputs, dict): inputs = inputs["pred"] - return self.loss(inputs, targets) + return {"losses": [self.loss(inputs, targets)]} class CrossEntropyLoss_wrapper: """ @@ -2454,4 +2456,228 @@ def forward( loss = loss / B iou = iou / B - return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals \ No newline at end of file + + return { + "losses": [loss + prediction.sum() * 0], + "metrics": {"IoU": float(iou)} + } + +class VGG(nn.Module): + """Perceptual loss based on VGG16 feature activations. + + This loss compares intermediate VGG feature maps of prediction and target + images using an L1 distance. It is commonly used as a perceptual term in + image-to-image GAN training. + + Notes + ----- + - Uses pretrained ``torchvision.models.vgg16`` features up to layer ``:16``. + - Supports both 2D `(B, C, H, W)` and 3D `(B, C, D, H, W)` tensors. + For 3D inputs, depth is folded into batch to reuse 2D VGG. + - Single-channel inputs are replicated to 3 channels before VGG. + + References + ---------- + - Johnson et al., "Perceptual Losses for Real-Time Style Transfer and + Super-Resolution", ECCV 2016. + https://arxiv.org/abs/1603.08155 + - Implementation adapted for this project from: + https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge + """ + + def __init__(self, device): + """Initialize VGG perceptual loss. + + Parameters + ---------- + device : torch.device + Device where VGG features and loss operations are executed. + """ + super().__init__() + self.vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16].eval().to(device) + for param in self.vgg.parameters(): + param.requires_grad = False + self.loss = nn.L1Loss() + self.preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def forward(self, pred, target): + """Compute perceptual distance between prediction and target. + + Parameters + ---------- + pred : torch.Tensor or dict + Predicted image tensor. If dict, prediction is taken from ``pred['pred']``. + target : torch.Tensor or dict + Target image tensor. If dict, target is taken from ``target['pred']``. + + Returns + ------- + torch.Tensor + Scalar perceptual loss value (L1 over VGG features). + """ + if isinstance(pred, dict): + pred = pred["pred"] + if isinstance(target, dict): + target = target["pred"] + + # If 3D, fold Depth (dim 2) into Batch (dim 0) -> (B*D, C, H, W) + if pred.dim() == 5: + B, C, D, H, W = pred.shape + pred = pred.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + target = target.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + + # 2D behavior remains identical + if pred.shape[1] == 1: + pred = pred.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + pred = self.preprocess(pred) + target = self.preprocess(target) + pred_vgg = self.vgg(pred) + target_vgg = self.vgg(target) + return self.loss(pred_vgg, target_vgg) + +class CycleGanLoss(nn.Module): + """Weighted composite loss for generator and discriminator training. + + This class combines multiple objectives for GAN-based image restoration: + + - Adversarial BCE term + - L1 reconstruction term + - MSE reconstruction term + - VGG perceptual term + - SSIM term + + Each term is controlled by configuration weights under + ``LOSS.CYCLEGAN``. Heavy components (VGG/SSIM modules) are created only + when their weight is greater than zero. + + References + ---------- + - Isola et al., "Image-to-Image Translation with Conditional Adversarial + Networks", CVPR 2017 (pix2pix). + https://arxiv.org/abs/1611.07004 + - Generator family inspiration (NAFNet/NAFSSR): + Chu et al., "NAFSSR: Stereo Image Super-Resolution Using NAFNet", + CVPR Workshops 2022. + https://openaccess.thecvf.com/content/CVPR2022W/NTIRE/html/Chu_NAFSSR_Stereo_Image_Super-Resolution_Using_NAFNet_CVPRW_2022_paper.html + - Structural/perceptual metrics are implemented with torchmetrics + (e.g., ``torchmetrics.image.StructuralSimilarityIndexMeasure``). + - Implementation adapted for this project from: + https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge + """ + + def __init__(self, cfg, device): + """Initialize composed GAN loss from configuration. + + Parameters + ---------- + cfg : yacs.config.CfgNode + Global configuration node. Uses ``cfg.LOSS.CYCLEGAN`` weights. + device : torch.device + Device where loss terms are computed. + """ + super().__init__() + self.device = device + self.w_gan = cfg.LOSS.CYCLEGAN.LAMBDA_GAN + self.w_l1 = cfg.LOSS.CYCLEGAN.LAMBDA_RECON + self.w_vgg = cfg.LOSS.CYCLEGAN.ALPHA_PERCEPTUAL + self.w_ssim = cfg.LOSS.CYCLEGAN.GAMMA_SSIM + self.w_mse = cfg.LOSS.CYCLEGAN.DELTA_MSE + + # Dont load the vgg if not + if self.w_vgg > 0: + self.vgg = VGG(device) + if self.w_ssim > 0: + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device) + + # Standard lightweight losses are always initialized + self.l1 = nn.L1Loss() + self.mse = nn.MSELoss() + self.bce = nn.BCEWithLogitsLoss() + + def forward_generator(self, pred, target, d_fake): + """Compute weighted generator loss. + + Parameters + ---------- + pred : torch.Tensor or dict + Generator prediction. If dict, reads ``pred['pred']``. + target : torch.Tensor or dict + Ground-truth target. If dict, reads ``target['pred']``. + d_fake : torch.Tensor + Discriminator logits for generated samples. + + Returns + ------- + torch.Tensor + Scalar generator loss as weighted sum of active terms. + """ + # Dict extraction + if isinstance(pred, dict): pred = pred["pred"] + if isinstance(target, dict): target = target["pred"] + + # NaN Band-aid + pred = torch.nan_to_num(pred, nan=0.0, posinf=1.0, neginf=-1.0) + target = torch.nan_to_num(target, nan=0.0, posinf=1.0, neginf=-1.0) + + total_loss = torch.tensor(0.0, device=self.device) + + # 2. Dynamically build the loss based on config weights + if self.w_l1 > 0: + total_loss += self.w_l1 * self.l1(pred, target) + + if self.w_mse > 0: + total_loss += self.w_mse * self.mse(pred, target) + + if self.w_vgg > 0: + total_loss += self.w_vgg * self.vgg(pred, target) + + if self.w_ssim > 0: + # SSIM requires 4D tensors. Safely route 3D to 2D slices. + if pred.dim() == 5: + B, C, D, H, W = pred.shape + pred_ssim = pred.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + target_ssim = target.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + total_loss += self.w_ssim * (1.0 - self.ssim(pred_ssim, target_ssim)) + else: + total_loss += self.w_ssim * (1.0 - self.ssim(pred, target)) + + if self.w_gan > 0: + total_loss += self.w_gan * self.bce(d_fake, torch.ones_like(d_fake)) + + # NaN Safety Check + if torch.isnan(total_loss): + print("Warning: NaN detected in generator loss. Returning zero loss.") + total_loss = torch.tensor(0.0, requires_grad=True).to(self.device) + + return total_loss + + def forward_discriminator(self, d_real, d_fake): + """Compute discriminator adversarial loss. + + Uses BCE with one-sided label smoothing for real logits. + + Parameters + ---------- + d_real : torch.Tensor + Discriminator logits for real samples. + d_fake : torch.Tensor + Discriminator logits for generated samples. + + Returns + ------- + torch.Tensor + Scalar discriminator loss. + """ + # Calculate Adversarial Loss for Discriminator + real_loss = self.bce(d_real, torch.full_like(d_real, 0.9)) # Label smoothing (0.9 instead of 1.0) + fake_loss = self.bce(d_fake, torch.zeros_like(d_fake)) + total_loss = (real_loss + fake_loss) / 2.0 + + # NaN Safety Check + if torch.isnan(total_loss): + print("Warning: NaN detected in discriminator loss. Returning zero loss.") + total_loss = torch.tensor(0.0, requires_grad=True).to(self.device) + + return total_loss \ No newline at end of file diff --git a/biapy/engine/self_supervised.py b/biapy/engine/self_supervised.py index b9162049..163d71d2 100644 --- a/biapy/engine/self_supervised.py +++ b/biapy/engine/self_supervised.py @@ -262,7 +262,7 @@ def define_metrics(self): def MaskedAutoencoderViT_loss_wrapper(self, output, targets): """Unravel MAE loss.""" # Targets not used because the loss has been already calculated - return output["loss"] + return {"losses": [output["loss"]]} def metric_calculation( self, diff --git a/biapy/engine/train_engine.py b/biapy/engine/train_engine.py index e4b05e3b..7f34c4fb 100644 --- a/biapy/engine/train_engine.py +++ b/biapy/engine/train_engine.py @@ -13,7 +13,7 @@ from typing import Callable, Optional from torch.utils.data import DataLoader from yacs.config import CfgNode as CN - +from torch.nn.utils import clip_grad_norm_ from biapy.utils.misc import MetricLogger, SmoothedValue, TensorboardLogger, all_reduce_mean from biapy.engine import Scheduler from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR @@ -29,15 +29,16 @@ def train_one_epoch( metric_function: Callable, prepare_targets: Callable, data_loader: DataLoader, - optimizer: Optimizer, + optimizer: list[Optimizer], device: torch.device, epoch: int, log_writer: Optional[TensorboardLogger] = None, - lr_scheduler: Optional[Scheduler] = None, + lr_scheduler: list[Optional[Scheduler]] = None, verbose: bool = False, memory_bank: Optional[MemoryBank] = None, total_iters: int=0, contrast_warmup_iters: int=0, + loss_names: list[str] = None, ): """ Train the model for one epoch. @@ -61,7 +62,7 @@ def train_one_epoch( Function to prepare targets for loss/metrics. data_loader : DataLoader Training data loader. - optimizer : Optimizer + optimizer : List[Optimizer] Optimizer for model parameters. device : torch.device Device to use. @@ -69,7 +70,7 @@ def train_one_epoch( Current epoch number. log_writer : TensorboardLogger, optional Logger for TensorBoard. - lr_scheduler : Scheduler, optional + lr_scheduler : List[Scheduler] Learning rate scheduler. verbose : bool, optional Verbosity flag. @@ -89,28 +90,26 @@ def train_one_epoch( """ # Switch to training mode model.train(True) - - # Ensure correct order of each epoch info by adding loss first + lr_names = [name.replace("loss", "lr", 1) for name in loss_names] metric_logger = MetricLogger(delimiter=" ", verbose=verbose) - metric_logger.add_meter("loss", SmoothedValue()) + for loss_name in loss_names: + metric_logger.add_meter(loss_name, SmoothedValue()) # Set up the header for logging header = "Epoch: [{}]".format(epoch + 1) print_freq = 10 - optimizer.zero_grad() + for opt in optimizer: + opt.zero_grad() for step, (batch, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): # Apply warmup cosine decay scheduler if selected # (notice we use a per iteration (instead of per epoch) lr scheduler) - if ( - epoch % cfg.TRAIN.ACCUM_ITER == 0 - and cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine" - and lr_scheduler - and isinstance(lr_scheduler, WarmUpCosineDecayScheduler) - ): - lr_scheduler.adjust_learning_rate(optimizer, step / len(data_loader) + epoch) + if cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": + for sched, opt in zip(lr_scheduler, optimizer): + if sched and isinstance(sched, WarmUpCosineDecayScheduler): + sched.adjust_learning_rate(opt, step / len(data_loader) + epoch) # Gather inputs targets = prepare_targets(targets, batch) @@ -139,60 +138,67 @@ def train_one_epoch( 'segment_queue': memory_bank.segment_queue, } - loss = loss_function(outputs, targets, with_embed=with_embed) + result = loss_function(outputs, targets, with_embed=with_embed) memory_bank.dequeue_and_enqueue( outputs['key'], targets.detach(), ) else: - loss = loss_function(outputs, targets) + result = loss_function(outputs, targets) - # Separate metric if precalculated inside the loss (e.g. Embedding loss) - precalculated_metric, precalculated_metric_name = None, None - if isinstance(loss, tuple): - precalculated_metric = loss[1] - precalculated_metric_name = loss[2] - loss = loss[0] + # Parse the loss result + if isinstance(result, dict): + losses = result.get("losses", []) + precalculated_metrics = result.get("metrics", {}) + else: + losses = [result] + precalculated_metrics = {} - loss_value = loss.item() - if not math.isfinite(loss_value): - print("Loss is {}, stopping training".format(loss_value)) - sys.exit(1) + for l_val in losses: + loss_value = l_val.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) # Calculate the metrics - if precalculated_metric is None: + if not precalculated_metrics: metric_function(outputs, targets, metric_logger=metric_logger) else: - metric_logger.meters[precalculated_metric_name].update(precalculated_metric) + for m_name, m_val in precalculated_metrics.items(): + metric_logger.meters[m_name].update(m_val) # Forward pass scaling the loss - loss /= cfg.TRAIN.ACCUM_ITER - if (step + 1) % cfg.TRAIN.ACCUM_ITER == 0: - loss.backward() - optimizer.step() # update weight - optimizer.zero_grad() - if lr_scheduler and isinstance(lr_scheduler, OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": - lr_scheduler.step() + for i, loss_tensor in enumerate(losses): + loss_tensor.backward() + if cfg.TRAIN.GRADIENT_CLIP_NORM > 0: + params = [p for group in optimizer[i].param_groups for p in group["params"]] + clip_grad_norm_(params, max_norm=cfg.TRAIN.GRADIENT_CLIP_NORM) + optimizer[i].step() + if lr_scheduler[i] and isinstance(lr_scheduler[i], OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler[i].step() if device.type != "cpu": getattr(torch, device.type).synchronize() # Update loss in loggers - metric_logger.update(loss=loss_value) - loss_value_reduce = all_reduce_mean(loss_value) - if log_writer: - log_writer.update(loss=loss_value_reduce, head="loss") + for i, loss_tensor in enumerate(losses): + loss_name = loss_names[i] + val = loss_tensor.item() + metric_logger.update(**{loss_name: val}) + loss_value_reduce = all_reduce_mean(val) + if log_writer: + log_writer.update(head="loss", **{loss_name: loss_value_reduce}) # Update lr in loggers - max_lr = 0.0 - for group in optimizer.param_groups: - max_lr = max(max_lr, group["lr"]) - if step == 0: - metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) - metric_logger.update(lr=max_lr) - if log_writer: - log_writer.update(lr=max_lr, head="opt") - + for i, opt in enumerate(optimizer): + max_lr = 0.0 + for group in opt.param_groups: + max_lr = max(max_lr, group["lr"]) + if step == 0: + metric_logger.add_meter(lr_names[i], SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.update(**{lr_names[i]: max_lr}) + if log_writer: + log_writer.update(head="opt", **{lr_names[i]: max_lr}) # Gather the stats from all processes metric_logger.synchronize_between_processes() print("[Train] averaged stats:", metric_logger) @@ -209,8 +215,9 @@ def evaluate( prepare_targets: Callable, epoch: int, data_loader: DataLoader, - lr_scheduler: Optional[Scheduler] = None, + lr_scheduler: list[Optional[Scheduler]] = None, memory_bank: Optional[MemoryBank] = None, + loss_names: list[str] = None, ): """ Evaluate the model on the validation set. @@ -248,7 +255,8 @@ def evaluate( """ # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ") - metric_logger.add_meter("loss", SmoothedValue()) + for loss_name in loss_names: + metric_logger.add_meter(loss_name, SmoothedValue()) header = "Epoch: [{}]".format(epoch + 1) # Switch to evaluation mode @@ -275,30 +283,35 @@ def evaluate( 'segment_queue': memory_bank.segment_queue, } - loss = loss_function(outputs, targets, with_embed=with_embed) + result = loss_function(outputs, targets, with_embed=with_embed) else: - loss = loss_function(outputs, targets) + result = loss_function(outputs, targets) # Separate metric if precalculated inside the loss (e.g. Embedding loss) - precalculated_metric, precalculated_metric_name = None, None - if isinstance(loss, tuple): - precalculated_metric = loss[1] - precalculated_metric_name = loss[2] - loss = loss[0] - - loss_value = loss.item() - if not math.isfinite(loss_value): - print("Loss is {}, stopping training".format(loss_value)) - sys.exit(1) + if isinstance(result, dict): + losses = result.get("losses", []) + precalculated_metrics = result.get("metrics", {}) + else: + losses = [result] + precalculated_metrics = {} + + for l_val in losses: + loss_value = l_val.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) # Calculate the metrics - if precalculated_metric is not None: - metric_logger.meters[precalculated_metric_name].update(precalculated_metric) + if precalculated_metrics: + for m_name, m_val in precalculated_metrics.items(): + metric_logger.meters[m_name].update(m_val) else: metric_function(outputs, targets, metric_logger=metric_logger) # Update loss in loggers - metric_logger.update(loss=loss) + for i, loss_tensor in enumerate(losses): + loss_name = loss_names[i] + metric_logger.update(**{loss_name: loss_tensor.item()}) # Gather the stats from all processes metric_logger.synchronize_between_processes() @@ -306,10 +319,10 @@ def evaluate( print("[Val] averaged stats:", metric_logger) # Apply reduceonplateau scheduler if the global validation has been reduced - if ( - lr_scheduler - and isinstance(lr_scheduler, ReduceLROnPlateau) - and cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau" - ): - lr_scheduler.step(metric_logger.meters["loss"].global_avg, epoch=epoch) + if lr_scheduler and cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": + for i, sched in enumerate(lr_scheduler): + if sched and isinstance(sched, ReduceLROnPlateau): + loss_name = loss_names[i] + sched.step(metric_logger.meters[loss_name].global_avg, epoch=epoch) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/biapy/models/__init__.py b/biapy/models/__init__.py index 365580d6..489fbb2c 100644 --- a/biapy/models/__init__.py +++ b/biapy/models/__init__.py @@ -376,6 +376,21 @@ def build_model( ) model = MaskedAutoencoderViT(**args) # type: ignore callable_model = MaskedAutoencoderViT # type: ignore + elif modelname == "nafnet": + args = dict( + img_channel=cfg.DATA.PATCH_SIZE[-1], + width=cfg.MODEL.NAFNET.WIDTH, + middle_blk_num=cfg.MODEL.NAFNET.MIDDLE_BLK_NUM, + enc_blk_nums=cfg.MODEL.NAFNET.ENC_BLK_NUMS, + dec_blk_nums=cfg.MODEL.NAFNET.DEC_BLK_NUMS, + drop_out_rate=cfg.MODEL.DROPOUT_VALUES[0], + dw_expand=cfg.MODEL.NAFNET.DW_EXPAND, + ffn_expand=cfg.MODEL.NAFNET.FFN_EXPAND, + discriminator_arch=cfg.MODEL.NAFNET.ARCHITECTURE_D, + patchgan_base_filters=cfg.MODEL.NAFNET.PATCHGAN.BASE_FILTERS, + ) + callable_model = NAFNet # type: ignore + model = callable_model(**args) # type: ignore # Check the network created model.to(device) if cfg.PROBLEM.NDIM == "2D": diff --git a/biapy/models/nafnet.py b/biapy/models/nafnet.py new file mode 100644 index 00000000..0802adcb --- /dev/null +++ b/biapy/models/nafnet.py @@ -0,0 +1,413 @@ +"""NAFNet model components and GAN discriminator builder utilities. + +This module provides: + +1. Lightweight building blocks (`SimpleGate`, `LayerNorm2d`, `NAFBlock`) used + by NAFNet. +2. The `NAFNet` encoder-decoder model for image restoration / image-to-image + workflows. +3. A discriminator builder function used by GAN-based training setups in BiaPy. + +Compared with traditional restoration backbones, NAFNet simplifies nonlinear +design while preserving strong reconstruction quality via gated depthwise blocks +and residual scaling. + +Reference +--------- +`Simple Baselines for Image Restoration `_. + +Related Work +------------ +The generator design is also inspired by the NAFSSR family: +`NAFSSR: Stereo Image Super-Resolution Using NAFNet +. +Implementation adapted for this project from: +https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge +Citation +-------- +Chu, Xiaojie and Chen, Liangyu and Yu, Wenqing. "NAFSSR: Stereo Image +Super-Resolution Using NAFNet." CVPR Workshops, 2022. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from yacs.config import CfgNode as CN +from torchinfo import summary + +from biapy.models.patchgan import PatchGANDiscriminator + +class SimpleGate(nn.Module): + """Simple channel-gating operator used in NAF blocks. + + The input tensor is split into two equal channel groups and both parts are + multiplied element-wise. + """ + + def forward(self, x): + """Apply channel-wise gating. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape `(N, C, H, W)` where `C` must be divisible + by 2. + + Returns + ------- + torch.Tensor + Tensor with shape `(N, C/2, H, W)` obtained by multiplying both + channel chunks element-wise. + """ + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class LayerNorm2d(nn.Module): + """Layer normalization over channel dimension for 2D features. + + This normalization computes mean and variance across channels for each + spatial position and applies learned affine parameters. + """ + + def __init__(self, channels, eps=1e-6): + """Initialize layer normalization parameters. + + Parameters + ---------- + channels : int + Number of channels in the input tensor. + eps : float, optional + Numerical stability constant added to the variance. + """ + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + """Normalize each spatial position across channels. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Normalized tensor with same shape as input. + """ + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + self.eps).sqrt() + y = self.weight.view(1, C, 1, 1) * y + self.bias.view(1, C, 1, 1) + return y + + +class NAFBlock(nn.Module): + """Core NAFNet residual block. + + The block combines: + 1. Layer normalization. + 2. Pointwise + depthwise convolutions. + 3. `SimpleGate` and simplified channel attention. + 4. A lightweight FFN branch. + 5. Two residual scaling parameters (`beta`, `gamma`). + """ + + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + """Initialize one NAF block. + + Parameters + ---------- + c : int + Number of input/output channels in the block. + DW_Expand : int, optional + Expansion ratio for the depthwise branch before gating. + FFN_Expand : int, optional + Expansion ratio for the feed-forward branch. + drop_out_rate : float, optional + Dropout probability used in both residual branches. + """ + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + """Apply the NAF block transformation. + + Parameters + ---------- + inp : torch.Tensor + Input feature map with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Output feature map with the same shape as `inp`. + """ + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + +class NAFNet(nn.Module): + """NAFNet encoder-decoder architecture for image restoration. + + The model follows a U-shaped design with: + 1. Intro and ending convolutions. + 2. Multiple encoder stages with downsampling. + 3. Bottleneck NAF blocks. + 4. Decoder stages with PixelShuffle upsampling and skip connections. + """ + + def __init__( + self, + img_channel=3, + width=16, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + drop_out_rate=0.0, + dw_expand=2, + ffn_expand=2, + discriminator_arch=None, + patchgan_base_filters=64, + ): + """Initialize a NAFNet model. + + Parameters + ---------- + img_channel : int, optional + Number of input/output image channels. + width : int, optional + Base number of channels. + middle_blk_num : int, optional + Number of NAF blocks in the bottleneck. + enc_blk_nums : list[int], optional + Number of NAF blocks per encoder stage. + dec_blk_nums : list[int], optional + Number of NAF blocks per decoder stage. + drop_out_rate : float, optional + Dropout probability used inside blocks. + dw_expand : int, optional + Expansion ratio for depthwise branch. + ffn_expand : int, optional + Expansion ratio for feed-forward branch. + + Notes + ----- + Spatial padding is handled in `check_image_size` to ensure dimensions are + divisible by the encoder downsampling factor. + """ + super().__init__() + + self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, bias=True) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + # Pass the new parameters into the NAFBlock + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = nn.Sequential( + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + + discriminator = None + if discriminator_arch == "patchgan": + discriminator = PatchGANDiscriminator( + in_channels=img_channel, + base_filters=patchgan_base_filters, + ) + + self.discriminator = discriminator + + @property + def param_groups(self): + """Return parameter groups for separate optimizers. + When a discriminator is present, returns ``[generator_params, discriminator_params]`` + so that :func:`prepare_optimizer` can assign a separate optimizer and learning + rate to each group. Without a discriminator, returns a single group. + """ + if self.discriminator is not None: + gen_params = [p for n, p in self.named_parameters() if not n.startswith("discriminator.")] + return [gen_params, list(self.discriminator.parameters())] + return [list(self.parameters())] + + def forward(self, inp): + """Run a forward pass through NAFNet. + + Parameters + ---------- + inp : torch.Tensor + Input image tensor with shape `(N, C, H, W)`. + + Notes + ----- + The input is internally padded to satisfy the downsampling factor and + then cropped back to original size at the end of the forward pass. + + Returns + ------- + torch.Tensor or dict + Restored image with original spatial size `(H, W)`. If the + discriminator is active, returns ``{"pred": tensor}`` so that the + output can be enriched with loss information by ``forward_loss``. + """ + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + pred = x[:, :, :H, :W] + + if self.discriminator is not None: + return {"pred": pred} + return pred + + def forward_loss(self, pred, targets, loss_fn): + """Compute GAN losses using the discriminator and the given loss function. + + Parameters + ---------- + pred : torch.Tensor + Generator prediction (restored image). + targets : torch.Tensor + Ground-truth clean image. + loss_fn : nn.Module + Loss module (e.g. ``CycleGanLoss``) providing + ``forward_generator`` and ``forward_discriminator``. + + Returns + ------- + tuple or None + ``(loss_generator, loss_discriminator)`` if the discriminator is + available, otherwise ``None``. + """ + if self.discriminator is None: + return None + + fake_img = torch.clamp(pred, 0, 1) + + for p in self.discriminator.parameters(): + p.requires_grad_(False) + d_fake_for_g = self.discriminator(fake_img) + loss_g = loss_fn.forward_generator(fake_img, targets, d_fake_for_g) + for p in self.discriminator.parameters(): + p.requires_grad_(True) + + d_real = self.discriminator(targets) + d_fake = self.discriminator(fake_img.detach()) + loss_d = loss_fn.forward_discriminator(d_real, d_fake) + + return (loss_g, loss_d) + + def check_image_size(self, x): + """Pad image so height/width are divisible by internal stride. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Padded tensor compatible with encoder/decoder downsampling. + """ + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x diff --git a/biapy/models/patchgan.py b/biapy/models/patchgan.py new file mode 100644 index 00000000..b7e736dd --- /dev/null +++ b/biapy/models/patchgan.py @@ -0,0 +1,95 @@ +"""PatchGAN discriminator model used in image-to-image GAN training. + +This module implements a convolutional discriminator that predicts realism at +the patch level instead of producing a single global score. Patch-level +classification is commonly used in conditional GAN pipelines because it +emphasizes local texture and edge consistency, which is especially useful in +restoration and translation tasks. + +Classes +------- +PatchGANDiscriminator + Multi-layer convolutional discriminator with strided downsampling blocks and + a final 1-channel logits map. + +Notes +----- +The output tensor shape is `(N, 1, H_patch, W_patch)`, where each spatial value +acts as a local real/fake logit for a receptive-field patch in the input image. + +Implementation adapted for this project from: +https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge + +""" + +import torch.nn as nn + + +class PatchGANDiscriminator(nn.Module): + """PatchGAN discriminator based on strided convolutional blocks. + + Parameters + ---------- + in_channels : int, optional + Number of channels in the input image. + base_filters : int, optional + Number of filters in the first discriminator block. Each subsequent + block doubles this value. + + Notes + ----- + The architecture follows a typical PatchGAN design: + 1. Four convolutional downsampling blocks. + 2. Batch normalization on all blocks except the first one. + 3. LeakyReLU activations. + 4. Final convolution producing a patch-logits map. + """ + + def __init__(self, in_channels=1, base_filters=64): + super(PatchGANDiscriminator, self).__init__() + + def discriminator_block(in_filters, out_filters, normalization=True): + """Create one discriminator stage. + + Parameters + ---------- + in_filters : int + Number of input channels. + out_filters : int + Number of output channels. + normalization : bool, optional + Whether to include BatchNorm after convolution. + + Returns + ------- + list[nn.Module] + Layers composing one stage of the discriminator. + """ + layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] + if normalization: + layers.append(nn.BatchNorm2d(out_filters)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *discriminator_block(in_channels, base_filters, normalization=False), + *discriminator_block(base_filters, base_filters * 2), + *discriminator_block(base_filters * 2, base_filters * 4), + *discriminator_block(base_filters * 4, base_filters * 8), + nn.Conv2d(base_filters * 8, 1, 4, stride=1, padding=1) + ) + + def forward(self, img): + """Run a forward pass through the discriminator. + + Parameters + ---------- + img : torch.Tensor + Input tensor with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Patch-wise realism logits with shape `(N, 1, H_patch, W_patch)`. + """ + return self.model(img) \ No newline at end of file diff --git a/biapy/utils/misc.py b/biapy/utils/misc.py index 91592daa..9628e6a0 100644 --- a/biapy/utils/misc.py +++ b/biapy/utils/misc.py @@ -325,7 +325,7 @@ def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, The current epoch number. model_without_ddp : nn.Module The model instance, typically the unwrapped model if using DistributedDataParallel. - optimizer : torch.optim.Optimizer + optimizer : List[torch.optim.Optimizer] The optimizer's state. model_build_kwargs : Optional[Dict], optional Keyword arguments used to build the model, useful for re-instantiating @@ -346,11 +346,15 @@ def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, to_save = { "model_build_kwargs": model_build_kwargs, "model": model_without_ddp.state_dict(), - "optimizer": optimizer.state_dict(), + "optimizer": [opt.state_dict() for opt in optimizer], "epoch": epoch, "cfg": cfg, "biapy_version": biapy_version, } + + # For Gan Models + if hasattr(model_without_ddp, 'discriminator'): + to_save["discriminator_state_dict"] = model_without_ddp.discriminator.state_dict() save_on_master(to_save, checkpoint_path) if len(checkpoint_paths) > 0: @@ -459,8 +463,8 @@ def load_model_checkpoint(cfg, jobname, model_without_ddp, device, optimizer=Non The model instance (unwrapped if DDP is used) to load weights into. device : torch.device The device to map the loaded checkpoint to. - optimizer : Optional[torch.optim.Optimizer], optional - The optimizer instance to load state into. If None, optimizer state is not loaded. + optimizer : Optional[List[torch.optim.Optimizer]], optional + The list of optimizer instances to load state into. If None, optimizer state is not loaded. Defaults to None. just_extract_checkpoint_info : bool, optional If True, only the configuration (`cfg`) and BiaPy version from the checkpoint @@ -557,11 +561,23 @@ def load_model_checkpoint(cfg, jobname, model_without_ddp, device, optimizer=Non model_without_ddp.load_state_dict(filtered_state_dict, strict=False) print("Model weights loaded!") + if "discriminator_state_dict" in checkpoint: + if hasattr(model_without_ddp, 'discriminator') and model_without_ddp.discriminator is not None: + model_without_ddp.discriminator.load_state_dict(checkpoint["discriminator_state_dict"], strict=False) + print("Discriminator weights loaded!") # Load also opt, epoch and scaler info if "optimizer" in checkpoint and optimizer is not None and "optimizer" in cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT: - optimizer.load_state_dict(checkpoint["optimizer"], strict=False) - print("Optimizer info loaded!") + # Backward compatibility: checkpoints are not converted in check_configuration. + checkpoint_optimizer = checkpoint["optimizer"] + if isinstance(checkpoint_optimizer, dict): + checkpoint_optimizer = [checkpoint_optimizer] + + loaded_optimizers = 0 + for opt, opt_state in zip(optimizer, checkpoint_optimizer): + opt.load_state_dict(opt_state, strict=False) + loaded_optimizers += 1 + print(f"Optimizer info loaded for {loaded_optimizers}/{len(optimizer)} optimizer(s)!") start_epoch = 0 if "epoch" in checkpoint and "epoch" in cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT: diff --git a/biapy/utils/util.py b/biapy/utils/util.py index be1e6e6a..64efff38 100644 --- a/biapy/utils/util.py +++ b/biapy/utils/util.py @@ -34,7 +34,7 @@ from biapy.utils.misc import is_main_process -def create_plots(results, metrics, job_id, chartOutDir): +def create_plots(results, metrics, loss_names, job_id, chartOutDir): """ Create loss and main metric plots with the given results. @@ -50,6 +50,8 @@ def create_plots(results, metrics, job_id, chartOutDir): and its validation counterpart (e.g., 'val_jaccard_index'). metrics : List[str] A list of metric names (e.g., ["jaccard_index", "f1_score"]) present in `results`. + loss_names : List[str] + A list of loss function names (e.g., ["loss", "loss_discriminator"]) present in `results`. job_id : str A unique identifier for the job, used in plot titles and filenames. chartOutDir : str @@ -76,18 +78,20 @@ def create_plots(results, metrics, job_id, chartOutDir): os.environ["QT_QPA_PLATFORM"] = "offscreen" # Loss - plt.plot(results["loss"]) - if "val_loss" in results: - plt.plot(results["val_loss"]) - plt.title("Model JOBID=" + job_id + " loss") - plt.ylabel("Value") - plt.xlabel("Epoch") - if "val_loss" in results: - plt.legend(["Train loss", "Val. loss"], loc="upper left") - else: - plt.legend(["Train loss"], loc="upper left") - plt.savefig(os.path.join(chartOutDir, job_id + "_loss.png")) - plt.clf() + for loss_key in loss_names: + val_loss_key = f"val_{loss_key}" + plt.plot(results[loss_key]) + if val_loss_key in results: + plt.plot(results[val_loss_key]) + plt.title("Model JOBID=" + job_id + " " + loss_key) + plt.ylabel("Value") + plt.xlabel("Epoch") + if val_loss_key in results: + plt.legend([f"Train {loss_key}", f"Val. {loss_key}"], loc="upper left") + else: + plt.legend([f"Train {loss_key}"], loc="upper left") + plt.savefig(os.path.join(chartOutDir, job_id + "_" + loss_key + ".png")) + plt.clf() # Metric for i in range(len(metrics)):