Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .multimodal_projector.builder import build_vision_projector

from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

import torch.nn.functional as F
from llava.mm_utils import get_anyres_image_grid_shape


Expand Down Expand Up @@ -137,9 +137,14 @@ def get_model(self):
def get_vision_tower(self):
return self.get_model().get_vision_tower()


def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = self.get_model().mm_projector(image_features)

if 'gemma' in getattr(self.get_model().config, 'mm_vision_tower', ''):
image_features = torch.clamp(image_features, min=-10.0, max=10.0)

return image_features

def prepare_inputs_labels_for_multimodal(
Expand Down
4 changes: 2 additions & 2 deletions llava/model/multimodal_encoder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
from .siglip_encoder import SiglipVisionTower
from .aimv2_encoder import Aimv2VisionTower

def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
use_s2 = getattr(vision_tower_cfg, "s2", False)

if vision_tower and ("apple/aimv2" in vision_tower or "aim-v2" in vision_tower.lower() or "aimv2" in vision_tower.lower()):
return Aimv2VisionTower(vision_tower, args=vision_tower_cfg, **kwargs)


if "siglip" in vision_tower.lower():
if "siglip" in vision_tower.lower() or "gemma" in vision_tower.lower() :
return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)

if os.path.exists(vision_tower) or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
Expand Down
100 changes: 86 additions & 14 deletions llava/train/llava_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn

from torch.utils.data import Sampler
from typing import List, Optional, Dict, Any, Union

from transformers import Trainer
from transformers.trainer import (
Expand All @@ -13,7 +14,6 @@
)
import torch.nn as nn
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
from typing import List, Optional


def maybe_zero_3(param, ignore_status=False, name=None):
Expand Down Expand Up @@ -133,6 +133,90 @@ def __iter__(self):

class LLaVATrainer(Trainer):


def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also make sure this compute_loss is bypassed in case of other encoders. I am still not sure why exactly we need a custom loss here. Is it to handle sequence mismatch problem? If so why is it not a problem in other encoders? Or is there any other reason? So unless those are clear, lets make sure we do this custom compute_loss only for gemma3-siglip

"""
Computes the loss.
This version uses a special manual loss calculation for gemma vision towers
and the standard internal loss for all other models, preserving original
Trainer functionality like label smoothing for the standard path.
"""
is_gemma_vision_tower = 'gemma' in getattr(self.model.config, 'mm_vision_tower', '')

if is_gemma_vision_tower:
# --- PATH 1: MANUAL LOSS CALCULATION FOR GEMMA (OUR FIX) ---
labels = inputs.pop("labels", None)
outputs = model(**inputs)
logits = outputs.logits

loss = None
if logits is not None and labels is not None:
logits_seq_len = logits.shape[1]
padded_labels = torch.full((labels.shape[0], logits_seq_len), -100, dtype=labels.dtype, device=labels.device)
padded_labels[:, :labels.shape[1]] = labels

shift_logits = logits[..., :-1, :].contiguous()
shift_labels = padded_labels[..., 1:].contiguous()

loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

else:
# --- PATH 2: ORIGINAL TRANSFORMERS TRAINER LOGIC (FOR CLIP, ETC.) ---
num_items_in_batch = kwargs.get("num_items_in_batch", None)
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None

if hasattr(self, "model_accepts_loss_kwargs") and self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}

outputs = model(**inputs)

if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]

if labels is not None:
if hasattr(self, "accelerator"):
unwrapped_model = self.accelerator.unwrap_model(model)
else:
unwrapped_model = model

if hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model.base_model, "model") and hasattr(unwrapped_model.base_model.model, "_get_name"):
model_name = unwrapped_model.base_model.model._get_name()
elif hasattr(unwrapped_model, "_get_name"):
model_name = unwrapped_model._get_name()
else:
model_name = unwrapped_model.__class__.__name__

from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

if self.compute_loss_func is not None:
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

if hasattr(self, "model_accepts_loss_kwargs") and self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes

return (loss, outputs) if return_outputs else loss

def _get_train_sampler(self, train_dataset=None) -> Optional[torch.utils.data.Sampler]:
if train_dataset is None:
train_dataset = self.train_dataset
Expand All @@ -151,12 +235,6 @@ def _get_train_sampler(self, train_dataset=None) -> Optional[torch.utils.data.Sa
return super()._get_train_sampler(train_dataset)

def create_optimizer(self):
"""
Setup the optimizer.

We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()

Expand Down Expand Up @@ -238,7 +316,6 @@ def _save_checkpoint(self, model, trial, metrics=None):
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)

# Only save Adapter
keys_to_match = ['mm_projector', 'vision_resampler']
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(['embed_tokens', 'embed_in'])
Expand All @@ -249,24 +326,19 @@ def _save_checkpoint(self, model, trial, metrics=None):
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
else:
#super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
# fix for newer transformer
if metrics:
# log/save metrics manually if needed
logger.info(f"Metrics while saving checkpoints {metrics}")


super(LLaVATrainer, self)._save_checkpoint(model, trial)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
pass
else:
# Patch generation_config to avoid HF validation error
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
gen_cfg = self.model.generation_config
if not gen_cfg.do_sample:
gen_cfg.temperature = None
gen_cfg.top_p = None
self.model.generation_config = gen_cfg
super(LLaVATrainer, self)._save(output_dir, state_dict)
super(LLaVATrainer, self)._save(output_dir, state_dict)
3 changes: 2 additions & 1 deletion llava/train/train.py
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in your if 'siglip' in self.data_args.image_processor.image_processor_type.lower(): line, you need to update to

if 'siglip' or 'gemma' in self.data_args.image_processor.image_processor_type.lower():

Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ def make_inputs_require_grad(module, input, output):
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)


data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True

Expand Down Expand Up @@ -1007,4 +1008,4 @@ def make_inputs_require_grad(module, input, output):


if __name__ == "__main__":
train()
train()
39 changes: 39 additions & 0 deletions scripts/v1_5/pretrain_llava_gemma3_it.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

export PROMPT_VERSION=plain

deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path lmsys/vicuna-7b-v1.5 \
--version $PROMPT_VERSION \
--data_path /dev/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \
--image_folder /dev/data/images \
--mm_projector_type mlp2x_gelu \
--tune_mm_mlp_adapter True \
--freeze_backbone True \
--mm_vision_select_layer -2 \
--vision_tower akshataa/gemma3-4b_it_siglip_encoder \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir checkpoints/llava-$MODEL_VERSION-pretrain \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--save_strategy "steps" \
--save_steps 24000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--mm_projector_lr 2e-6 \
--max_grad_norm 1.0 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--tf32 True \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
2 changes: 1 addition & 1 deletion test_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ class Args: pass
features = tower(pixel)
print("Output shape :", features.shape) # expect (1, 196, 768)
print("Dummy shape :", tower.dummy_feature.shape)
print("AIM v2 encoder is working")
print("Gemma encoder is working")
44 changes: 44 additions & 0 deletions test_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from llava.model.multimodal_encoder.siglip_encoder import SiglipVisionTower
from PIL import Image
import torch

class Args:
def __init__(self):
self.mm_vision_select_layer = -2 # Common choice: second-to-last layer
self.mm_vision_select_feature = 'patch' # or 'cls_patch'

args = Args()

local_ckpt_path = "gemma3_siglip_encoder" # Your extracted encoder path
tower = SiglipVisionTower(
vision_tower=local_ckpt_path,
args=args,
delay_load=False
)

img = Image.new('RGB', (224, 224), color='red') # simple red square
pix = tower.image_processor(images=img, return_tensors='pt')['pixel_values']

with torch.no_grad():
feats = tower(pix)

print("Output shape:", feats.shape) # expect (1, num_patches, hidden_size)
print("Dummy feature shape:", tower.dummy_feature.shape)

assert feats.ndim == 3, f"Expected 3D tensor, got {feats.ndim}D"
assert tower.dummy_feature.ndim == 2, f"Expected 2D dummy feature, got {tower.dummy_feature.ndim}D"

# Additional checks
print(f"Hidden size: {feats.shape[-1]}")
print(f"Number of patches: {feats.shape[1]}")
print(f"Vision tower config: {tower.config}")
print("SigLIP encoder is working ✓")

try:
real_img = Image.open('sample.jpg')
real_pix = tower.image_processor(images=real_img, return_tensors='pt')['pixel_values']
real_feats = tower(real_pix)
print("Real image shape:", real_feats.shape)
pass
except Exception as e:
print(f"Real image test failed: {e}")
Loading