Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions dataset/dataset_DINet_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import random
import cv2
import os

from torch.utils.data import Dataset

Expand Down Expand Up @@ -42,7 +43,13 @@ def __getitem__(self, index):
reference_clip_list = []
for source_frame_index in range(2, 2 + 5):
## load source clip
source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1]
source_image_path = os.path.join(*source_image_path_list[source_frame_index].replace('\\', '/').split('/')) # fix path error
if not os.path.exists(source_image_path):
raise FileNotFoundError(f'{source_image_path} does not exist')
source_image_data = cv2.imread(source_image_path)
if source_image_data is None:
raise IOError(f'Failed to open{source_image_path}')
source_image_data =source_image_data[:, :, ::-1]
source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0
source_clip_list.append(source_image_data)
source_image_mask = source_image_data.copy()
Expand All @@ -62,8 +69,13 @@ def __getitem__(self, index):
reference_frame_path_list = self.data_dic[video_name]['clip_data_list'][reference_anchor][
'frame_path_list']
reference_random_index = random.sample(range(9), 1)[0]
reference_frame_path = reference_frame_path_list[reference_random_index]
reference_frame_data = cv2.imread(reference_frame_path)[:, :, ::-1]
reference_frame_path = os.path.join(*reference_frame_path_list[reference_random_index].replace('\\', '/').split('/')) # fix path error
if not os.path.exists(reference_frame_path):
raise FileNotFoundError(f'{reference_frame_path} does not exist')
reference_frame_data = cv2.imread(reference_frame_path)
if reference_frame_data is None:
raise IOError(f'Failed to open{reference_frame_path}')
reference_frame_data = reference_frame_data[:, :, ::-1]
reference_frame_data = cv2.resize(reference_frame_data, (self.img_w, self.img_h)) / 255.0
reference_frame_list.append(reference_frame_data)
reference_clip_list.append(np.concatenate(reference_frame_list, 2))
Expand Down
22 changes: 18 additions & 4 deletions dataset/dataset_DINet_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import random
import cv2
import os

from torch.utils.data import Dataset

Expand All @@ -17,7 +18,7 @@ def get_data(json_name,augment_num):
data_dic_name_list.append(video_name)
random.shuffle(data_dic_name_list)
print('finish loading')
return data_dic_name_list,data_dic
return data_dic_name_list, data_dic


class DINetDataset(Dataset):
Expand All @@ -39,7 +40,14 @@ def __getitem__(self, index):
## load source image
source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list']
source_random_index = random.sample(range(2, 7), 1)[0]
source_image_data = cv2.imread(source_image_path_list[source_random_index])[:, :, ::-1]
## modify the path
source_image_path = os.path.join(*source_image_path_list[source_random_index].replace('\\','/').split('/')) # fix the path problem
if not os.path.exists(source_image_path):
raise FileNotFoundError(f"{source_image_path} does not exist")
source_image_data = cv2.imread(source_image_path)
if source_image_data is None:
raise IOError(f"Failed to open {source_image_path}")
source_image_data = source_image_data[:, :, ::-1]
source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h))/ 255.0
source_image_mask = source_image_data.copy()
source_image_mask[self.radius:self.radius+self.mouth_region_size,self.radius_1_4:self.radius_1_4 +self.mouth_region_size ,:] = 0
Expand All @@ -52,8 +60,14 @@ def __getitem__(self, index):
for reference_anchor in reference_anchor_list:
reference_frame_path_list = self.data_dic[video_name]['clip_data_list'][reference_anchor]['frame_path_list']
reference_random_index = random.sample(range(9), 1)[0]
reference_frame_path = reference_frame_path_list[reference_random_index]
reference_frame_data = cv2.imread(reference_frame_path)[:, :, ::-1]
## modify the path
reference_frame_path = os.path.join(*reference_frame_path_list[reference_random_index].replace('\\','/').split('/')) # fix the path problem
if not os.path.exists(reference_frame_path):
raise FileNotFoundError(f"{reference_frame_path} does not exsit")
reference_frame_data = cv2.imread(reference_frame_path)
if reference_frame_data is None:
raise IOError(f"Failed to open {reference_frame_path}")
reference_frame_data = reference_frame_data[:, :, ::-1]
reference_frame_data = cv2.resize(reference_frame_data, (self.img_w, self.img_h))/ 255.0
reference_frame_data_list.append(reference_frame_data)
reference_clip_data = np.concatenate(reference_frame_data_list, 2)
Expand Down