> def noise_loss(image, illumination, reflectance, noise): weight_illu = illumination weight_illu.detach() loss = weight_illu*noise return torch.norm(loss, 2)