-
Notifications
You must be signed in to change notification settings - Fork 87
Janus pro mix dataset #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: janus
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,130 @@ | ||||||||||||||||||||||||
| import bisect | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| import mindspore as ms | ||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||
| from mindspore.dataset import WeightedRandomSampler, DistributedSampler | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| from janus.models import VLChatProcessor | ||||||||||||||||||||||||
| from janus.train.t2i_dataset import TextImageDataset | ||||||||||||||||||||||||
| from janus.train.text_dataset import TextDataset | ||||||||||||||||||||||||
| from janus.train.vqa_dataset import VqaDataset | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class MixDataset: | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Mixed dataset that outputs pure text, multi-modal and text-to-image data. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||
| def cumsum(sequence): | ||||||||||||||||||||||||
| r, s = [], 0 | ||||||||||||||||||||||||
| for e in sequence: | ||||||||||||||||||||||||
| l = len(e) | ||||||||||||||||||||||||
| r.append(l + s) | ||||||||||||||||||||||||
| s += l | ||||||||||||||||||||||||
| return r | ||||||||||||||||||||||||
|
Comment on lines
+18
to
+24
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This manual implementation of
Suggested change
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def __init__(self, datasets, default_image_shape=(1, 3, 384, 384), max_token_length=1024): | ||||||||||||||||||||||||
| self.default_image_shape = default_image_shape | ||||||||||||||||||||||||
| self.max_token_length = max_token_length | ||||||||||||||||||||||||
| self.datasets = datasets | ||||||||||||||||||||||||
| self.num_dataset = len(datasets) | ||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||
| self.cumulative_sizes = self.cumsum(self.datasets) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def __getitem__(self, idx): | ||||||||||||||||||||||||
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | ||||||||||||||||||||||||
| if dataset_idx ==0: | ||||||||||||||||||||||||
| sample_idx = idx | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| ret = self.datasets[dataset_idx][sample_idx] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # add image and image_seq_mask item to pure text for batching | ||||||||||||||||||||||||
| if dataset_idx == 0: | ||||||||||||||||||||||||
| image = np.zeros(self.default_image_shape, np.float32) | ||||||||||||||||||||||||
| image_seq_mask = np.zeros((self.max_token_length), dtype=np.bool) | ||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||
| ret += (image_seq_mask, image) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return ret | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def __len__(self): | ||||||||||||||||||||||||
| return self.cumulative_sizes[-1] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def create_mix_dataloader( | ||||||||||||||||||||||||
| vl_chat_processor, | ||||||||||||||||||||||||
| t2i_csv_path="./datasets/jade/csvfile/image_text_en.csv", | ||||||||||||||||||||||||
| t2i_data_dir="./datasets", | ||||||||||||||||||||||||
| text_dataset_name="pubmedqa", | ||||||||||||||||||||||||
| text_data_dir="./datasets/PubMedQA", | ||||||||||||||||||||||||
| vqa_dataset_name="medical-vqa", | ||||||||||||||||||||||||
| vqa_data_dir="./datasets/medical-vqa", | ||||||||||||||||||||||||
| max_token_length=1024, | ||||||||||||||||||||||||
| image_size=384, | ||||||||||||||||||||||||
| null_prompt_prob=0.0, | ||||||||||||||||||||||||
| batch_size=1, | ||||||||||||||||||||||||
| num_parallel_workers=1, | ||||||||||||||||||||||||
| rank=0, | ||||||||||||||||||||||||
| rank_size=1, | ||||||||||||||||||||||||
| num_samples=100, | ||||||||||||||||||||||||
| sample_ratios=(1, 5, 4)): | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| dataset_text = TextDataset( | ||||||||||||||||||||||||
| dataset_name=text_dataset_name, | ||||||||||||||||||||||||
| data_dir=text_data_dir, | ||||||||||||||||||||||||
| vl_chat_processor=vl_chat_processor, | ||||||||||||||||||||||||
| max_token_length=max_token_length, | ||||||||||||||||||||||||
| num_samples=num_samples, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| dataset_t2i = TextImageDataset( | ||||||||||||||||||||||||
| csv_path=t2i_csv_path, | ||||||||||||||||||||||||
| data_dir=t2i_data_dir, | ||||||||||||||||||||||||
| vl_chat_processor=vl_chat_processor, | ||||||||||||||||||||||||
| max_token_length=max_token_length, | ||||||||||||||||||||||||
| image_size=image_size, | ||||||||||||||||||||||||
| null_prompt_prob=null_prompt_prob, | ||||||||||||||||||||||||
| num_samples=num_samples, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| dataset_vqa = VqaDataset( | ||||||||||||||||||||||||
| dataset_name=vqa_dataset_name, | ||||||||||||||||||||||||
| data_dir=vqa_data_dir, | ||||||||||||||||||||||||
| vl_chat_processor=vl_chat_processor, | ||||||||||||||||||||||||
| max_token_length=max_token_length, | ||||||||||||||||||||||||
| num_samples=num_samples, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| datasets = [dataset_text, dataset_vqa, dataset_t2i] # keep the right order | ||||||||||||||||||||||||
| mix_dataset = MixDataset(datasets=datasets, | ||||||||||||||||||||||||
| default_image_shape=(1, 3, image_size, image_size), | ||||||||||||||||||||||||
| max_token_length=max_token_length) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| sample_weights = [] | ||||||||||||||||||||||||
| assert len(sample_ratios) == len(datasets) | ||||||||||||||||||||||||
| for i in range(len(sample_ratios)): | ||||||||||||||||||||||||
| weight = sample_ratios[i] * len(mix_dataset) / len(datasets[i]) | ||||||||||||||||||||||||
| sample_weights += [weight] * len(datasets[i]) | ||||||||||||||||||||||||
|
Comment on lines
+103
to
+107
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loop for creating
Suggested change
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| sampler = DistributedSampler(num_shards=rank_size, shard_id=rank, shuffle=False) | ||||||||||||||||||||||||
| sampler.add_child(WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| dataloader = ms.dataset.GeneratorDataset( | ||||||||||||||||||||||||
| source=mix_dataset, | ||||||||||||||||||||||||
| sampler=sampler, | ||||||||||||||||||||||||
| column_names=["task_type", "input_ids", "labels", "attention_mask", "image_seq_mask", "image"], | ||||||||||||||||||||||||
| num_parallel_workers=num_parallel_workers, | ||||||||||||||||||||||||
| python_multiprocessing=True, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| dataloader = dataloader.batch(batch_size, drop_remainder=True) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return dataloader | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||
| pretrain_model_path = "./ckpts/Janus-Pro-1B" | ||||||||||||||||||||||||
| vl_chat_processor = VLChatProcessor.from_pretrained(pretrain_model_path) | ||||||||||||||||||||||||
| dataloader = create_mix_dataloader(vl_chat_processor, batch_size=2) | ||||||||||||||||||||||||
| for data in dataloader.create_dict_iterator(): | ||||||||||||||||||||||||
| print(data) | ||||||||||||||||||||||||
| break | ||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an extra space between
classand the class name. According to PEP 8, there should be a single space.