Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions examples/janus/janus/train/mix_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import bisect

import mindspore as ms
import numpy as np
from mindspore.dataset import WeightedRandomSampler, DistributedSampler

from janus.models import VLChatProcessor
from janus.train.t2i_dataset import TextImageDataset
from janus.train.text_dataset import TextDataset
from janus.train.vqa_dataset import VqaDataset


class MixDataset:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's an extra space between class and the class name. According to PEP 8, there should be a single space.

Suggested change
class MixDataset:
class MixDataset:

"""
Mixed dataset that outputs pure text, multi-modal and text-to-image data.
"""
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
Comment on lines +18 to +24

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This manual implementation of cumsum can be simplified and made more efficient by using numpy.cumsum. Since numpy is already a dependency, this would be more idiomatic and concise. bisect works fine with NumPy arrays.

Suggested change
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def cumsum(sequence):
return np.cumsum([len(e) for e in sequence])


def __init__(self, datasets, default_image_shape=(1, 3, 384, 384), max_token_length=1024):
self.default_image_shape = default_image_shape
self.max_token_length = max_token_length
self.datasets = datasets
self.num_dataset = len(datasets)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The instance variable self.num_dataset is assigned but never used. It should be removed to avoid confusion and keep the code clean.

self.cumulative_sizes = self.cumsum(self.datasets)

def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx ==0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]

ret = self.datasets[dataset_idx][sample_idx]

# add image and image_seq_mask item to pure text for batching
if dataset_idx == 0:
image = np.zeros(self.default_image_shape, np.float32)
image_seq_mask = np.zeros((self.max_token_length), dtype=np.bool)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The np.bool alias is deprecated since NumPy 1.20 and will be removed in a future version. You should use the standard Python bool type instead for forward compatibility.

Suggested change
image_seq_mask = np.zeros((self.max_token_length), dtype=np.bool)
image_seq_mask = np.zeros((self.max_token_length), dtype=bool)

ret += (image_seq_mask, image)

return ret

def __len__(self):
return self.cumulative_sizes[-1]


def create_mix_dataloader(
vl_chat_processor,
t2i_csv_path="./datasets/jade/csvfile/image_text_en.csv",
t2i_data_dir="./datasets",
text_dataset_name="pubmedqa",
text_data_dir="./datasets/PubMedQA",
vqa_dataset_name="medical-vqa",
vqa_data_dir="./datasets/medical-vqa",
max_token_length=1024,
image_size=384,
null_prompt_prob=0.0,
batch_size=1,
num_parallel_workers=1,
rank=0,
rank_size=1,
num_samples=100,
sample_ratios=(1, 5, 4)):

dataset_text = TextDataset(
dataset_name=text_dataset_name,
data_dir=text_data_dir,
vl_chat_processor=vl_chat_processor,
max_token_length=max_token_length,
num_samples=num_samples,
)

dataset_t2i = TextImageDataset(
csv_path=t2i_csv_path,
data_dir=t2i_data_dir,
vl_chat_processor=vl_chat_processor,
max_token_length=max_token_length,
image_size=image_size,
null_prompt_prob=null_prompt_prob,
num_samples=num_samples,
)

dataset_vqa = VqaDataset(
dataset_name=vqa_dataset_name,
data_dir=vqa_data_dir,
vl_chat_processor=vl_chat_processor,
max_token_length=max_token_length,
num_samples=num_samples,
)

datasets = [dataset_text, dataset_vqa, dataset_t2i] # keep the right order
mix_dataset = MixDataset(datasets=datasets,
default_image_shape=(1, 3, image_size, image_size),
max_token_length=max_token_length)

sample_weights = []
assert len(sample_ratios) == len(datasets)
for i in range(len(sample_ratios)):
weight = sample_ratios[i] * len(mix_dataset) / len(datasets[i])
sample_weights += [weight] * len(datasets[i])
Comment on lines +103 to +107

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop for creating sample_weights can be written more concisely and efficiently using NumPy operations. Using np.repeat is more idiomatic for this kind of task and can be faster for large datasets.

Suggested change
sample_weights = []
assert len(sample_ratios) == len(datasets)
for i in range(len(sample_ratios)):
weight = sample_ratios[i] * len(mix_dataset) / len(datasets[i])
sample_weights += [weight] * len(datasets[i])
assert len(sample_ratios) == len(datasets)
weights_per_dataset = np.array([
r * len(mix_dataset) / len(d) for r, d in zip(sample_ratios, datasets)
])
counts = [len(d) for d in datasets]
sample_weights = np.repeat(weights_per_dataset, counts).tolist()


sampler = DistributedSampler(num_shards=rank_size, shard_id=rank, shuffle=False)
sampler.add_child(WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True))

dataloader = ms.dataset.GeneratorDataset(
source=mix_dataset,
sampler=sampler,
column_names=["task_type", "input_ids", "labels", "attention_mask", "image_seq_mask", "image"],
num_parallel_workers=num_parallel_workers,
python_multiprocessing=True,
)

dataloader = dataloader.batch(batch_size, drop_remainder=True)

return dataloader

if __name__ == "__main__":
pretrain_model_path = "./ckpts/Janus-Pro-1B"
vl_chat_processor = VLChatProcessor.from_pretrained(pretrain_model_path)
dataloader = create_mix_dataloader(vl_chat_processor, batch_size=2)
for data in dataloader.create_dict_iterator():
print(data)
break