From 1c6dc1185111b73fd996e7f3e232410ac696e4e0 Mon Sep 17 00:00:00 2001 From: Sohan Anisetty Date: Sat, 28 Aug 2021 14:28:15 +0530 Subject: [PATCH] Added SSIM support, quality of life improvements in invert.py Added support for SSIM loss. SSIM loss improves reconstruction of faces under extreme conditions of pose, illumination etc. SSIM makes the output smoother which may be an unwanted attribute. Added support for directly taking images from directory in invert.py instead of making a .list file. Quality of life improvements in the example codes for inversion and manipulation. --- README.md | 17 +- diffuse.py | 6 +- examples/test.list | 2 +- interpolate.py | 6 +- invert.py | 61 +++- manipulate.py | 6 +- mix_style.py | 6 +- models/base_module.py | 600 +++++++++++++++---------------- models/model_settings.py | 6 + models/stylegan_generator.py | 15 + pytorch_ssim/__init__.py | 73 ++++ utils/inverter.py | 667 ++++++++++++++++++----------------- 12 files changed, 810 insertions(+), 655 deletions(-) create mode 100644 pytorch_ssim/__init__.py diff --git a/README.md b/README.md index 09328da..37f3333 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,14 @@ Please download the pre-trained models from the following links and save them to ```bash MODEL_NAME='styleganinv_ffhq256' IMAGE_LIST='examples/test.list' -python invert.py $MODEL_NAME $IMAGE_LIST +python invert.py --model_name $MODEL_NAME --image_list $IMAGE_LIST +``` + +```bash +MODEL_NAME='styleganinv_ffhq256' +IMAGE_DIR='examples/images/' + +python invert.py --model_name $MODEL_NAME --image_dir $IMAGE_DIR ``` **NOTE:** We find that 100 iterations are good enough for inverting an image, which takes about 8s (on P40). But users can always use more iterations (much slower) for a more precise reconstruction. @@ -47,7 +54,7 @@ python invert.py $MODEL_NAME $IMAGE_LIST MODEL_NAME='styleganinv_ffhq256' TARGET_LIST='examples/target.list' CONTEXT_LIST='examples/context.list' -python diffuse.py $MODEL_NAME $TARGET_LIST $CONTEXT_LIST +python diffuse.py --model_name $MODEL_NAME --target_list $TARGET_LIST --context_list $CONTEXT_LIST ``` NOTE: The diffusion process is highly similar to image inversion. The main difference is that only the target patch is used to compute loss for **masked** optimization. @@ -57,7 +64,7 @@ NOTE: The diffusion process is highly similar to image inversion. The main diffe ```bash SRC_DIR='results/inversion/test' DST_DIR='results/inversion/test' -python interpolate.py $MODEL_NAME $SRC_DIR $DST_DIR +python interpolate.py --model_name $MODEL_NAME --src_dir $SRC_DIR --dst_dir $DST_DIR ``` ### Manipulation @@ -65,7 +72,7 @@ python interpolate.py $MODEL_NAME $SRC_DIR $DST_DIR ```bash IMAGE_DIR='results/inversion/test' BOUNDARY='boundaries/expression.npy' -python manipulate.py $MODEL_NAME $IMAGE_DIR $BOUNDARY +python manipulate.py --model_name $MODEL_NAME --image_dir $IMAGE_DIR --boundary_path $BOUNDARY ``` **NOTE:** Boundaries are obtained using [InterFaceGAN](https://github.com/genforce/interfacegan). @@ -75,7 +82,7 @@ python manipulate.py $MODEL_NAME $IMAGE_DIR $BOUNDARY ```bash STYLE_DIR='results/inversion/test' CONTENT_DIR='results/inversion/test' -python mix_style.py $MODEL_NAME $STYLE_DIR $CONTENT_DIR +python mix_style.py --model_name $MODEL_NAME --style_dir $STYLE_DIR --content_dir $CONTENT_DIR ``` ## BibTeX diff --git a/diffuse.py b/diffuse.py index fc1e37e..542e5aa 100644 --- a/diffuse.py +++ b/diffuse.py @@ -24,10 +24,10 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('target_list', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--target_list', type=str, help='List of target images to diffuse from.') - parser.add_argument('context_list', type=str, + parser.add_argument('--context_list', type=str, help='List of context images to diffuse to.') parser.add_argument('-o', '--output_dir', type=str, default='', help='Directory to save the results. If not specified, ' diff --git a/examples/test.list b/examples/test.list index 46fd527..e3879b8 100644 --- a/examples/test.list +++ b/examples/test.list @@ -16,4 +16,4 @@ examples/000015.png examples/000016.png examples/000017.png examples/000018.png -examples/000019.png +examples/000019.png \ No newline at end of file diff --git a/interpolate.py b/interpolate.py index 882dfcf..bbe430f 100644 --- a/interpolate.py +++ b/interpolate.py @@ -23,11 +23,11 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('src_dir', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--src_dir', type=str, help='Source directory, which includes original images, ' 'inverted codes, and image list.') - parser.add_argument('dst_dir', type=str, + parser.add_argument('--dst_dir', type=str, help='Target directory, which includes original images, ' 'inverted codes, and image list.') parser.add_argument('-o', '--output_dir', type=str, default='', diff --git a/invert.py b/invert.py index 169e6d9..389e5b6 100644 --- a/invert.py +++ b/invert.py @@ -21,9 +21,12 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('image_list', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--image_list', type=str, default = '', help='List of images to invert.') + + parser.add_argument('--test_dir', type=str, default = '', + help='directory of images to invert.') parser.add_argument('-o', '--output_dir', type=str, default='', help='Directory to save the results. If not specified, ' '`./results/inversion/${IMAGE_LIST}` ' @@ -38,6 +41,11 @@ def parse_args(): parser.add_argument('--loss_weight_feat', type=float, default=5e-5, help='The perceptual loss scale for optimization. ' '(default: 5e-5)') + + parser.add_argument('--loss_weight_ssim', type=float, default=1.0, + help='The perceptual loss scale for optimization. ' + '(default: 1)') + parser.add_argument('--loss_weight_enc', type=float, default=2.0, help='The encoder loss scale for optimization.' '(default: 2.0)') @@ -52,9 +60,24 @@ def main(): """Main function.""" args = parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id - assert os.path.exists(args.image_list) - image_list_name = os.path.splitext(os.path.basename(args.image_list))[0] + if args.image_list != '' and args.test_dir == '': + assert os.path.exists(args.image_list) + image_list_name = os.path.splitext(os.path.basename(args.image_list))[0] + elif args.test_dir != '' and args.image_list == '' : + assert os.path.exists(args.test_dir) + image_list_name = os.path.splitext(os.path.basename(args.test_dir))[0] + else: + raise Exception("Use either --image_list or --test_dir. Using both arguments at the same time not supported.") + + + MODEL_DIR = os.path.join('models', 'pretrain') + os.makedirs(MODEL_DIR, exist_ok=True) + if(all(x not in os.listdir(MODEL_DIR) for x in ["styleganinv_ffhq256_encoder.pth" , "styleganinv_ffhq256_generator.pth" , "vgg16.pth"])): + raise Exception("styleganinv_ffhq256_encoder.pth , styleganinv_ffhq256_generator.pth and vgg16.pth missing") + output_dir = args.output_dir or f'results/inversion/{image_list_name}' + if not os.path.exists(output_dir): + os.makedirs(output_dir) logger = setup_logger(output_dir, 'inversion.log', 'inversion_logger') logger.info(f'Loading model.') @@ -65,15 +88,27 @@ def main(): reconstruction_loss_weight=1.0, perceptual_loss_weight=args.loss_weight_feat, regularization_loss_weight=args.loss_weight_enc, + loss_weight_ssim = args.loss_weight_ssim, logger=logger) image_size = inverter.G.resolution # Load image list. logger.info(f'Loading image list.') image_list = [] - with open(args.image_list, 'r') as f: - for line in f: - image_list.append(line.strip()) + if args.image_list !='': + + with open(args.image_list, 'r') as f: + for line in f: + image_list.append(line.strip()) + + if args.test_dir !='': + for root, dirs, files in os.walk(args.test_dir): + for file in files: + image_list.append(file) + + + #print(len(image_list)) + logger.info(f'loaded {len(image_list)} images') # Initialize visualizer. save_interval = args.num_iterations // args.num_results @@ -90,10 +125,15 @@ def main(): logger.info(f'Start inversion.') latent_codes = [] for img_idx in tqdm(range(len(image_list)), leave=False): - image_path = image_list[img_idx] - image_name = os.path.splitext(os.path.basename(image_path))[0] + if args.image_list !='': + image_path = image_list[img_idx] + image_name = os.path.splitext(os.path.basename(image_path))[0] + elif args.test_dir !='': + image_path = os.path.join( args.test_dir, image_list[img_idx]) + image_name = os.path.splitext(os.path.basename(image_list[img_idx]))[0] + image = resize_image(load_image(image_path), (image_size, image_size)) - code, viz_results = inverter.easy_invert(image, num_viz=args.num_results) + code, viz_results , ssim_loss = inverter.easy_invert(np.array(image), num_viz=args.num_results) latent_codes.append(code) save_image(f'{output_dir}/{image_name}_ori.png', image) save_image(f'{output_dir}/{image_name}_enc.png', viz_results[1]) @@ -103,6 +143,7 @@ def main(): for viz_idx, viz_img in enumerate(viz_results[1:]): visualizer.set_cell(img_idx, viz_idx + 2, image=viz_img) + # Save results. os.system(f'cp {args.image_list} {output_dir}/image_list.txt') np.save(f'{output_dir}/inverted_codes.npy', diff --git a/manipulate.py b/manipulate.py index c4a70dd..a3ed33f 100644 --- a/manipulate.py +++ b/manipulate.py @@ -20,11 +20,11 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('image_dir', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--image_dir', type=str, help='Image directory, which includes original images, ' 'inverted codes, and image list.') - parser.add_argument('boundary_path', type=str, + parser.add_argument('--boundary_path', type=str, help='Path to the boundary for semantic manipulation.') parser.add_argument('-o', '--output_dir', type=str, default='', help='Directory to save the results. If not specified, ' diff --git a/mix_style.py b/mix_style.py index 1c2cdfc..d04a40a 100644 --- a/mix_style.py +++ b/mix_style.py @@ -23,11 +23,11 @@ def parse_args(): """Parses arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('model_name', type=str, help='Name of the GAN model.') - parser.add_argument('style_dir', type=str, + parser.add_argument('--model_name', type=str, help='Name of the GAN model.') + parser.add_argument('--style_dir', type=str, help='Style directory, which includes original images, ' 'inverted codes, and image list.') - parser.add_argument('content_dir', type=str, + parser.add_argument('--content_dir', type=str, help='Content directory, which includes original images, ' 'inverted codes, and image list.') parser.add_argument('-o', '--output_dir', type=str, default='', diff --git a/models/base_module.py b/models/base_module.py index 04a9105..5e84ac1 100644 --- a/models/base_module.py +++ b/models/base_module.py @@ -1,300 +1,300 @@ -# python 3.7 -"""Contains the base class for modules in a GAN model. - -Commonly, GAN consists of two components, i.e., generator and discriminator. -In practice, however, more modules can be added, such as encoder. -""" - -import os.path -import sys -import logging -import numpy as np - -import torch - -from . import model_settings - -__all__ = ['BaseModule'] - -DTYPE_NAME_TO_TORCH_TENSOR_TYPE = { - 'float16': torch.HalfTensor, - 'float32': torch.FloatTensor, - 'float64': torch.DoubleTensor, - 'int8': torch.CharTensor, - 'int16': torch.ShortTensor, - 'int32': torch.IntTensor, - 'int64': torch.LongTensor, - 'uint8': torch.ByteTensor, - 'bool': torch.BoolTensor, -} - - -def get_temp_logger(logger_name='logger'): - """Gets a temporary logger. - - This logger will print all levels of messages onto the screen. - - Args: - logger_name: Name of the logger. - - Returns: - A `logging.Logger`. - - Raises: - ValueError: If the input `logger_name` is empty. - """ - if not logger_name: - raise ValueError(f'Input `logger_name` should not be empty!') - - logger = logging.getLogger(logger_name) - if not logger.hasHandlers(): - logger.setLevel(logging.DEBUG) - formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s') - sh = logging.StreamHandler(stream=sys.stdout) - sh.setLevel(logging.DEBUG) - sh.setFormatter(formatter) - logger.addHandler(sh) - - return logger - - -class BaseModule(object): - """Base class for modules in GANs, like generator and discriminator. - - NOTE: The module should be defined with pytorch, and used for inference only. - """ - - def __init__(self, model_name, module_name, logger=None): - """Initializes with specific settings. - - The GAN model should be first registered in `model_settings.py` with proper - settings. Among them, some attributes are necessary, including: - - (1) resolution: Resolution of the synthesis. - (2) image_channels: Number of channels of the synthesis. (default: 3) - (3) channel_order: Channel order of the raw synthesis. (default: `RGB`) - (4) min_val: Minimum value of the raw synthesis. (default -1.0) - (5) max_val: Maximum value of the raw synthesis. (default 1.0) - - Args: - model_name: Name with which the GAN model is registered. - module_name: Name of the module, like `generator` or `discriminator`. - logger: Logger for recording log messages. If set as `None`, a default - logger, which prints messages from all levels onto the screen, will be - created. (default: None) - - Raises: - AttributeError: If some necessary attributes are missing. - """ - self.model_name = model_name - self.module_name = module_name - self.logger = logger or get_temp_logger(model_name) - - # Parse settings. - for key, val in model_settings.MODEL_POOL[model_name].items(): - setattr(self, key, val) - self.use_cuda = model_settings.USE_CUDA and torch.cuda.is_available() - self.batch_size = model_settings.MAX_IMAGES_ON_DEVICE - self.ram_size = model_settings.MAX_IMAGES_ON_RAM - self.net = None - self.run_device = 'cuda' if self.use_cuda else 'cpu' - self.cpu_device = 'cpu' - - # Check necessary settings. - self.check_attr('gan_type') # Should be specified in derived classes. - self.check_attr('resolution') - self.image_channels = getattr(self, 'image_channels', 3) - assert self.image_channels in [1, 3] - self.channel_order = getattr(self, 'channel_order', 'RGB').upper() - assert self.channel_order in ['RGB', 'BGR'] - self.min_val = getattr(self, 'min_val', -1.0) - self.max_val = getattr(self, 'max_val', 1.0) - - # Get paths. - self.weight_path = model_settings.get_weight_path( - f'{model_name}_{module_name}') - - # Build graph and load pre-trained weights. - self.logger.info(f'Build network for module `{self.module_name}` in ' - f'model `{self.model_name}`.') - self.model_specific_vars = [] - self.build() - if os.path.isfile(self.weight_path): - self.load() - else: - self.logger.warning(f'No pre-trained weights will be loaded!') - - # Change to inference mode and GPU mode if needed. - assert self.net - self.net.eval().to(self.run_device) - - def check_attr(self, attr_name): - """Checks the existence of a particular attribute. - - Args: - attr_name: Name of the attribute to check. - - Raises: - AttributeError: If the target attribute is missing. - """ - if not hasattr(self, attr_name): - raise AttributeError(f'Field `{attr_name}` is missing for ' - f'module `{self.module_name}` in ' - f'model `{self.model_name}`!') - - def build(self): - """Builds the graph.""" - raise NotImplementedError(f'Should be implemented in derived class!') - - def load(self): - """Loads pre-trained weights.""" - self.logger.info(f'Loading pytorch weights from `{self.weight_path}`.') - state_dict = torch.load(self.weight_path) - for var_name in self.model_specific_vars: - state_dict[var_name] = self.net.state_dict()[var_name] - self.net.load_state_dict(state_dict) - self.logger.info(f'Successfully loaded!') - - def to_tensor(self, array): - """Converts a `numpy.ndarray` to `torch.Tensor` on running device. - - Args: - array: The input array to convert. - - Returns: - A `torch.Tensor` whose dtype is determined by that of the input array. - - Raises: - ValueError: If the array is with neither `torch.Tensor` type nor - `numpy.ndarray` type. - """ - dtype = type(array) - if isinstance(array, torch.Tensor): - tensor = array - elif isinstance(array, np.ndarray): - tensor_type = DTYPE_NAME_TO_TORCH_TENSOR_TYPE[array.dtype.name] - tensor = torch.from_numpy(array).type(tensor_type) - else: - raise ValueError(f'Unsupported input type `{dtype}`!') - tensor = tensor.to(self.run_device) - return tensor - - def get_value(self, tensor): - """Gets value of a `torch.Tensor`. - - Args: - tensor: The input tensor to get value from. - - Returns: - A `numpy.ndarray`. - - Raises: - ValueError: If the tensor is with neither `torch.Tensor` type nor - `numpy.ndarray` type. - """ - dtype = type(tensor) - if isinstance(tensor, np.ndarray): - return tensor - if isinstance(tensor, torch.Tensor): - return tensor.to(self.cpu_device).detach().numpy() - raise ValueError(f'Unsupported input type `{dtype}`!') - - def get_ont_hot_labels(self, num, labels=None): - """Gets ont-hot labels for conditional generation. - - Args: - num: Number of labels to generate. - labels: Input labels as reference to generate one-hot labels. If set as - `None`, label `0` will be used by default. (default: None) - - Returns: - Returns `None` if `self.label_size` is 0, otherwise, a `numpy.ndarray` - with shape [num, self.label_size] and dtype `np.float32`. - """ - self.check_attr('label_size') - if self.label_size == 0: - return None - - if labels is None: - labels = 0 - labels = np.array(labels).reshape(-1) - if labels.size == 1: - labels = np.tile(labels, (num,)) - assert labels.shape == (num,) - for label in labels: - if label >= self.label_size or label < 0: - raise ValueError(f'Label should be smaller than {self.label_size}, ' - f'but {label} is received!') - - one_hot = np.zeros((num, self.label_size), dtype=np.int32) - one_hot[np.arange(num), labels] = 1 - return one_hot - - def get_batch_inputs(self, inputs, batch_size=None): - """Gets inputs within mini-batch. - - This function yields at most `self.batch_size` inputs at a time. - - Args: - inputs: Input data to form mini-batch. - batch_size: Batch size. If not specified, `self.batch_size` will be used. - (default: None) - """ - total_num = inputs.shape[0] - batch_size = batch_size or self.batch_size - for i in range(0, total_num, batch_size): - yield inputs[i:i + batch_size] - - def batch_run(self, inputs, run_fn): - """Runs model with mini-batch. - - This function splits the inputs into mini-batches, run the model with each - mini-batch, and then concatenate the outputs from all mini-batches together. - - NOTE: The output of `run_fn` can only be `numpy.ndarray` or a dictionary - whose values are all `numpy.ndarray`. - - Args: - inputs: The input samples to run with. - run_fn: A callable function. - - Returns: - Same type as the output of `run_fn`. - - Raises: - ValueError: If the output type of `run_fn` is not supported. - """ - if inputs.shape[0] > self.ram_size: - self.logger.warning(f'Number of inputs on RAM is larger than ' - f'{self.ram_size}. Please use ' - f'`self.get_batch_inputs()` to split the inputs! ' - f'Otherwise, it may encounter OOM problem!') - - results = {} - temp_key = '__temp_key__' - for batch_inputs in self.get_batch_inputs(inputs): - batch_outputs = run_fn(batch_inputs) - if isinstance(batch_outputs, dict): - for key, val in batch_outputs.items(): - if not isinstance(val, np.ndarray): - raise ValueError(f'Each item of the model output should be with ' - f'type `numpy.ndarray`, but type `{type(val)}` is ' - f'received for key `{key}`!') - if key not in results: - results[key] = [val] - else: - results[key].append(val) - elif isinstance(batch_outputs, np.ndarray): - if temp_key not in results: - results[temp_key] = [batch_outputs] - else: - results[temp_key].append(batch_outputs) - else: - raise ValueError(f'The model output can only be with type ' - f'`numpy.ndarray`, or a dictionary of ' - f'`numpy.ndarray`, but type `{type(batch_outputs)}` ' - f'is received!') - - for key, val in results.items(): - results[key] = np.concatenate(val, axis=0) - return results if temp_key not in results else results[temp_key] +# python 3.7 +"""Contains the base class for modules in a GAN model. + +Commonly, GAN consists of two components, i.e., generator and discriminator. +In practice, however, more modules can be added, such as encoder. +""" + +import os.path +import sys +import logging +import numpy as np + +import torch + +from . import model_settings + +__all__ = ['BaseModule'] + +DTYPE_NAME_TO_TORCH_TENSOR_TYPE = { + 'float16': torch.HalfTensor, + 'float32': torch.FloatTensor, + 'float64': torch.DoubleTensor, + 'int8': torch.CharTensor, + 'int16': torch.ShortTensor, + 'int32': torch.IntTensor, + 'int64': torch.LongTensor, + 'uint8': torch.ByteTensor, + 'bool': torch.BoolTensor, +} + + +def get_temp_logger(logger_name='logger'): + """Gets a temporary logger. + + This logger will print all levels of messages onto the screen. + + Args: + logger_name: Name of the logger. + + Returns: + A `logging.Logger`. + + Raises: + ValueError: If the input `logger_name` is empty. + """ + if not logger_name: + raise ValueError(f'Input `logger_name` should not be empty!') + + logger = logging.getLogger(logger_name) + if not logger.hasHandlers(): + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s') + sh = logging.StreamHandler(stream=sys.stdout) + sh.setLevel(logging.DEBUG) + sh.setFormatter(formatter) + logger.addHandler(sh) + + return logger + + +class BaseModule(object): + """Base class for modules in GANs, like generator and discriminator. + + NOTE: The module should be defined with pytorch, and used for inference only. + """ + + def __init__(self, model_name, module_name, logger=None): + """Initializes with specific settings. + + The GAN model should be first registered in `model_settings.py` with proper + settings. Among them, some attributes are necessary, including: + + (1) resolution: Resolution of the synthesis. + (2) image_channels: Number of channels of the synthesis. (default: 3) + (3) channel_order: Channel order of the raw synthesis. (default: `RGB`) + (4) min_val: Minimum value of the raw synthesis. (default -1.0) + (5) max_val: Maximum value of the raw synthesis. (default 1.0) + + Args: + model_name: Name with which the GAN model is registered. + module_name: Name of the module, like `generator` or `discriminator`. + logger: Logger for recording log messages. If set as `None`, a default + logger, which prints messages from all levels onto the screen, will be + created. (default: None) + + Raises: + AttributeError: If some necessary attributes are missing. + """ + self.model_name = model_name + self.module_name = module_name + self.logger = logger or get_temp_logger(model_name) + + # Parse settings. + for key, val in model_settings.MODEL_POOL[model_name].items(): + setattr(self, key, val) + self.use_cuda = model_settings.USE_CUDA and torch.cuda.is_available() + self.batch_size = model_settings.MAX_IMAGES_ON_DEVICE + self.ram_size = model_settings.MAX_IMAGES_ON_RAM + self.net = None + self.run_device = 'cuda' if self.use_cuda else 'cpu' + self.cpu_device = 'cpu' + + # Check necessary settings. + self.check_attr('gan_type') # Should be specified in derived classes. + self.check_attr('resolution') + self.image_channels = getattr(self, 'image_channels', 3) + assert self.image_channels in [1, 3] + self.channel_order = getattr(self, 'channel_order', 'RGB').upper() + assert self.channel_order in ['RGB', 'BGR'] + self.min_val = getattr(self, 'min_val', -1.0) + self.max_val = getattr(self, 'max_val', 1.0) + + # Get paths. + self.weight_path = model_settings.get_weight_path( + f'{model_name}_{module_name}') + + # Build graph and load pre-trained weights. + self.logger.info(f'Build network for module `{self.module_name}` in ' + f'model `{self.model_name}`.') + self.model_specific_vars = [] + self.build() + if os.path.isfile(self.weight_path): + self.load() + else: + self.logger.warning(f'No pre-trained weights will be loaded!') + + # Change to inference mode and GPU mode if needed. + assert self.net + self.net.eval().to(self.run_device) + + def check_attr(self, attr_name): + """Checks the existence of a particular attribute. + + Args: + attr_name: Name of the attribute to check. + + Raises: + AttributeError: If the target attribute is missing. + """ + if not hasattr(self, attr_name): + raise AttributeError(f'Field `{attr_name}` is missing for ' + f'module `{self.module_name}` in ' + f'model `{self.model_name}`!') + + def build(self): + """Builds the graph.""" + raise NotImplementedError(f'Should be implemented in derived class!') + + def load(self): + """Loads pre-trained weights.""" + self.logger.info(f'Loading pytorch weights from `{self.weight_path}`.') + state_dict = torch.load(self.weight_path) + for var_name in self.model_specific_vars: + state_dict[var_name] = self.net.state_dict()[var_name] + self.net.load_state_dict(state_dict) + self.logger.info(f'Successfully loaded!') + + def to_tensor(self, array): + """Converts a `numpy.ndarray` to `torch.Tensor` on running device. + + Args: + array: The input array to convert. + + Returns: + A `torch.Tensor` whose dtype is determined by that of the input array. + + Raises: + ValueError: If the array is with neither `torch.Tensor` type nor + `numpy.ndarray` type. + """ + dtype = type(array) + if isinstance(array, torch.Tensor): + tensor = array + elif isinstance(array, np.ndarray): + tensor_type = DTYPE_NAME_TO_TORCH_TENSOR_TYPE[array.dtype.name] + tensor = torch.from_numpy(array).type(tensor_type) + else: + raise ValueError(f'Unsupported input type `{dtype}`!') + tensor = tensor.to(self.run_device) + return tensor + + def get_value(self, tensor): + """Gets value of a `torch.Tensor`. + + Args: + tensor: The input tensor to get value from. + + Returns: + A `numpy.ndarray`. + + Raises: + ValueError: If the tensor is with neither `torch.Tensor` type nor + `numpy.ndarray` type. + """ + dtype = type(tensor) + if isinstance(tensor, np.ndarray): + return tensor + if isinstance(tensor, torch.Tensor): + return tensor.to(self.cpu_device).detach().numpy() + raise ValueError(f'Unsupported input type `{dtype}`!') + + def get_ont_hot_labels(self, num, labels=None): + """Gets ont-hot labels for conditional generation. + + Args: + num: Number of labels to generate. + labels: Input labels as reference to generate one-hot labels. If set as + `None`, label `0` will be used by default. (default: None) + + Returns: + Returns `None` if `self.label_size` is 0, otherwise, a `numpy.ndarray` + with shape [num, self.label_size] and dtype `np.float32`. + """ + self.check_attr('label_size') + if self.label_size == 0: + return None + + if labels is None: + labels = 0 + labels = np.array(labels).reshape(-1) + if labels.size == 1: + labels = np.tile(labels, (num,)) + assert labels.shape == (num,) + for label in labels: + if label >= self.label_size or label < 0: + raise ValueError(f'Label should be smaller than {self.label_size}, ' + f'but {label} is received!') + + one_hot = np.zeros((num, self.label_size), dtype=np.int32) + one_hot[np.arange(num), labels] = 1 + return one_hot + + def get_batch_inputs(self, inputs, batch_size=None): + """Gets inputs within mini-batch. + + This function yields at most `self.batch_size` inputs at a time. + + Args: + inputs: Input data to form mini-batch. + batch_size: Batch size. If not specified, `self.batch_size` will be used. + (default: None) + """ + total_num = inputs.shape[0] + batch_size = batch_size or self.batch_size + for i in range(0, total_num, batch_size): + yield inputs[i:i + batch_size] + + def batch_run(self, inputs, run_fn): + """Runs model with mini-batch. + + This function splits the inputs into mini-batches, run the model with each + mini-batch, and then concatenate the outputs from all mini-batches together. + + NOTE: The output of `run_fn` can only be `numpy.ndarray` or a dictionary + whose values are all `numpy.ndarray`. + + Args: + inputs: The input samples to run with. + run_fn: A callable function. + + Returns: + Same type as the output of `run_fn`. + + Raises: + ValueError: If the output type of `run_fn` is not supported. + """ + if inputs.shape[0] > self.ram_size: + self.logger.warning(f'Number of inputs on RAM is larger than ' + f'{self.ram_size}. Please use ' + f'`self.get_batch_inputs()` to split the inputs! ' + f'Otherwise, it may encounter OOM problem!') + + results = {} + temp_key = '__temp_key__' + for batch_inputs in self.get_batch_inputs(inputs): + batch_outputs = run_fn(batch_inputs) + if isinstance(batch_outputs, dict): + for key, val in batch_outputs.items(): + if not isinstance(val, np.ndarray): + raise ValueError(f'Each item of the model output should be with ' + f'type `numpy.ndarray`, but type `{type(val)}` is ' + f'received for key `{key}`!') + if key not in results: + results[key] = [val] + else: + results[key].append(val) + elif isinstance(batch_outputs, np.ndarray): + if temp_key not in results: + results[temp_key] = [batch_outputs] + else: + results[temp_key].append(batch_outputs) + else: + raise ValueError(f'The model output can only be with type ' + f'`numpy.ndarray`, or a dictionary of ' + f'`numpy.ndarray`, but type `{type(batch_outputs)}` ' + f'is received!') + + for key, val in results.items(): + results[key] = np.concatenate(val, axis=0) + return results if temp_key not in results else results[temp_key] diff --git a/models/model_settings.py b/models/model_settings.py index b6bc732..42f554c 100644 --- a/models/model_settings.py +++ b/models/model_settings.py @@ -14,6 +14,12 @@ 'final_tanh': True, 'use_bn': True, }, + 'styleganinv_ffhq256B': { + 'resolution': 256, + 'repeat_w': False, + 'final_tanh': True, + 'use_bn': True, + }, 'styleganinv_bedroom256': { 'resolution': 256, 'repeat_w': False, diff --git a/models/stylegan_generator.py b/models/stylegan_generator.py index 7089e0c..9ba39b3 100644 --- a/models/stylegan_generator.py +++ b/models/stylegan_generator.py @@ -138,6 +138,21 @@ def preprocess(self, latent_codes, latent_space_type='z', **kwargs): return latent_codes.astype(np.float32) + + def synthesizeImages(self,latent_codes,latent_space_type='Z'): + zs = latent_codes + zs = zs.to(self.run_device) + ws = self.net.mapping(zs) + ws = ws.to(self.run_device) + wps = self.net.truncation(ws) + wps = wps.to(self.run_device) + + images = self.net.synthesis(wps) + images = images.to(self.run_device) + + return images + + def _synthesize(self, latent_codes, latent_space_type='z', diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py new file mode 100644 index 0000000..738e803 --- /dev/null +++ b/pytorch_ssim/__init__.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/utils/inverter.py b/utils/inverter.py index dc029d2..017a9ea 100644 --- a/utils/inverter.py +++ b/utils/inverter.py @@ -1,327 +1,340 @@ -# python 3.7 -"""Utility functions to invert a given image back to a latent code.""" - -from tqdm import tqdm -import cv2 -import numpy as np - -import torch - -from models.stylegan_generator import StyleGANGenerator -from models.stylegan_encoder import StyleGANEncoder -from models.perceptual_model import PerceptualModel - -__all__ = ['StyleGANInverter'] - - -def _softplus(x): - """Implements the softplus function.""" - return torch.nn.functional.softplus(x, beta=1, threshold=10000) - -def _get_tensor_value(tensor): - """Gets the value of a torch Tensor.""" - return tensor.cpu().detach().numpy() - - -class StyleGANInverter(object): - """Defines the class for StyleGAN inversion. - - Even having the encoder, the output latent code is not good enough to recover - the target image satisfyingly. To this end, this class optimize the latent - code based on gradient descent algorithm. In the optimization process, - following loss functions will be considered: - - (1) Pixel-wise reconstruction loss. (required) - (2) Perceptual loss. (optional, but recommended) - (3) Regularization loss from encoder. (optional, but recommended for in-domain - inversion) - - NOTE: The encoder can be missing for inversion, in which case the latent code - will be randomly initialized and the regularization loss will be ignored. - """ - - def __init__(self, - model_name, - learning_rate=1e-2, - iteration=100, - reconstruction_loss_weight=1.0, - perceptual_loss_weight=5e-5, - regularization_loss_weight=2.0, - logger=None): - """Initializes the inverter. - - NOTE: Only Adam optimizer is supported in the optimization process. - - Args: - model_name: Name of the model on which the inverted is based. The model - should be first registered in `models/model_settings.py`. - logger: Logger to record the log message. - learning_rate: Learning rate for optimization. (default: 1e-2) - iteration: Number of iterations for optimization. (default: 100) - reconstruction_loss_weight: Weight for reconstruction loss. Should always - be a positive number. (default: 1.0) - perceptual_loss_weight: Weight for perceptual loss. 0 disables perceptual - loss. (default: 5e-5) - regularization_loss_weight: Weight for regularization loss from encoder. - This is essential for in-domain inversion. However, this loss will - automatically ignored if the generative model does not include a valid - encoder. 0 disables regularization loss. (default: 2.0) - """ - self.logger = logger - self.model_name = model_name - self.gan_type = 'stylegan' - - self.G = StyleGANGenerator(self.model_name, self.logger) - self.E = StyleGANEncoder(self.model_name, self.logger) - self.F = PerceptualModel(min_val=self.G.min_val, max_val=self.G.max_val) - self.encode_dim = [self.G.num_layers, self.G.w_space_dim] - self.run_device = self.G.run_device - assert list(self.encode_dim) == list(self.E.encode_dim) - - assert self.G.gan_type == self.gan_type - assert self.E.gan_type == self.gan_type - - self.learning_rate = learning_rate - self.iteration = iteration - self.loss_pix_weight = reconstruction_loss_weight - self.loss_feat_weight = perceptual_loss_weight - self.loss_reg_weight = regularization_loss_weight - assert self.loss_pix_weight > 0 - - - def preprocess(self, image): - """Preprocesses a single image. - - This function assumes the input numpy array is with shape [height, width, - channel], channel order `RGB`, and pixel range [0, 255]. - - The returned image is with shape [channel, new_height, new_width], where - `new_height` and `new_width` are specified by the given generative model. - The channel order of returned image is also specified by the generative - model. The pixel range is shifted to [min_val, max_val], where `min_val` and - `max_val` are also specified by the generative model. - """ - if not isinstance(image, np.ndarray): - raise ValueError(f'Input image should be with type `numpy.ndarray`!') - if image.dtype != np.uint8: - raise ValueError(f'Input image should be with dtype `numpy.uint8`!') - - if image.ndim != 3 or image.shape[2] not in [1, 3]: - raise ValueError(f'Input should be with shape [height, width, channel], ' - f'where channel equals to 1 or 3!\n' - f'But {image.shape} is received!') - if image.shape[2] == 1 and self.G.image_channels == 3: - image = np.tile(image, (1, 1, 3)) - if image.shape[2] != self.G.image_channels: - raise ValueError(f'Number of channels of input image, which is ' - f'{image.shape[2]}, is not supported by the current ' - f'inverter, which requires {self.G.image_channels} ' - f'channels!') - - if self.G.image_channels == 3 and self.G.channel_order == 'BGR': - image = image[:, :, ::-1] - if image.shape[1:3] != [self.G.resolution, self.G.resolution]: - image = cv2.resize(image, (self.G.resolution, self.G.resolution)) - image = image.astype(np.float32) - image = image / 255.0 * (self.G.max_val - self.G.min_val) + self.G.min_val - image = image.astype(np.float32).transpose(2, 0, 1) - - return image - - def get_init_code(self, image): - """Gets initial latent codes as the start point for optimization. - - The input image is assumed to have already been preprocessed, meaning to - have shape [self.G.image_channels, self.G.resolution, self.G.resolution], - channel order `self.G.channel_order`, and pixel range [self.G.min_val, - self.G.max_val]. - """ - x = image[np.newaxis] - x = self.G.to_tensor(x.astype(np.float32)) - z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim)) - return z.astype(np.float32) - - def invert(self, image, num_viz=0): - """Inverts the given image to a latent code. - - Basically, this function is based on gradient descent algorithm. - - Args: - image: Target image to invert, which is assumed to have already been - preprocessed. - num_viz: Number of intermediate outputs to visualize. (default: 0) - - Returns: - A two-element tuple. First one is the inverted code. Second one is a list - of intermediate results, where first image is the input image, second - one is the reconstructed result from the initial latent code, remainings - are from the optimization process every `self.iteration // num_viz` - steps. - """ - x = image[np.newaxis] - x = self.G.to_tensor(x.astype(np.float32)) - x.requires_grad = False - init_z = self.get_init_code(image) - z = torch.Tensor(init_z).to(self.run_device) - z.requires_grad = True - - optimizer = torch.optim.Adam([z], lr=self.learning_rate) - - viz_results = [] - viz_results.append(self.G.postprocess(_get_tensor_value(x))[0]) - x_init_inv = self.G.net.synthesis(z) - viz_results.append(self.G.postprocess(_get_tensor_value(x_init_inv))[0]) - pbar = tqdm(range(1, self.iteration + 1), leave=True) - for step in pbar: - loss = 0.0 - - # Reconstruction loss. - x_rec = self.G.net.synthesis(z) - loss_pix = torch.mean((x - x_rec) ** 2) - loss = loss + loss_pix * self.loss_pix_weight - log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}' - - # Perceptual loss. - if self.loss_feat_weight: - x_feat = self.F.net(x) - x_rec_feat = self.F.net(x_rec) - loss_feat = torch.mean((x_feat - x_rec_feat) ** 2) - loss = loss + loss_feat * self.loss_feat_weight - log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}' - - # Regularization loss. - if self.loss_reg_weight: - z_rec = self.E.net(x_rec).view(1, *self.encode_dim) - loss_reg = torch.mean((z - z_rec) ** 2) - loss = loss + loss_reg * self.loss_reg_weight - log_message += f', loss_reg: {_get_tensor_value(loss_reg):.3f}' - - log_message += f', loss: {_get_tensor_value(loss):.3f}' - pbar.set_description_str(log_message) - if self.logger: - self.logger.debug(f'Step: {step:05d}, ' - f'lr: {self.learning_rate:.2e}, ' - f'{log_message}') - - # Do optimization. - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if num_viz > 0 and step % (self.iteration // num_viz) == 0: - viz_results.append(self.G.postprocess(_get_tensor_value(x_rec))[0]) - - return _get_tensor_value(z), viz_results - - def easy_invert(self, image, num_viz=0): - """Wraps functions `preprocess()` and `invert()` together.""" - return self.invert(self.preprocess(image), num_viz) - - def diffuse(self, - target, - context, - center_x, - center_y, - crop_x, - crop_y, - num_viz=0): - """Diffuses the target image to a context image. - - Basically, this function is a motified version of `self.invert()`. More - concretely, the encoder regularizer is removed from the objectives and the - reconstruction loss is computed from the masked region. - - Args: - target: Target image (foreground). - context: Context image (background). - center_x: The x-coordinate of the crop center. - center_y: The y-coordinate of the crop center. - crop_x: The crop size along the x-axis. - crop_y: The crop size along the y-axis. - num_viz: Number of intermediate outputs to visualize. (default: 0) - - Returns: - A two-element tuple. First one is the inverted code. Second one is a list - of intermediate results, where first image is the direct copy-paste - image, second one is the reconstructed result from the initial latent - code, remainings are from the optimization process every - `self.iteration // num_viz` steps. - """ - image_shape = (self.G.image_channels, self.G.resolution, self.G.resolution) - mask = np.zeros((1, *image_shape), dtype=np.float32) - xx = center_x - crop_x // 2 - yy = center_y - crop_y // 2 - mask[:, :, yy:yy + crop_y, xx:xx + crop_x] = 1.0 - - target = target[np.newaxis] - if context.ndim == 3: - context = self.preprocess(context)[np.newaxis] - else: - contexts = [] - for i in range(context.shape[0]): - contexts.append(self.preprocess(context[i])) - context = np.asarray(contexts) - x = target * mask + context * (1 - mask) - x = self.G.to_tensor(x.astype(np.float32)) - x.requires_grad = False - mask = self.G.to_tensor(mask.astype(np.float32)) - mask.requires_grad = False - - init_z = _get_tensor_value(self.E.net(x).view(-1, *self.encode_dim)) - init_z = init_z.astype(np.float32) - z = torch.Tensor(init_z).to(self.run_device) - z.requires_grad = True - - optimizer = torch.optim.Adam([z], lr=self.learning_rate) - - copy_and_paste = self.G.postprocess(_get_tensor_value(x)) - x_init_inv = self.G.net.synthesis(z) - encoder_out = self.G.postprocess(_get_tensor_value(x_init_inv)) - viz_results = {} - for it in range(context.shape[0]): - viz_results[it] = [] - viz_results[it].append(copy_and_paste[it]) - viz_results[it].append(encoder_out[it]) - - pbar = tqdm(range(1, self.iteration + 1), leave=True) - for step in pbar: - loss = 0.0 - - # Reconstruction loss. - x_rec = self.G.net.synthesis(z) - loss_pix = torch.mean(((x - x_rec) * mask) ** 2, dim=[1, 2, 3]) - loss = loss + loss_pix * self.loss_pix_weight - log_message = f'loss_pix: {np.mean(_get_tensor_value(loss_pix)):.3f}' - - # Perceptual loss. - if self.loss_feat_weight: - x_feat = self.F.net(x * mask) - x_rec_feat = self.F.net(x_rec * mask) - loss_feat = torch.mean((x_feat - x_rec_feat) ** 2, dim=[1, 2, 3]) - loss = loss + loss_feat * self.loss_feat_weight - log_message += f', loss_feat: {np.mean(_get_tensor_value(loss_feat)):.3f}' - - log_message += f', loss: {np.mean(_get_tensor_value(loss)):.3f}' - pbar.set_description_str(log_message) - if self.logger: - self.logger.debug(f'Step: {step:05d}, ' - f'lr: {self.learning_rate:.2e}, ' - f'{log_message}') - - # Do optimization. - optimizer.zero_grad() - loss.backward(torch.ones_like(loss)) - optimizer.step() - - if num_viz > 0 and step % (self.iteration // num_viz) == 0: - rec_res = self.G.postprocess(_get_tensor_value(x_rec)) - for it in range(rec_res.shape[0]): - viz_results[it].append(rec_res[it]) - - return _get_tensor_value(z), viz_results - - def easy_diffuse(self, target, context, *args, **kwargs): - """Wraps functions `preprocess()` and `diffuse()` together.""" - return self.diffuse(self.preprocess(target), - context, - *args, **kwargs) +# python 3.7 +"""Utility functions to invert a given image back to a latent code.""" + +from tqdm import tqdm +import cv2 +import numpy as np + +import torch + +from models.stylegan_generator import StyleGANGenerator +from models.stylegan_encoder import StyleGANEncoder +from models.perceptual_model import PerceptualModel +import pytorch_ssim + +__all__ = ['StyleGANInverter'] + + +def _softplus(x): + """Implements the softplus function.""" + return torch.nn.functional.softplus(x, beta=1, threshold=10000) + +def _get_tensor_value(tensor): + """Gets the value of a torch Tensor.""" + return tensor.cpu().detach().numpy() + + +class StyleGANInverter(object): + """Defines the class for StyleGAN inversion. + + Even having the encoder, the output latent code is not good enough to recover + the target image satisfyingly. To this end, this class optimize the latent + code based on gradient descent algorithm. In the optimization process, + following loss functions will be considered: + + (1) Pixel-wise reconstruction loss. (required) + (2) Perceptual loss. (optional, but recommended) + (3) Regularization loss from encoder. (optional, but recommended for in-domain + inversion) + + NOTE: The encoder can be missing for inversion, in which case the latent code + will be randomly initialized and the regularization loss will be ignored. + """ + + def __init__(self, + model_name, + learning_rate=1e-2, + iteration=100, + reconstruction_loss_weight=1.0, + perceptual_loss_weight=5e-5, + regularization_loss_weight=2.0, + loss_weight_ssim = 1.0, + logger=None): + """Initializes the inverter. + + NOTE: Only Adam optimizer is supported in the optimization process. + + Args: + model_name: Name of the model on which the inverted is based. The model + should be first registered in `models/model_settings.py`. + logger: Logger to record the log message. + learning_rate: Learning rate for optimization. (default: 1e-2) + iteration: Number of iterations for optimization. (default: 100) + reconstruction_loss_weight: Weight for reconstruction loss. Should always + be a positive number. (default: 1.0) + perceptual_loss_weight: Weight for perceptual loss. 0 disables perceptual + loss. (default: 5e-5) + regularization_loss_weight: Weight for regularization loss from encoder. + This is essential for in-domain inversion. However, this loss will + automatically ignored if the generative model does not include a valid + encoder. 0 disables regularization loss. (default: 2.0) + """ + self.logger = logger + self.model_name = model_name + self.gan_type = 'stylegan' + + self.G = StyleGANGenerator(self.model_name, self.logger) + self.E = StyleGANEncoder(self.model_name, self.logger) + self.F = PerceptualModel(min_val=self.G.min_val, max_val=self.G.max_val) + self.encode_dim = [self.G.num_layers, self.G.w_space_dim] + self.run_device = self.G.run_device + assert list(self.encode_dim) == list(self.E.encode_dim) + + assert self.G.gan_type == self.gan_type + assert self.E.gan_type == self.gan_type + + self.learning_rate = learning_rate + self.iteration = iteration + self.loss_pix_weight = reconstruction_loss_weight + self.loss_feat_weight = perceptual_loss_weight + self.loss_reg_weight = regularization_loss_weight + self.loss_weight_ssim = loss_weight_ssim + assert self.loss_pix_weight > 0 + + + def preprocess(self, image): + """Preprocesses a single image. + + This function assumes the input numpy array is with shape [height, width, + channel], channel order `RGB`, and pixel range [0, 255]. + + The returned image is with shape [channel, new_height, new_width], where + `new_height` and `new_width` are specified by the given generative model. + The channel order of returned image is also specified by the generative + model. The pixel range is shifted to [min_val, max_val], where `min_val` and + `max_val` are also specified by the generative model. + """ + if not isinstance(image, np.ndarray): + raise ValueError(f'Input image should be with type `numpy.ndarray`!') + if image.dtype != np.uint8: + raise ValueError(f'Input image should be with dtype `numpy.uint8`!') + + if image.ndim != 3 or image.shape[2] not in [1, 3]: + raise ValueError(f'Input should be with shape [height, width, channel], ' + f'where channel equals to 1 or 3!\n' + f'But {image.shape} is received!') + if image.shape[2] == 1 and self.G.image_channels == 3: + image = np.tile(image, (1, 1, 3)) + if image.shape[2] != self.G.image_channels: + raise ValueError(f'Number of channels of input image, which is ' + f'{image.shape[2]}, is not supported by the current ' + f'inverter, which requires {self.G.image_channels} ' + f'channels!') + + if self.G.image_channels == 3 and self.G.channel_order == 'BGR': + image = image[:, :, ::-1] + if image.shape[1:3] != [self.G.resolution, self.G.resolution]: + image = cv2.resize(image, (self.G.resolution, self.G.resolution)) + image = image.astype(np.float32) + image = image / 255.0 * (self.G.max_val - self.G.min_val) + self.G.min_val + image = image.astype(np.float32).transpose(2, 0, 1) + + return image + + def get_init_code(self, image): + """Gets initial latent codes as the start point for optimization. + + The input image is assumed to have already been preprocessed, meaning to + have shape [self.G.image_channels, self.G.resolution, self.G.resolution], + channel order `self.G.channel_order`, and pixel range [self.G.min_val, + self.G.max_val]. + """ + x = image[np.newaxis] + x = self.G.to_tensor(x.astype(np.float32)) + z = _get_tensor_value(self.E.net(x).view(1, *self.encode_dim)) + return z.astype(np.float32) + + def invert(self, image, num_viz=0): + """Inverts the given image to a latent code. + + Basically, this function is based on gradient descent algorithm. + + Args: + image: Target image to invert, which is assumed to have already been + preprocessed. + num_viz: Number of intermediate outputs to visualize. (default: 0) + + Returns: + A two-element tuple. First one is the inverted code. Second one is a list + of intermediate results, where first image is the input image, second + one is the reconstructed result from the initial latent code, remainings + are from the optimization process every `self.iteration // num_viz` + steps. + """ + x = image[np.newaxis] + x = self.G.to_tensor(x.astype(np.float32)) + x.requires_grad = False + init_z = self.get_init_code(image) + z = torch.Tensor(init_z).to(self.run_device) + z.requires_grad = True + + optimizer = torch.optim.Adam([z], lr=self.learning_rate) + + viz_results = [] + viz_results.append(self.G.postprocess(_get_tensor_value(x))[0]) + x_init_inv = self.G.net.synthesis(z) + viz_results.append(self.G.postprocess(_get_tensor_value(x_init_inv))[0]) + pbar = tqdm(range(1, self.iteration + 1), leave=True) + for step in pbar: + loss = 0.0 + + # Reconstruction loss. + x_rec = self.G.net.synthesis(z) + loss_pix = torch.mean((x - x_rec) ** 2) + loss = loss + loss_pix * self.loss_pix_weight + log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}' + + # SSIM loss. + ssim_loss = pytorch_ssim.SSIM() + x_rec = self.G.net.synthesis(z) + ssim_out = -ssim_loss(x, x_rec) + + loss = loss + ssim_out * self.loss_weight_ssim + log_message += f', loss_ssim: {(- ssim_out.item()):.3f}' + + # Perceptual loss. + if self.loss_feat_weight: + x_feat = self.F.net(x) + x_rec_feat = self.F.net(x_rec) + loss_feat = torch.mean((x_feat - x_rec_feat) ** 2) + loss = loss + loss_feat * self.loss_feat_weight + log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}' + + # Regularization loss. + if self.loss_reg_weight: + z_rec = self.E.net(x_rec).view(1, *self.encode_dim) + loss_reg = torch.mean((z - z_rec) ** 2) + loss = loss + loss_reg * self.loss_reg_weight + log_message += f', loss_reg: {_get_tensor_value(loss_reg):.3f}' + + + + log_message += f', loss: {_get_tensor_value(loss):.3f}' + pbar.set_description_str(log_message) + if self.logger: + self.logger.debug(f'Step: {step:05d}, ' + f'lr: {self.learning_rate:.2e}, ' + f'{log_message}') + + # Do optimization. + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if num_viz > 0 and step % (self.iteration // num_viz) == 0: + viz_results.append(self.G.postprocess(_get_tensor_value(x_rec))[0]) + + return _get_tensor_value(z), viz_results, - ssim_out.item() + + def easy_invert(self, image, num_viz=0): + """Wraps functions `preprocess()` and `invert()` together.""" + return self.invert(self.preprocess(image), num_viz) + + def diffuse(self, + target, + context, + center_x, + center_y, + crop_x, + crop_y, + num_viz=0): + """Diffuses the target image to a context image. + + Basically, this function is a motified version of `self.invert()`. More + concretely, the encoder regularizer is removed from the objectives and the + reconstruction loss is computed from the masked region. + + Args: + target: Target image (foreground). + context: Context image (background). + center_x: The x-coordinate of the crop center. + center_y: The y-coordinate of the crop center. + crop_x: The crop size along the x-axis. + crop_y: The crop size along the y-axis. + num_viz: Number of intermediate outputs to visualize. (default: 0) + + Returns: + A two-element tuple. First one is the inverted code. Second one is a list + of intermediate results, where first image is the direct copy-paste + image, second one is the reconstructed result from the initial latent + code, remainings are from the optimization process every + `self.iteration // num_viz` steps. + """ + image_shape = (self.G.image_channels, self.G.resolution, self.G.resolution) + mask = np.zeros((1, *image_shape), dtype=np.float32) + xx = center_x - crop_x // 2 + yy = center_y - crop_y // 2 + mask[:, :, yy:yy + crop_y, xx:xx + crop_x] = 1.0 + + target = target[np.newaxis] + if context.ndim == 3: + context = self.preprocess(context)[np.newaxis] + else: + contexts = [] + for i in range(context.shape[0]): + contexts.append(self.preprocess(context[i])) + context = np.asarray(contexts) + x = target * mask + context * (1 - mask) + x = self.G.to_tensor(x.astype(np.float32)) + x.requires_grad = False + mask = self.G.to_tensor(mask.astype(np.float32)) + mask.requires_grad = False + + init_z = _get_tensor_value(self.E.net(x).view(-1, *self.encode_dim)) + init_z = init_z.astype(np.float32) + z = torch.Tensor(init_z).to(self.run_device) + z.requires_grad = True + + optimizer = torch.optim.Adam([z], lr=self.learning_rate) + + copy_and_paste = self.G.postprocess(_get_tensor_value(x)) + x_init_inv = self.G.net.synthesis(z) + encoder_out = self.G.postprocess(_get_tensor_value(x_init_inv)) + viz_results = {} + for it in range(context.shape[0]): + viz_results[it] = [] + viz_results[it].append(copy_and_paste[it]) + viz_results[it].append(encoder_out[it]) + + pbar = tqdm(range(1, self.iteration + 1), leave=True) + for step in pbar: + loss = 0.0 + + # Reconstruction loss. + x_rec = self.G.net.synthesis(z) + loss_pix = torch.mean(((x - x_rec) * mask) ** 2, dim=[1, 2, 3]) + loss = loss + loss_pix * self.loss_pix_weight + log_message = f'loss_pix: {np.mean(_get_tensor_value(loss_pix)):.3f}' + + # Perceptual loss. + if self.loss_feat_weight: + x_feat = self.F.net(x * mask) + x_rec_feat = self.F.net(x_rec * mask) + loss_feat = torch.mean((x_feat - x_rec_feat) ** 2, dim=[1, 2, 3]) + loss = loss + loss_feat * self.loss_feat_weight + log_message += f', loss_feat: {np.mean(_get_tensor_value(loss_feat)):.3f}' + + log_message += f', loss: {np.mean(_get_tensor_value(loss)):.3f}' + pbar.set_description_str(log_message) + if self.logger: + self.logger.debug(f'Step: {step:05d}, ' + f'lr: {self.learning_rate:.2e}, ' + f'{log_message}') + + # Do optimization. + optimizer.zero_grad() + loss.backward(torch.ones_like(loss)) + optimizer.step() + + if num_viz > 0 and step % (self.iteration // num_viz) == 0: + rec_res = self.G.postprocess(_get_tensor_value(x_rec)) + for it in range(rec_res.shape[0]): + viz_results[it].append(rec_res[it]) + + return _get_tensor_value(z), viz_results + + def easy_diffuse(self, target, context, *args, **kwargs): + """Wraps functions `preprocess()` and `diffuse()` together.""" + return self.diffuse(self.preprocess(target), + context, + *args, **kwargs)