From a2855621cdc348e786180cc1aa5654be62370e2a Mon Sep 17 00:00:00 2001 From: DheerajMadda <50489165+DheerajMadda@users.noreply.github.com> Date: Thu, 7 Sep 2023 09:53:39 +0530 Subject: [PATCH 1/4] Optimized and restructured BASNet model and loss 1) Optimized the loss function, removed for loop for calculating IoU and used torchmetrics for SSIM. 2) Restructured the BASNet model definition 3) Created a BASNet_Lite model --- network/__init__.py | 3 + network/loss/basenet_loss.py | 71 +++++++++ network/model/BASNet.py | 273 +++++++++++++++++++++++++++++++++++ network/model/BASNet_lite.py | 272 ++++++++++++++++++++++++++++++++++ 4 files changed, 619 insertions(+) create mode 100644 network/__init__.py create mode 100644 network/loss/basenet_loss.py create mode 100644 network/model/BASNet.py create mode 100644 network/model/BASNet_lite.py diff --git a/network/__init__.py b/network/__init__.py new file mode 100644 index 00000000..9d01a6cd --- /dev/null +++ b/network/__init__.py @@ -0,0 +1,3 @@ +from .model.BASNet import BASNet +from .model.BASNet_lite import BASNet_Lite +from .loss.basenet_loss import BASNetLoss \ No newline at end of file diff --git a/network/loss/basenet_loss.py b/network/loss/basenet_loss.py new file mode 100644 index 00000000..91b7eee3 --- /dev/null +++ b/network/loss/basenet_loss.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torchmetrics as tmt + +class SSIM(nn.Module): + def __init__(self, device="cpu"): + super(SSIM, self).__init__() + self.ssim = tmt.image.StructuralSimilarityIndexMeasure( + data_range=1.0 + ).to(device) + + def forward(self, preds, targets): + return self.ssim(preds, targets) + +class IOU(nn.Module): + def __init__(self): + super(IOU, self).__init__() + self.smooth = 1.0e-9 + + def forward(self, preds, targets): + intersection = torch.sum(torch.abs(targets * preds), dim=[1,2,3]) + union = (torch.sum(targets, dim=[1,2,3]) + torch.sum(preds, dim=[1,2,3])) - intersection + iou = torch.mean((intersection + self.smooth) / (union + self.smooth)) + return iou + +class BASNetLoss(nn.Module): + """BASNet hybrid loss.""" + + def __init__(self, device="cpu"): + super(BASNetLoss, self).__init__() + self.bce_loss = nn.BCELoss() + self.ssim = SSIM(device=device) + self.iou = IOU() + self.smooth = 1.0e-9 + self._is_train = True + + def train(self): + self._is_train = True + + def eval(self): + self._is_train = False + + def hybrid_loss(self, y_pred, y_true): + bce_loss = self.bce_loss(y_pred, y_true) + + ssim_value = self.ssim(y_pred, y_true) + ssim_loss = 1 - ssim_value + self.smooth + + iou_value = self.iou(y_pred, y_true) + iou_loss = 1 - iou_value + + # Add all three losses + return bce_loss + ssim_loss + iou_loss + + def forward(self, sup8, sup1, sup2, sup3, sup4, sup5, sup6, sup7, target): + + loss8 = self.hybrid_loss(sup8, target) + + if self._is_train: + loss1 = self.hybrid_loss(sup1, target) + loss2 = self.hybrid_loss(sup2, target) + loss3 = self.hybrid_loss(sup3, target) + loss4 = self.hybrid_loss(sup4, target) + loss5 = self.hybrid_loss(sup5, target) + loss6 = self.hybrid_loss(sup6, target) + loss7 = self.hybrid_loss(sup7, target) + + loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8 + return loss, loss8 + + return loss8 diff --git a/network/model/BASNet.py b/network/model/BASNet.py new file mode 100644 index 00000000..626e3465 --- /dev/null +++ b/network/model/BASNet.py @@ -0,0 +1,273 @@ +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + + # First convolutional layer + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + # Second convolutional layer + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + # Downsample layer if exists + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + # First convolutional block + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + + # Second convolutional block + x = self.conv2(x) + x = self.bn2(x) + + if self.downsample: + residual = self.downsample(residual) + + x += residual + x = self.relu2(x) + + return x + +class PredictModule(nn.Module): + """ + Predict Module (Encoder-Decoder). + A unet based model. + """ + def __init__(self): + super(PredictModule, self).__init__() + + resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT) + + self.encoder = self._make_encoder(resnet) + self.bridge = self._make_bridge() + self.decoder = self._make_decoder() + self.side_outputs = self._make_side_outputs() + self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear') + + def _make_encoder(self, resnet): + encoder_layers = [ + nn.Sequential( + nn.Conv2d(3, 64, 3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + # stage 1 + resnet.layer1 #224 + ), + # stage 2 + resnet.layer2, #112 + # stage 3 + resnet.layer3, #56 + # stage 4 + resnet.layer4, #28 + # stage 5 + nn.Sequential( + nn.MaxPool2d(2, 2, ceil_mode=True), #14 + BasicBlock(512, 512), + BasicBlock(512, 512), + BasicBlock(512, 512) + ), + # stage 6 + nn.Sequential( + nn.MaxPool2d(2,2,ceil_mode=True), #7 + BasicBlock(512, 512), + BasicBlock(512, 512), + BasicBlock(512, 512) + ) + ] + + return nn.ModuleList(encoder_layers) + + def _make_bridge(self): + bridge_layers = [ + self._conv_block(512, 512, 3, dilation=2, padding=2), + self._conv_block(512, 512, 3, dilation=2, padding=2), + self._conv_block(512, 512, 3, dilation=2, padding=2) + ] #7 + return nn.Sequential(*bridge_layers) + + def _make_decoder(self): + decoder_layers = [ + # stage 6d + nn.Sequential( + self._conv_block(1024, 512, 3, padding=1), #7 + self._conv_block(512, 512, 3, dilation=2, padding=2), + self._conv_block(512, 512, 3, dilation=2, padding=2) + ), + # stage 5d + nn.Sequential( + self._conv_block(1024, 512, 3, padding=1), #14 + self._conv_block(512, 512, 3, padding=1), + self._conv_block(512, 512, 3, padding=1) + ), + # stage 4d + nn.Sequential( + self._conv_block(1024, 512, 3, padding=1), #28 + self._conv_block(512, 512, 3, padding=1), + self._conv_block(512, 256, 3, padding=1) + ), + # stage 3d + nn.Sequential( + self._conv_block(512, 256, 3, padding=1), #56 + self._conv_block(256, 256, 3, padding=1), + self._conv_block(256, 128, 3, padding=1) + ), + # stage 2d + nn.Sequential( + self._conv_block(256, 128, 3, padding=1), #112 + self._conv_block(128, 128, 3, padding=1), + self._conv_block(128, 64, 3, padding=1) + ), + # stage 1d + nn.Sequential( + self._conv_block(128, 64, 3, padding=1), #224 + self._conv_block(64, 64, 3, padding=1), + self._conv_block(64, 64, 3, padding=1) + ) + ] + + return nn.ModuleList(decoder_layers) + + def _make_side_outputs(self): + + side_output_layers = [] + channels = [512, 512, 512, 256, 128, 64] # sup1 -> sup6 + upsample_scales = [32, 32, 16, 8, 4 ,2] # sup1 -> sup6 + + for channel, scale_factor in zip(channels, upsample_scales): + side_output_layers += [ + nn.Sequential( + nn.Conv2d(channel, 1, 3, padding=1), + nn.Upsample(scale_factor=scale_factor, mode='bilinear') + ) + ] + + side_output_layers += [nn.Conv2d(64, 1, 3, padding=1)] # sup7 + + return nn.ModuleList(side_output_layers) + + def _conv_block(self, in_channels, out_channels, kernel_size, **kwargs): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + # Encoder + encoder_outs = [] + for encoder_module in self.encoder: + x = encoder_module(x) + encoder_outs.append(x) + + # Bridge + x = self.bridge(x) + + # Decoder and Side Outputs + side_outputs = [self.side_outputs[0](x)] + + for idx, decoder_module in enumerate(self.decoder, start=1): + x = torch.cat((x, encoder_outs[-idx]), dim=1) + x = decoder_module(x) + sx = self.side_outputs[idx](x) + side_outputs.append(sx) + + if idx < 6: # not applying upsample for decoder stage 1d + x = self.upsample2(x) + + return side_outputs + +class RRM(nn.Module): + """ + Residual Refinement Module (RRM). + A unet based model. + """ + def __init__(self, in_ch, inc_ch): + super(RRM, self).__init__() + + # Initial convolution + self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1) + + # Encoder + self.enc1 = self._block(inc_ch, 64) + self.enc2 = self._block(64, 64) + self.enc3 = self._block(64, 64) + self.enc4 = self._block(64, 64) + + self.pool = nn.MaxPool2d(2, 2, ceil_mode=True) + + # Bridge + self.bridge = self._block(64, 64) + + # Decoder + self.dec4 = self._block(128, 64) + self.dec3 = self._block(128, 64) + self.dec2 = self._block(128, 64) + self.dec1 = self._block(128, 64) + + self.final = nn.Conv2d(64, 1, 3, padding=1) + + # Upsampling + self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear') + + def _block(self, in_ch, out_ch): + return nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + # Encoder + hx1 = self.enc1(self.conv0(x)) + hx2 = self.enc2(self.pool(hx1)) + hx3 = self.enc3(self.pool(hx2)) + hx4 = self.enc4(self.pool(hx3)) + + # Bridge + hx5 = self.bridge(self.pool(hx4)) + + # Decoder with skip connections + d4 = self.dec4(torch.cat((self.upsample2(hx5), hx4), 1)) + d3 = self.dec3(torch.cat((self.upsample2(d4), hx3), 1)) + d2 = self.dec2(torch.cat((self.upsample2(d3), hx2), 1)) + d1 = self.dec1(torch.cat((self.upsample2(d2), hx1), 1)) + + # Final layer + residual = self.final(d1) + + return x + residual + +class BASNet(nn.Module): + def __init__(self): + super(BASNet, self).__init__() + + self.predict_module = PredictModule() + self.refine_module = RRM(1, 64) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + + side_outputs = self.predict_module(x) + out = self.refine_module(side_outputs[-1]) + out = self.sigmoid(out) + + if self.training: + return (out, *[self.sigmoid(x) for x in side_outputs]) + return out diff --git a/network/model/BASNet_lite.py b/network/model/BASNet_lite.py new file mode 100644 index 00000000..af8b9e5f --- /dev/null +++ b/network/model/BASNet_lite.py @@ -0,0 +1,272 @@ +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + + # First convolutional layer + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + # Second convolutional layer + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + # Downsample layer if exists + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + # First convolutional block + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + + # Second convolutional block + x = self.conv2(x) + x = self.bn2(x) + + if self.downsample: + residual = self.downsample(residual) + + x += residual + x = self.relu2(x) + + return x + +class PredictModule(nn.Module): + """ + Predict Module (Encoder-Decoder). + A unet based model. + """ + def __init__(self): + super(PredictModule, self).__init__() + + mbv3_large = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT) + + self.encoder = self._make_encoder(mbv3_large) + self.bridge = self._make_bridge() + self.decoder = self._make_decoder() + self.side_outputs = self._make_side_outputs() + self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear') + + def _make_encoder(self, mbv3_large): + encoder_layers = [ + nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.BatchNorm2d(16), + nn.ReLU(inplace=True), + # stage 1 + mbv3_large.features[1:4] #112 + ), + # stage 2 + mbv3_large.features[4:7], #56 + # stage 3 + mbv3_large.features[7:10], #28 + # stage 4 + mbv3_large.features[10:13], #28 + # stage 5 + nn.Sequential( + nn.MaxPool2d(2, 2, ceil_mode=True), #14 + BasicBlock(112, 112), + BasicBlock(112, 112), + BasicBlock(112, 112) + ), + # stage 6 + nn.Sequential( + nn.MaxPool2d(2,2,ceil_mode=True), #7 + BasicBlock(112, 112), + BasicBlock(112, 112), + BasicBlock(112, 112) + ) + ] + + return nn.ModuleList(encoder_layers) + + def _make_bridge(self): + bridge_layers = [ + self._conv_block(112, 112, 3, dilation=2, padding=2), + self._conv_block(112, 112, 3, dilation=2, padding=2), + self._conv_block(112, 112, 3, dilation=2, padding=2) + ] #7 + return nn.Sequential(*bridge_layers) + + def _make_decoder(self): + decoder_layers = [ + # stage 6d + nn.Sequential( + self._conv_block(224, 112, 3, padding=1), #7 + self._conv_block(112, 112, 3, dilation=2, padding=2), + self._conv_block(112, 112, 3, dilation=2, padding=2), + ), + # stage 5d + nn.Sequential( + self._conv_block(224, 112, 3, padding=1), #14 + self._conv_block(112, 112, 3, padding=1), + self._conv_block(112, 112, 3, padding=1) + ), + # stage 4d + nn.Sequential( + self._conv_block(224, 112, 3, padding=1), #28 + self._conv_block(112, 112, 3, padding=1), + self._conv_block(112, 56, 3, padding=1) + + ), + # stage 3d + nn.Sequential( + self._conv_block(136, 48, 3, padding=1), #28 + self._conv_block(48, 48, 3, padding=1), + self._conv_block(48, 32, 3, padding=1) + ), + # stage 2d + nn.Sequential( + self._conv_block(72, 24, 3, padding=1), #56 + self._conv_block(24, 24, 3, padding=1), + self._conv_block(24, 16, 3, padding=1) + ), + # stage 1d + nn.Sequential( + self._conv_block(40, 16, 3, padding=1), #112 + self._conv_block(16, 16, 3, padding=1), + self._conv_block(16, 16, 3, padding=1) + ) + ] + + return nn.ModuleList(decoder_layers) + + def _make_side_outputs(self): + + side_output_layers = [] + channels = [112, 112, 112, 56, 32, 16, 16] # sup1 -> sup7 + upsample_scales = [32, 32, 16, 8, 8 ,4, 2] # sup1 -> sup7 + + for channel, scale_factor in zip(channels, upsample_scales): + side_output_layers += [ + nn.Sequential( + nn.Conv2d(channel, 1, 3, padding=1), + nn.Upsample(scale_factor=scale_factor, mode='bilinear') + ) + ] + + return nn.ModuleList(side_output_layers) + + def _conv_block(self, in_channels, out_channels, kernel_size, **kwargs): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + # Encoder + encoder_outs = [] + for encoder_module in self.encoder: + x = encoder_module(x) + encoder_outs.append(x) + + # Bridge + x = self.bridge(x) + + # Decoder and Side Outputs + side_outputs = [self.side_outputs[0](x)] + + for idx, decoder_module in enumerate(self.decoder, start=1): + x = torch.cat((x, encoder_outs[-idx]), dim=1) + x = decoder_module(x) + sx = self.side_outputs[idx](x) + side_outputs.append(sx) + + if idx != 3: # not applying upsample for decoder stage 4d + x = self.upsample2(x) + + return side_outputs + +class RRM(nn.Module): + """ + Residual Refinement Module (RRM). + A unet based model. + """ + def __init__(self, in_ch, inc_ch): + super(RRM, self).__init__() + + # Initial convolution + self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1) + + # Encoder + self.enc1 = self._block(inc_ch, 16) + self.enc2 = self._block(16, 16) + self.enc3 = self._block(16, 16) + self.enc4 = self._block(16, 16) + + self.pool = nn.MaxPool2d(2, 2, ceil_mode=True) + + # Bridge + self.bridge = self._block(16, 16) + + # Decoder + self.dec4 = self._block(32, 16) + self.dec3 = self._block(32, 16) + self.dec2 = self._block(32, 16) + self.dec1 = self._block(32, 16) + + self.final = nn.Conv2d(16, 1, 3, padding=1) + + # Upsampling + self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear') + + def _block(self, in_ch, out_ch): + return nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + # Encoder + hx1 = self.enc1(self.conv0(x)) + hx2 = self.enc2(self.pool(hx1)) + hx3 = self.enc3(self.pool(hx2)) + hx4 = self.enc4(self.pool(hx3)) + + # Bridge + hx5 = self.bridge(self.pool(hx4)) + + # Decoder with skip connections + d4 = self.dec4(torch.cat((self.upsample2(hx5), hx4), 1)) + d3 = self.dec3(torch.cat((self.upsample2(d4), hx3), 1)) + d2 = self.dec2(torch.cat((self.upsample2(d3), hx2), 1)) + d1 = self.dec1(torch.cat((self.upsample2(d2), hx1), 1)) + + # Final layer + residual = self.final(d1) + + return x + residual + +class BASNet_Lite(nn.Module): + def __init__(self): + super(BASNet_Lite, self).__init__() + + self.predict_module = PredictModule() + self.refine_module = RRM(1, 16) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + + side_outputs = self.predict_module(x) + out = self.refine_module(side_outputs[-1]) + out = self.sigmoid(out) + + if self.training: + return (out, *[self.sigmoid(x) for x in side_outputs]) + return out From 666f71c93b1209721e9aa8cdc6dd6ecc11431f23 Mon Sep 17 00:00:00 2001 From: DheerajMadda <50489165+DheerajMadda@users.noreply.github.com> Date: Sun, 10 Sep 2023 16:56:59 +0530 Subject: [PATCH 2/4] Update __init__.py --- network/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/network/__init__.py b/network/__init__.py index 9d01a6cd..fa0b0b58 100644 --- a/network/__init__.py +++ b/network/__init__.py @@ -1,3 +1,3 @@ from .model.BASNet import BASNet from .model.BASNet_lite import BASNet_Lite -from .loss.basenet_loss import BASNetLoss \ No newline at end of file +from .loss.basenet_loss import BASNetLoss From 21bc4bd6e73b990aa79b233b60a54a13822b3144 Mon Sep 17 00:00:00 2001 From: DheerajMadda <50489165+DheerajMadda@users.noreply.github.com> Date: Sun, 10 Sep 2023 16:57:35 +0530 Subject: [PATCH 3/4] Update BASNet.py --- network/model/BASNet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/network/model/BASNet.py b/network/model/BASNet.py index 626e3465..5df15bf4 100644 --- a/network/model/BASNet.py +++ b/network/model/BASNet.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn from torchvision import models -import torch.nn.functional as F def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" @@ -180,13 +179,18 @@ def forward(self, x): x = self.bridge(x) # Decoder and Side Outputs - side_outputs = [self.side_outputs[0](x)] + side_outputs = [self.side_outputs[0](x)] if self.training else [] for idx, decoder_module in enumerate(self.decoder, start=1): x = torch.cat((x, encoder_outs[-idx]), dim=1) x = decoder_module(x) - sx = self.side_outputs[idx](x) - side_outputs.append(sx) + if self.training: + sx = self.side_outputs[idx](x) + side_outputs.append(sx) + else: + if idx == 6: + sx = self.side_outputs[idx](x) + side_outputs.append(sx) if idx < 6: # not applying upsample for decoder stage 1d x = self.upsample2(x) From 9f49efed0178c9c00811203c792bf4a31203fc87 Mon Sep 17 00:00:00 2001 From: DheerajMadda <50489165+DheerajMadda@users.noreply.github.com> Date: Sun, 10 Sep 2023 16:58:14 +0530 Subject: [PATCH 4/4] Update BASNet_lite.py --- network/model/BASNet_lite.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/network/model/BASNet_lite.py b/network/model/BASNet_lite.py index af8b9e5f..17050b03 100644 --- a/network/model/BASNet_lite.py +++ b/network/model/BASNet_lite.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn from torchvision import models -import torch.nn.functional as F def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" @@ -179,15 +178,20 @@ def forward(self, x): x = self.bridge(x) # Decoder and Side Outputs - side_outputs = [self.side_outputs[0](x)] + side_outputs = [self.side_outputs[0](x)] if self.training else [] for idx, decoder_module in enumerate(self.decoder, start=1): x = torch.cat((x, encoder_outs[-idx]), dim=1) x = decoder_module(x) - sx = self.side_outputs[idx](x) - side_outputs.append(sx) + if self.training: + sx = self.side_outputs[idx](x) + side_outputs.append(sx) + else: + if idx == 6: + sx = self.side_outputs[idx](x) + side_outputs.append(sx) - if idx != 3: # not applying upsample for decoder stage 4d + if idx not in (3, 6): # not applying upsample for decoder stage 4d and # stage 1d x = self.upsample2(x) return side_outputs