-
Notifications
You must be signed in to change notification settings - Fork 15
support for encoder from gemma-3-4b-it #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f522551
d45c7f1
551ccfb
e32839d
6fd569e
f4f7477
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| import torch.nn as nn | ||
|
|
||
| from torch.utils.data import Sampler | ||
| from typing import List, Optional, Dict, Any, Union | ||
|
|
||
| from transformers import Trainer | ||
| from transformers.trainer import ( | ||
|
|
@@ -13,7 +14,6 @@ | |
| ) | ||
| import torch.nn as nn | ||
| ALL_LAYERNORM_LAYERS = [nn.LayerNorm] | ||
| from typing import List, Optional | ||
|
|
||
|
|
||
| def maybe_zero_3(param, ignore_status=False, name=None): | ||
|
|
@@ -133,6 +133,90 @@ def __iter__(self): | |
|
|
||
| class LLaVATrainer(Trainer): | ||
|
|
||
|
|
||
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we also make sure this |
||
| """ | ||
| Computes the loss. | ||
| This version uses a special manual loss calculation for gemma vision towers | ||
| and the standard internal loss for all other models, preserving original | ||
| Trainer functionality like label smoothing for the standard path. | ||
| """ | ||
| is_gemma_vision_tower = 'gemma' in getattr(self.model.config, 'mm_vision_tower', '') | ||
|
|
||
| if is_gemma_vision_tower: | ||
| # --- PATH 1: MANUAL LOSS CALCULATION FOR GEMMA (OUR FIX) --- | ||
| labels = inputs.pop("labels", None) | ||
| outputs = model(**inputs) | ||
| logits = outputs.logits | ||
|
|
||
| loss = None | ||
| if logits is not None and labels is not None: | ||
| logits_seq_len = logits.shape[1] | ||
| padded_labels = torch.full((labels.shape[0], logits_seq_len), -100, dtype=labels.dtype, device=labels.device) | ||
| padded_labels[:, :labels.shape[1]] = labels | ||
|
|
||
| shift_logits = logits[..., :-1, :].contiguous() | ||
| shift_labels = padded_labels[..., 1:].contiguous() | ||
|
|
||
| loss_fct = nn.CrossEntropyLoss() | ||
| shift_logits = shift_logits.view(-1, self.model.config.vocab_size) | ||
| shift_labels = shift_labels.view(-1) | ||
| shift_labels = shift_labels.to(shift_logits.device) | ||
| loss = loss_fct(shift_logits, shift_labels) | ||
|
|
||
| else: | ||
| # --- PATH 2: ORIGINAL TRANSFORMERS TRAINER LOGIC (FOR CLIP, ETC.) --- | ||
| num_items_in_batch = kwargs.get("num_items_in_batch", None) | ||
| if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: | ||
| labels = inputs.pop("labels") | ||
| else: | ||
| labels = None | ||
|
|
||
| if hasattr(self, "model_accepts_loss_kwargs") and self.model_accepts_loss_kwargs: | ||
| loss_kwargs = {} | ||
| if num_items_in_batch is not None: | ||
| loss_kwargs["num_items_in_batch"] = num_items_in_batch | ||
| inputs = {**inputs, **loss_kwargs} | ||
|
|
||
| outputs = model(**inputs) | ||
|
|
||
| if self.args.past_index >= 0: | ||
| self._past = outputs[self.args.past_index] | ||
|
|
||
| if labels is not None: | ||
| if hasattr(self, "accelerator"): | ||
| unwrapped_model = self.accelerator.unwrap_model(model) | ||
| else: | ||
| unwrapped_model = model | ||
|
|
||
| if hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model.base_model, "model") and hasattr(unwrapped_model.base_model.model, "_get_name"): | ||
| model_name = unwrapped_model.base_model.model._get_name() | ||
| elif hasattr(unwrapped_model, "_get_name"): | ||
| model_name = unwrapped_model._get_name() | ||
| else: | ||
| model_name = unwrapped_model.__class__.__name__ | ||
|
|
||
| from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES | ||
|
|
||
| if self.compute_loss_func is not None: | ||
| loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) | ||
| elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): | ||
| loss = self.label_smoother(outputs, labels, shift_labels=True) | ||
| else: | ||
| loss = self.label_smoother(outputs, labels) | ||
| else: | ||
| if isinstance(outputs, dict) and "loss" not in outputs: | ||
| raise ValueError( | ||
| "The model did not return a loss from the inputs, only the following keys: " | ||
| f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." | ||
| ) | ||
| loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] | ||
|
|
||
| if hasattr(self, "model_accepts_loss_kwargs") and self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: | ||
| loss *= self.accelerator.num_processes | ||
|
|
||
| return (loss, outputs) if return_outputs else loss | ||
|
|
||
| def _get_train_sampler(self, train_dataset=None) -> Optional[torch.utils.data.Sampler]: | ||
| if train_dataset is None: | ||
| train_dataset = self.train_dataset | ||
|
|
@@ -151,12 +235,6 @@ def _get_train_sampler(self, train_dataset=None) -> Optional[torch.utils.data.Sa | |
| return super()._get_train_sampler(train_dataset) | ||
|
|
||
| def create_optimizer(self): | ||
| """ | ||
| Setup the optimizer. | ||
|
|
||
| We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | ||
| Trainer's init through `optimizers`, or subclass and override this method in a subclass. | ||
| """ | ||
| if is_sagemaker_mp_enabled(): | ||
| return super().create_optimizer() | ||
|
|
||
|
|
@@ -238,7 +316,6 @@ def _save_checkpoint(self, model, trial, metrics=None): | |
| run_dir = self._get_output_dir(trial=trial) | ||
| output_dir = os.path.join(run_dir, checkpoint_folder) | ||
|
|
||
| # Only save Adapter | ||
| keys_to_match = ['mm_projector', 'vision_resampler'] | ||
| if getattr(self.args, "use_im_start_end", False): | ||
| keys_to_match.extend(['embed_tokens', 'embed_in']) | ||
|
|
@@ -249,24 +326,19 @@ def _save_checkpoint(self, model, trial, metrics=None): | |
| self.model.config.save_pretrained(output_dir) | ||
| torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) | ||
| else: | ||
| #super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) | ||
| # fix for newer transformer | ||
| if metrics: | ||
| # log/save metrics manually if needed | ||
| logger.info(f"Metrics while saving checkpoints {metrics}") | ||
|
|
||
|
|
||
| super(LLaVATrainer, self)._save_checkpoint(model, trial) | ||
|
|
||
| def _save(self, output_dir: Optional[str] = None, state_dict=None): | ||
| if getattr(self.args, 'tune_mm_mlp_adapter', False): | ||
| pass | ||
| else: | ||
| # Patch generation_config to avoid HF validation error | ||
| if hasattr(self.model, "generation_config") and self.model.generation_config is not None: | ||
| gen_cfg = self.model.generation_config | ||
| if not gen_cfg.do_sample: | ||
| gen_cfg.temperature = None | ||
| gen_cfg.top_p = None | ||
| self.model.generation_config = gen_cfg | ||
| super(LLaVATrainer, self)._save(output_dir, state_dict) | ||
| super(LLaVATrainer, self)._save(output_dir, state_dict) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in your
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| #!/bin/bash | ||
|
|
||
| export PROMPT_VERSION=plain | ||
|
|
||
| deepspeed llava/train/train_mem.py \ | ||
| --deepspeed ./scripts/zero2.json \ | ||
| --model_name_or_path lmsys/vicuna-7b-v1.5 \ | ||
| --version $PROMPT_VERSION \ | ||
| --data_path /dev/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ | ||
| --image_folder /dev/data/images \ | ||
| --mm_projector_type mlp2x_gelu \ | ||
| --tune_mm_mlp_adapter True \ | ||
| --freeze_backbone True \ | ||
| --mm_vision_select_layer -2 \ | ||
| --vision_tower akshataa/gemma3-4b_it_siglip_encoder \ | ||
| --mm_use_im_start_end False \ | ||
| --mm_use_im_patch_token False \ | ||
| --bf16 True \ | ||
| --output_dir checkpoints/llava-$MODEL_VERSION-pretrain \ | ||
| --num_train_epochs 1 \ | ||
| --per_device_train_batch_size 1 \ | ||
| --per_device_eval_batch_size 1 \ | ||
| --gradient_accumulation_steps 4 \ | ||
| --save_strategy "steps" \ | ||
| --save_steps 24000 \ | ||
| --save_total_limit 1 \ | ||
| --learning_rate 2e-5 \ | ||
| --mm_projector_lr 2e-6 \ | ||
| --max_grad_norm 1.0 \ | ||
| --weight_decay 0. \ | ||
| --warmup_ratio 0.03 \ | ||
| --tf32 True \ | ||
| --lr_scheduler_type "cosine" \ | ||
| --logging_steps 1 \ | ||
| --model_max_length 2048 \ | ||
| --gradient_checkpointing True \ | ||
| --dataloader_num_workers 4 \ | ||
| --lazy_preprocess True \ | ||
| --report_to wandb |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| from llava.model.multimodal_encoder.siglip_encoder import SiglipVisionTower | ||
| from PIL import Image | ||
| import torch | ||
|
|
||
| class Args: | ||
| def __init__(self): | ||
| self.mm_vision_select_layer = -2 # Common choice: second-to-last layer | ||
| self.mm_vision_select_feature = 'patch' # or 'cls_patch' | ||
|
|
||
| args = Args() | ||
|
|
||
| local_ckpt_path = "gemma3_siglip_encoder" # Your extracted encoder path | ||
| tower = SiglipVisionTower( | ||
| vision_tower=local_ckpt_path, | ||
| args=args, | ||
| delay_load=False | ||
| ) | ||
|
|
||
| img = Image.new('RGB', (224, 224), color='red') # simple red square | ||
| pix = tower.image_processor(images=img, return_tensors='pt')['pixel_values'] | ||
|
|
||
| with torch.no_grad(): | ||
| feats = tower(pix) | ||
|
|
||
| print("Output shape:", feats.shape) # expect (1, num_patches, hidden_size) | ||
| print("Dummy feature shape:", tower.dummy_feature.shape) | ||
|
|
||
| assert feats.ndim == 3, f"Expected 3D tensor, got {feats.ndim}D" | ||
| assert tower.dummy_feature.ndim == 2, f"Expected 2D dummy feature, got {tower.dummy_feature.ndim}D" | ||
|
|
||
| # Additional checks | ||
| print(f"Hidden size: {feats.shape[-1]}") | ||
| print(f"Number of patches: {feats.shape[1]}") | ||
| print(f"Vision tower config: {tower.config}") | ||
| print("SigLIP encoder is working ✓") | ||
|
|
||
| try: | ||
| real_img = Image.open('sample.jpg') | ||
| real_pix = tower.image_processor(images=real_img, return_tensors='pt')['pixel_values'] | ||
| real_feats = tower(real_pix) | ||
| print("Real image shape:", real_feats.shape) | ||
| pass | ||
| except Exception as e: | ||
| print(f"Real image test failed: {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for reference, this is how
compute_losslooks like in HF trainer class https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618