diff --git a/dataset/dataset_DINet_clip.py b/dataset/dataset_DINet_clip.py index 7d5e8c6..e869db1 100644 --- a/dataset/dataset_DINet_clip.py +++ b/dataset/dataset_DINet_clip.py @@ -3,6 +3,7 @@ import json import random import cv2 +import os from torch.utils.data import Dataset @@ -42,7 +43,13 @@ def __getitem__(self, index): reference_clip_list = [] for source_frame_index in range(2, 2 + 5): ## load source clip - source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1] + source_image_path = os.path.join(*source_image_path_list[source_frame_index].replace('\\', '/').split('/')) # fix path error + if not os.path.exists(source_image_path): + raise FileNotFoundError(f'{source_image_path} does not exist') + source_image_data = cv2.imread(source_image_path) + if source_image_data is None: + raise IOError(f'Failed to open{source_image_path}') + source_image_data =source_image_data[:, :, ::-1] source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0 source_clip_list.append(source_image_data) source_image_mask = source_image_data.copy() @@ -62,8 +69,13 @@ def __getitem__(self, index): reference_frame_path_list = self.data_dic[video_name]['clip_data_list'][reference_anchor][ 'frame_path_list'] reference_random_index = random.sample(range(9), 1)[0] - reference_frame_path = reference_frame_path_list[reference_random_index] - reference_frame_data = cv2.imread(reference_frame_path)[:, :, ::-1] + reference_frame_path = os.path.join(*reference_frame_path_list[reference_random_index].replace('\\', '/').split('/')) # fix path error + if not os.path.exists(reference_frame_path): + raise FileNotFoundError(f'{reference_frame_path} does not exist') + reference_frame_data = cv2.imread(reference_frame_path) + if reference_frame_data is None: + raise IOError(f'Failed to open{reference_frame_path}') + reference_frame_data = reference_frame_data[:, :, ::-1] reference_frame_data = cv2.resize(reference_frame_data, (self.img_w, self.img_h)) / 255.0 reference_frame_list.append(reference_frame_data) reference_clip_list.append(np.concatenate(reference_frame_list, 2)) diff --git a/dataset/dataset_DINet_frame.py b/dataset/dataset_DINet_frame.py index 4a74e01..cabbefa 100644 --- a/dataset/dataset_DINet_frame.py +++ b/dataset/dataset_DINet_frame.py @@ -3,6 +3,7 @@ import json import random import cv2 +import os from torch.utils.data import Dataset @@ -17,7 +18,7 @@ def get_data(json_name,augment_num): data_dic_name_list.append(video_name) random.shuffle(data_dic_name_list) print('finish loading') - return data_dic_name_list,data_dic + return data_dic_name_list, data_dic class DINetDataset(Dataset): @@ -39,7 +40,14 @@ def __getitem__(self, index): ## load source image source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list'] source_random_index = random.sample(range(2, 7), 1)[0] - source_image_data = cv2.imread(source_image_path_list[source_random_index])[:, :, ::-1] + ## modify the path + source_image_path = os.path.join(*source_image_path_list[source_random_index].replace('\\','/').split('/')) # fix the path problem + if not os.path.exists(source_image_path): + raise FileNotFoundError(f"{source_image_path} does not exist") + source_image_data = cv2.imread(source_image_path) + if source_image_data is None: + raise IOError(f"Failed to open {source_image_path}") + source_image_data = source_image_data[:, :, ::-1] source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h))/ 255.0 source_image_mask = source_image_data.copy() source_image_mask[self.radius:self.radius+self.mouth_region_size,self.radius_1_4:self.radius_1_4 +self.mouth_region_size ,:] = 0 @@ -52,8 +60,14 @@ def __getitem__(self, index): for reference_anchor in reference_anchor_list: reference_frame_path_list = self.data_dic[video_name]['clip_data_list'][reference_anchor]['frame_path_list'] reference_random_index = random.sample(range(9), 1)[0] - reference_frame_path = reference_frame_path_list[reference_random_index] - reference_frame_data = cv2.imread(reference_frame_path)[:, :, ::-1] + ## modify the path + reference_frame_path = os.path.join(*reference_frame_path_list[reference_random_index].replace('\\','/').split('/')) # fix the path problem + if not os.path.exists(reference_frame_path): + raise FileNotFoundError(f"{reference_frame_path} does not exsit") + reference_frame_data = cv2.imread(reference_frame_path) + if reference_frame_data is None: + raise IOError(f"Failed to open {reference_frame_path}") + reference_frame_data = reference_frame_data[:, :, ::-1] reference_frame_data = cv2.resize(reference_frame_data, (self.img_w, self.img_h))/ 255.0 reference_frame_data_list.append(reference_frame_data) reference_clip_data = np.concatenate(reference_frame_data_list, 2)