Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions graph_net/agent/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

### 基础依赖
```bash
pip install torch transformers accelerate
pip install huggingface_hub>=0.20.0
pip install "torch>=2.9.0" accelerate
pip install "transformers>=5.8.1" "huggingface_hub>=1.15.0"
```

## 环境配置
Expand Down
22 changes: 22 additions & 0 deletions graph_net/agent/code_generator/llm_code_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@

**多模态类**(clip/blip/flava/align 等):
同时传文本分支(input_ids + attention_mask)和视觉分支(pixel_values)
- Qwen-VL(qwen2_vl/qwen2_5_vl/qwen3_vl):需要 image_grid_thw,且 pixel_values 是展平 patch,形如 (4, channels * temporal_patch_size * patch_size * patch_size),不是 BCHW。
- LLaVA/Gemma3:若报图像特征和图像 token 数不匹配,必须把 input_ids 前缀覆盖为 image_token_index,覆盖数量分别按视觉 patch 数或 mm_tokens_per_image。
- InternVL:通常需要 image_flags,并把 input_ids 前缀覆盖为 img_context_token_id(或 image_token_index),数量为 num_image_token。

**音频类**(wav2vec2/hubert/whisper/clap/unispeech 等):
- wav2vec2/hubert/unispeech:input_values = torch.randn(1, 16000).to(device)
Expand Down Expand Up @@ -94,6 +97,10 @@
| 报错关键词 | 修复方法 |
|---|---|
| "You have to specify pixel_values" | 补充 pixel_values 输入 |
| "grid_thw" / "image_grid_thw" | Qwen-VL 需要 image_grid_thw=torch.tensor([[1,2,2]]),pixel_values 用展平 patch shape |
| "Image features and image tokens do not match" | LLaVA/Gemma3 需要把 input_ids 前缀覆盖为 image_token_index,数量匹配图像 token 数 |
| "tokens: 0, features:" | input_ids 中缺少图像 token,按对应 VLM 家族写入 image_token_index/img_context_token_id 前缀 |
| "image_flags" | InternVL 需要 image_flags=torch.ones((1,1), dtype=torch.long) |
| "index out of range in self" / embedding 越界 | input_ids 上界 > vocab_size,改为 min(vocab_size-1, 30000) |
| "NoneType has no attribute" | 对应输入字段为 None,补充正确 tensor |
| "running_mean should contain X elements" | BatchNorm channel 维度不对,检查 input_features shape 的 channel 轴 |
Expand Down Expand Up @@ -372,7 +379,14 @@ def _extract_key_fields(model_dir: Path) -> str:
"audio_config",
"vision_config",
"text_config",
"image_token_index",
"mm_tokens_per_image",
"vision_feature_select_strategy",
"img_context_token_id",
"num_image_token",
"patch_size",
"temporal_patch_size",
"spatial_merge_size",
"num_mel_bins",
"chunk_length",
# MoE routing (field names vary across models)
Expand Down Expand Up @@ -407,6 +421,14 @@ def _extract_key_fields(model_dir: Path) -> str:
"vocab_size",
"image_size",
"num_channels",
"patch_size",
"temporal_patch_size",
"spatial_merge_size",
"image_token_index",
"mm_tokens_per_image",
"vision_feature_select_strategy",
"img_context_token_id",
"num_image_token",
"num_mel_bins",
"hidden_size",
"num_local_experts",
Expand Down
233 changes: 186 additions & 47 deletions graph_net/agent/code_generator/template_generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
"""Template-based code generator implementation"""

import json
from pathlib import Path
from typing import Optional

from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata
from graph_net.agent.code_generator.base import BaseCodeGenerator
from graph_net.agent.model_type_utils import (
VLM_FAMILY_GEMMA3,
VLM_FAMILY_INTERNVL,
VLM_FAMILY_LLAVA,
VLM_FAMILY_QWEN,
get_vlm_family,
select_auto_model_class,
)
from graph_net.agent.utils.exceptions import CodeGenerationError

# Constants for safe vocab size calculation
Expand Down Expand Up @@ -122,70 +131,200 @@ def _generate_model_loader(
model_path = str(model_dir).replace("\\", "/")
arch = model_metadata.architecture_type

if arch == "seq2seq":
return (
f"from transformers import AutoConfig, AutoModelForSeq2SeqLM\n"
f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n'
f"model = AutoModelForSeq2SeqLM.from_config(_config)"
)
elif arch == "diffusion":
if arch == "diffusion":
return (
f"from diffusers import UNet2DConditionModel\n"
f'_config = UNet2DConditionModel.load_config("{model_path}")\n'
f"model = UNet2DConditionModel.from_config(_config)"
)
else:
# text, moe, vision, multimodal, audio, None → AutoModel
# If model_type is not present in config.json (e.g. prajjwal1/bert-tiny),
# inject the inferred model_type so AutoConfig can resolve the class.
model_type = model_metadata.model_type
if model_type:
return (
f"import json as _json, os as _os, tempfile as _tmp\n"
f"from transformers import AutoConfig, AutoModel\n"
f'_raw = _json.load(open(_os.path.join("{model_path}", "config.json")))\n'
f'if "model_type" not in _raw:\n'
f' _raw["model_type"] = "{model_type}"\n'
f" _td = _tmp.mkdtemp()\n"
f' _json.dump(_raw, open(_os.path.join(_td, "config.json"), "w"))\n'
f" _config = AutoConfig.from_pretrained(_td, trust_remote_code=True)\n"
f"else:\n"
f' _config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n'
f"model = AutoModel.from_config(_config)"
)
else:
return (
f"from transformers import AutoConfig, AutoModel\n"
f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n'
f"model = AutoModel.from_config(_config)"
)

model_class = select_auto_model_class(
model_metadata.model_type, model_metadata.architecture_type
)
model_type = model_metadata.model_type
if model_type and self._config_missing_model_type(model_dir):
return (
f"import json as _json, os as _os, tempfile as _tmp\n"
f"from transformers import AutoConfig, {model_class}\n"
f'_raw = _json.load(open(_os.path.join("{model_path}", "config.json")))\n'
f'_raw["model_type"] = "{model_type}"\n'
f"_td = _tmp.mkdtemp()\n"
f'_json.dump(_raw, open(_os.path.join(_td, "config.json"), "w"))\n'
f"_config = AutoConfig.from_pretrained(_td, trust_remote_code=True)\n"
f"model = {model_class}.from_config(_config)"
)

return (
f"from transformers import AutoConfig, {model_class}\n"
f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n'
f"model = {model_class}.from_config(_config)"
)

@staticmethod
def _config_missing_model_type(model_dir: Path) -> bool:
config_path = model_dir / "config.json"
if not config_path.exists():
return False
try:
return "model_type" not in json.loads(
config_path.read_text(encoding="utf-8")
)
except (OSError, json.JSONDecodeError):
return False

def _generate_input_code(self, model_metadata: ModelMetadata) -> str:
"""Generate input tensor construction code based on model metadata"""
family = get_vlm_family(model_metadata.model_type)
generators = {
VLM_FAMILY_QWEN: self._generate_qwen_vlm_input_code,
VLM_FAMILY_LLAVA: self._generate_llava_input_code,
VLM_FAMILY_GEMMA3: self._generate_gemma3_input_code,
VLM_FAMILY_INTERNVL: self._generate_internvl_input_code,
}
generator = generators.get(family, self._generate_generic_input_code)
return generator(model_metadata)

def _generate_generic_input_code(self, model_metadata: ModelMetadata) -> str:
"""Generate generic input tensor construction code."""
lines = ["inputs = {"]

for name, shape in model_metadata.input_shapes.items():
dtype = model_metadata.input_dtypes.get(name, "int64")
torch_dtype = self._get_torch_dtype(dtype)
shape_tuple = f"({', '.join(map(str, shape))})"

if dtype == "int64":
if "input_ids" in name.lower() or "decoder_input_ids" in name.lower():
safe_vocab_size = self._calculate_safe_vocab_size(model_metadata)
value = (
f"torch.randint(0, {safe_vocab_size}, {shape_tuple}, "
f"dtype={torch_dtype}).to(device)"
)
else:
value = f"torch.ones({shape_tuple}, dtype={torch_dtype}).to(device)"
else:
value = f"torch.randn({shape_tuple}, dtype={torch_dtype}).to(device)"

value = self._generate_tensor_value(name, shape, dtype, model_metadata)
lines.append(f' "{name}": {value},')

lines.append("}")
return "\n".join(lines)

def _generate_tensor_value(
self,
name: str,
shape: list,
dtype: str,
model_metadata: ModelMetadata,
) -> str:
torch_dtype = self._get_torch_dtype(dtype)
shape_tuple = self._shape_tuple(shape)

if dtype == "int64":
if "input_ids" in name.lower() or "decoder_input_ids" in name.lower():
safe_vocab_size = self._calculate_safe_vocab_size(model_metadata)
return (
f"torch.randint(0, {safe_vocab_size}, {shape_tuple}, "
f"dtype={torch_dtype}).to(device)"
)
return f"torch.ones({shape_tuple}, dtype={torch_dtype}).to(device)"
return f"torch.randn({shape_tuple}, dtype={torch_dtype}).to(device)"

def _generate_qwen_vlm_input_code(self, model_metadata: ModelMetadata) -> str:
lines = []
for name, shape in model_metadata.input_shapes.items():
dtype = model_metadata.input_dtypes.get(name, "int64")
if name == "image_grid_thw":
lines.append(
"image_grid_thw = torch.tensor([[1, 2, 2]], dtype=torch.long).to(device)"
)
elif name == "pixel_values":
lines.append(
f"pixel_values = torch.randn({self._shape_tuple(shape)}, dtype=torch.float32).to(device)"
)
elif name == "input_ids":
safe_vocab_size = self._calculate_safe_vocab_size(model_metadata)
lines.append(
f"input_ids = torch.randint(0, {safe_vocab_size}, {self._shape_tuple(shape)}, dtype=torch.int64).to(device)"
)
elif name == "attention_mask":
lines.append(
f"attention_mask = torch.ones({self._shape_tuple(shape)}, dtype=torch.int64).to(device)"
)
else:
lines.append(
f"{name} = {self._generate_tensor_value(name, shape, dtype, model_metadata)}"
)
lines.extend(self._input_dict_lines(model_metadata))
return "\n".join(lines)

def _generate_llava_input_code(self, model_metadata: ModelMetadata) -> str:
lines = self._generate_vlm_base_assignment_lines(model_metadata)
seq_len = self._input_seq_len(model_metadata)
image_shape = model_metadata.input_shapes.get("pixel_values", [1, 3, 224, 224])
image_size = image_shape[-1]
lines.extend(
[
'image_token_index = int(getattr(_config, "image_token_index", 32000))',
'patch_size = int(getattr(getattr(_config, "vision_config", _config), "patch_size", 14))',
f"image_size = {image_size}",
"num_image_tokens = (image_size // patch_size) ** 2",
'if getattr(_config, "vision_feature_select_strategy", None) == "full":',
" num_image_tokens += 1",
f"num_image_tokens = min(num_image_tokens, {seq_len})",
"input_ids[:, :num_image_tokens] = image_token_index",
]
)
lines.extend(self._input_dict_lines(model_metadata))
return "\n".join(lines)

def _generate_gemma3_input_code(self, model_metadata: ModelMetadata) -> str:
lines = self._generate_vlm_base_assignment_lines(model_metadata)
seq_len = self._input_seq_len(model_metadata)
lines.extend(
[
'image_token_index = int(getattr(_config, "image_token_index", 262144))',
'num_image_tokens = int(getattr(_config, "mm_tokens_per_image", 256))',
f"num_image_tokens = min(num_image_tokens, {seq_len})",
"input_ids[:, :num_image_tokens] = image_token_index",
]
)
lines.extend(self._input_dict_lines(model_metadata))
return "\n".join(lines)

def _generate_internvl_input_code(self, model_metadata: ModelMetadata) -> str:
lines = self._generate_vlm_base_assignment_lines(model_metadata)
seq_len = self._input_seq_len(model_metadata)
if "image_flags" in model_metadata.input_shapes:
lines.append(
"image_flags = torch.ones((1, 1), dtype=torch.long).to(device)"
)
lines.extend(
[
'num_image_token = int(getattr(_config, "num_image_token", 256))',
'image_token_id = int(getattr(_config, "img_context_token_id", getattr(_config, "image_token_index", 0)))',
f"num_image_token = min(num_image_token, {seq_len})",
"input_ids[:, :num_image_token] = image_token_id",
]
)
lines.extend(self._input_dict_lines(model_metadata))
return "\n".join(lines)

def _generate_vlm_base_assignment_lines(
self, model_metadata: ModelMetadata
) -> list:
lines = []
for name, shape in model_metadata.input_shapes.items():
dtype = model_metadata.input_dtypes.get(name, "int64")
if name == "image_flags":
continue
lines.append(
f"{name} = {self._generate_tensor_value(name, shape, dtype, model_metadata)}"
)
return lines

@staticmethod
def _shape_tuple(shape: list) -> str:
return f"({', '.join(map(str, shape))})"

@staticmethod
def _input_seq_len(model_metadata: ModelMetadata) -> int:
return int(model_metadata.input_shapes.get("input_ids", [1, 0])[1])

@staticmethod
def _input_dict_lines(model_metadata: ModelMetadata) -> list:
lines = ["inputs = {"]
for name in model_metadata.input_shapes:
lines.append(f' "{name}": {name},')
lines.append("}")
return lines

def _get_torch_dtype(self, dtype: str) -> str:
"""Convert dtype string to torch dtype"""
if dtype == "int64":
Expand Down
Loading
Loading