From 2fda19c6390e8f2ce3d0ebb3e4f8f3eb12d14e3f Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 20 May 2026 09:32:34 +0800 Subject: [PATCH 1/3] refactor(agent): shorten generated loader code Document the current agent dependency versions and only emit model_type patching logic when the downloaded config actually needs it. Co-Authored-By: Claude Opus 4.6 --- graph_net/agent/README.md | 4 +- .../code_generator/template_generator.py | 40 +++++++++++-------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/graph_net/agent/README.md b/graph_net/agent/README.md index 2fcf1aa7c..990f8083d 100644 --- a/graph_net/agent/README.md +++ b/graph_net/agent/README.md @@ -6,8 +6,8 @@ ### 基础依赖 ```bash -pip install torch transformers accelerate -pip install huggingface_hub>=0.20.0 +pip install "torch>=2.9.0" accelerate +pip install "transformers>=5.8.1" "huggingface_hub>=1.15.0" ``` ## 环境配置 diff --git a/graph_net/agent/code_generator/template_generator.py b/graph_net/agent/code_generator/template_generator.py index e2e051b31..2e2cb9d1d 100644 --- a/graph_net/agent/code_generator/template_generator.py +++ b/graph_net/agent/code_generator/template_generator.py @@ -1,5 +1,6 @@ """Template-based code generator implementation""" +import json from pathlib import Path from typing import Optional @@ -136,30 +137,37 @@ def _generate_model_loader( ) else: # text, moe, vision, multimodal, audio, None → AutoModel - # If model_type is not present in config.json (e.g. prajjwal1/bert-tiny), - # inject the inferred model_type so AutoConfig can resolve the class. model_type = model_metadata.model_type - if model_type: + if model_type and self._config_missing_model_type(model_dir): return ( f"import json as _json, os as _os, tempfile as _tmp\n" f"from transformers import AutoConfig, AutoModel\n" f'_raw = _json.load(open(_os.path.join("{model_path}", "config.json")))\n' - f'if "model_type" not in _raw:\n' - f' _raw["model_type"] = "{model_type}"\n' - f" _td = _tmp.mkdtemp()\n" - f' _json.dump(_raw, open(_os.path.join(_td, "config.json"), "w"))\n' - f" _config = AutoConfig.from_pretrained(_td, trust_remote_code=True)\n" - f"else:\n" - f' _config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' - f"model = AutoModel.from_config(_config)" - ) - else: - return ( - f"from transformers import AutoConfig, AutoModel\n" - f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' + f'_raw["model_type"] = "{model_type}"\n' + f"_td = _tmp.mkdtemp()\n" + f'_json.dump(_raw, open(_os.path.join(_td, "config.json"), "w"))\n' + f"_config = AutoConfig.from_pretrained(_td, trust_remote_code=True)\n" f"model = AutoModel.from_config(_config)" ) + return ( + f"from transformers import AutoConfig, AutoModel\n" + f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' + f"model = AutoModel.from_config(_config)" + ) + + @staticmethod + def _config_missing_model_type(model_dir: Path) -> bool: + config_path = model_dir / "config.json" + if not config_path.exists(): + return False + try: + return "model_type" not in json.loads( + config_path.read_text(encoding="utf-8") + ) + except (OSError, json.JSONDecodeError): + return False + def _generate_input_code(self, model_metadata: ModelMetadata) -> str: """Generate input tensor construction code based on model metadata""" lines = ["inputs = {"] From d24dfa82430452e12fba578f815573a8d0328890 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 20 May 2026 10:53:22 +0800 Subject: [PATCH 2/3] fix(agent): improve VLM input construction Add exact-family VLM metadata and generated input handling so Qwen-VL, LLaVA, Gemma3, and InternVL scripts include required image-token metadata while keeping generic multimodal behavior unchanged. Co-Authored-By: Claude Opus 4.6 --- .../agent/code_generator/llm_code_fixer.py | 22 +++ .../code_generator/template_generator.py | 164 +++++++++++++++-- .../config_metadata_analyzer.py | 174 +++++++++++++++++- graph_net/agent/vlm_model_types.py | 27 +++ 4 files changed, 368 insertions(+), 19 deletions(-) create mode 100644 graph_net/agent/vlm_model_types.py diff --git a/graph_net/agent/code_generator/llm_code_fixer.py b/graph_net/agent/code_generator/llm_code_fixer.py index 2e0fcef2c..c067848d2 100644 --- a/graph_net/agent/code_generator/llm_code_fixer.py +++ b/graph_net/agent/code_generator/llm_code_fixer.py @@ -57,6 +57,9 @@ **多模态类**(clip/blip/flava/align 等): 同时传文本分支(input_ids + attention_mask)和视觉分支(pixel_values) + - Qwen-VL(qwen2_vl/qwen2_5_vl/qwen3_vl):需要 image_grid_thw,且 pixel_values 是展平 patch,形如 (4, channels * temporal_patch_size * patch_size * patch_size),不是 BCHW。 + - LLaVA/Gemma3:若报图像特征和图像 token 数不匹配,必须把 input_ids 前缀覆盖为 image_token_index,覆盖数量分别按视觉 patch 数或 mm_tokens_per_image。 + - InternVL:通常需要 image_flags,并把 input_ids 前缀覆盖为 img_context_token_id(或 image_token_index),数量为 num_image_token。 **音频类**(wav2vec2/hubert/whisper/clap/unispeech 等): - wav2vec2/hubert/unispeech:input_values = torch.randn(1, 16000).to(device) @@ -94,6 +97,10 @@ | 报错关键词 | 修复方法 | |---|---| | "You have to specify pixel_values" | 补充 pixel_values 输入 | +| "grid_thw" / "image_grid_thw" | Qwen-VL 需要 image_grid_thw=torch.tensor([[1,2,2]]),pixel_values 用展平 patch shape | +| "Image features and image tokens do not match" | LLaVA/Gemma3 需要把 input_ids 前缀覆盖为 image_token_index,数量匹配图像 token 数 | +| "tokens: 0, features:" | input_ids 中缺少图像 token,按对应 VLM 家族写入 image_token_index/img_context_token_id 前缀 | +| "image_flags" | InternVL 需要 image_flags=torch.ones((1,1), dtype=torch.long) | | "index out of range in self" / embedding 越界 | input_ids 上界 > vocab_size,改为 min(vocab_size-1, 30000) | | "NoneType has no attribute" | 对应输入字段为 None,补充正确 tensor | | "running_mean should contain X elements" | BatchNorm channel 维度不对,检查 input_features shape 的 channel 轴 | @@ -372,7 +379,14 @@ def _extract_key_fields(model_dir: Path) -> str: "audio_config", "vision_config", "text_config", + "image_token_index", + "mm_tokens_per_image", + "vision_feature_select_strategy", + "img_context_token_id", + "num_image_token", "patch_size", + "temporal_patch_size", + "spatial_merge_size", "num_mel_bins", "chunk_length", # MoE routing (field names vary across models) @@ -407,6 +421,14 @@ def _extract_key_fields(model_dir: Path) -> str: "vocab_size", "image_size", "num_channels", + "patch_size", + "temporal_patch_size", + "spatial_merge_size", + "image_token_index", + "mm_tokens_per_image", + "vision_feature_select_strategy", + "img_context_token_id", + "num_image_token", "num_mel_bins", "hidden_size", "num_local_experts", diff --git a/graph_net/agent/code_generator/template_generator.py b/graph_net/agent/code_generator/template_generator.py index 2e2cb9d1d..59227ef86 100644 --- a/graph_net/agent/code_generator/template_generator.py +++ b/graph_net/agent/code_generator/template_generator.py @@ -7,6 +7,13 @@ from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata from graph_net.agent.code_generator.base import BaseCodeGenerator from graph_net.agent.utils.exceptions import CodeGenerationError +from graph_net.agent.vlm_model_types import ( + VLM_FAMILY_GEMMA3, + VLM_FAMILY_INTERNVL, + VLM_FAMILY_LLAVA, + VLM_FAMILY_QWEN, + get_vlm_family, +) # Constants for safe vocab size calculation DEFAULT_VOCAB_SIZE = 30522 @@ -170,30 +177,157 @@ def _config_missing_model_type(model_dir: Path) -> bool: def _generate_input_code(self, model_metadata: ModelMetadata) -> str: """Generate input tensor construction code based on model metadata""" + family = get_vlm_family(model_metadata.model_type) + generators = { + VLM_FAMILY_QWEN: self._generate_qwen_vlm_input_code, + VLM_FAMILY_LLAVA: self._generate_llava_input_code, + VLM_FAMILY_GEMMA3: self._generate_gemma3_input_code, + VLM_FAMILY_INTERNVL: self._generate_internvl_input_code, + } + generator = generators.get(family, self._generate_generic_input_code) + return generator(model_metadata) + + def _generate_generic_input_code(self, model_metadata: ModelMetadata) -> str: + """Generate generic input tensor construction code.""" lines = ["inputs = {"] for name, shape in model_metadata.input_shapes.items(): dtype = model_metadata.input_dtypes.get(name, "int64") - torch_dtype = self._get_torch_dtype(dtype) - shape_tuple = f"({', '.join(map(str, shape))})" - - if dtype == "int64": - if "input_ids" in name.lower() or "decoder_input_ids" in name.lower(): - safe_vocab_size = self._calculate_safe_vocab_size(model_metadata) - value = ( - f"torch.randint(0, {safe_vocab_size}, {shape_tuple}, " - f"dtype={torch_dtype}).to(device)" - ) - else: - value = f"torch.ones({shape_tuple}, dtype={torch_dtype}).to(device)" - else: - value = f"torch.randn({shape_tuple}, dtype={torch_dtype}).to(device)" - + value = self._generate_tensor_value(name, shape, dtype, model_metadata) lines.append(f' "{name}": {value},') lines.append("}") return "\n".join(lines) + def _generate_tensor_value( + self, + name: str, + shape: list, + dtype: str, + model_metadata: ModelMetadata, + ) -> str: + torch_dtype = self._get_torch_dtype(dtype) + shape_tuple = self._shape_tuple(shape) + + if dtype == "int64": + if "input_ids" in name.lower() or "decoder_input_ids" in name.lower(): + safe_vocab_size = self._calculate_safe_vocab_size(model_metadata) + return ( + f"torch.randint(0, {safe_vocab_size}, {shape_tuple}, " + f"dtype={torch_dtype}).to(device)" + ) + return f"torch.ones({shape_tuple}, dtype={torch_dtype}).to(device)" + return f"torch.randn({shape_tuple}, dtype={torch_dtype}).to(device)" + + def _generate_qwen_vlm_input_code(self, model_metadata: ModelMetadata) -> str: + lines = [] + for name, shape in model_metadata.input_shapes.items(): + dtype = model_metadata.input_dtypes.get(name, "int64") + if name == "image_grid_thw": + lines.append( + "image_grid_thw = torch.tensor([[1, 2, 2]], dtype=torch.long).to(device)" + ) + elif name == "pixel_values": + lines.append( + f"pixel_values = torch.randn({self._shape_tuple(shape)}, dtype=torch.float32).to(device)" + ) + elif name == "input_ids": + safe_vocab_size = self._calculate_safe_vocab_size(model_metadata) + lines.append( + f"input_ids = torch.randint(0, {safe_vocab_size}, {self._shape_tuple(shape)}, dtype=torch.int64).to(device)" + ) + elif name == "attention_mask": + lines.append( + f"attention_mask = torch.ones({self._shape_tuple(shape)}, dtype=torch.int64).to(device)" + ) + else: + lines.append( + f"{name} = {self._generate_tensor_value(name, shape, dtype, model_metadata)}" + ) + lines.extend(self._input_dict_lines(model_metadata)) + return "\n".join(lines) + + def _generate_llava_input_code(self, model_metadata: ModelMetadata) -> str: + lines = self._generate_vlm_base_assignment_lines(model_metadata) + seq_len = self._input_seq_len(model_metadata) + image_shape = model_metadata.input_shapes.get("pixel_values", [1, 3, 224, 224]) + image_size = image_shape[-1] + lines.extend( + [ + 'image_token_index = int(getattr(_config, "image_token_index", 32000))', + 'patch_size = int(getattr(getattr(_config, "vision_config", _config), "patch_size", 14))', + f"image_size = {image_size}", + "num_image_tokens = (image_size // patch_size) ** 2", + 'if getattr(_config, "vision_feature_select_strategy", None) == "full":', + " num_image_tokens += 1", + f"num_image_tokens = min(num_image_tokens, {seq_len})", + "input_ids[:, :num_image_tokens] = image_token_index", + ] + ) + lines.extend(self._input_dict_lines(model_metadata)) + return "\n".join(lines) + + def _generate_gemma3_input_code(self, model_metadata: ModelMetadata) -> str: + lines = self._generate_vlm_base_assignment_lines(model_metadata) + seq_len = self._input_seq_len(model_metadata) + lines.extend( + [ + 'image_token_index = int(getattr(_config, "image_token_index", 262144))', + 'num_image_tokens = int(getattr(_config, "mm_tokens_per_image", 256))', + f"num_image_tokens = min(num_image_tokens, {seq_len})", + "input_ids[:, :num_image_tokens] = image_token_index", + ] + ) + lines.extend(self._input_dict_lines(model_metadata)) + return "\n".join(lines) + + def _generate_internvl_input_code(self, model_metadata: ModelMetadata) -> str: + lines = self._generate_vlm_base_assignment_lines(model_metadata) + seq_len = self._input_seq_len(model_metadata) + if "image_flags" in model_metadata.input_shapes: + lines.append( + "image_flags = torch.ones((1, 1), dtype=torch.long).to(device)" + ) + lines.extend( + [ + 'num_image_token = int(getattr(_config, "num_image_token", 256))', + 'image_token_id = int(getattr(_config, "img_context_token_id", getattr(_config, "image_token_index", 0)))', + f"num_image_token = min(num_image_token, {seq_len})", + "input_ids[:, :num_image_token] = image_token_id", + ] + ) + lines.extend(self._input_dict_lines(model_metadata)) + return "\n".join(lines) + + def _generate_vlm_base_assignment_lines( + self, model_metadata: ModelMetadata + ) -> list: + lines = [] + for name, shape in model_metadata.input_shapes.items(): + dtype = model_metadata.input_dtypes.get(name, "int64") + if name == "image_flags": + continue + lines.append( + f"{name} = {self._generate_tensor_value(name, shape, dtype, model_metadata)}" + ) + return lines + + @staticmethod + def _shape_tuple(shape: list) -> str: + return f"({', '.join(map(str, shape))})" + + @staticmethod + def _input_seq_len(model_metadata: ModelMetadata) -> int: + return int(model_metadata.input_shapes.get("input_ids", [1, 0])[1]) + + @staticmethod + def _input_dict_lines(model_metadata: ModelMetadata) -> list: + lines = ["inputs = {"] + for name in model_metadata.input_shapes: + lines.append(f' "{name}": {name},') + lines.append("}") + return lines + def _get_torch_dtype(self, dtype: str) -> str: """Convert dtype string to torch dtype""" if dtype == "int64": diff --git a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py index 3e6213306..714b1fc8e 100644 --- a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py +++ b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py @@ -10,11 +10,22 @@ GraphExtractionErrorCategory, MetadataAnalysisError, ) +from graph_net.agent.vlm_model_types import ( + VLM_FAMILY_GEMMA3, + VLM_FAMILY_INTERNVL, + VLM_FAMILY_LLAVA, + VLM_FAMILY_QWEN, + get_vlm_family, + is_image_token_vlm, +) # Cap sequence length to avoid OOM: attention is O(n²), graph extraction # only needs a short sequence to trace the computation graph. _MAX_SEQ_LEN = 128 +# Larger cap used only for known VLM families that require image-token prefixes. +_MAX_VLM_SEQ_LEN = 512 +_VLM_TEXT_TAIL = 16 # Cap image size to avoid OOM on high-resolution configs. _MAX_IMAGE_SIZE = 512 _EMBEDDING_WEIGHT_KEYS = [ @@ -32,6 +43,30 @@ def _cfg_get(cfg: Any, key: str, default: Any = None) -> Any: return getattr(cfg, key, default) +def _get_vision_config(cfg_obj: Any, cfg_dict: Dict) -> Any: + return _cfg_get(cfg_obj, "vision_config") or cfg_dict.get("vision_config", {}) + + +def _get_vision_field(cfg_obj: Any, cfg_dict: Dict, key: str, default: Any) -> Any: + for vis_cfg in (_cfg_get(cfg_obj, "vision_config"), cfg_dict.get("vision_config")): + value = _cfg_get(vis_cfg, key, None) + if value is not None: + return value + return _cfg_get(cfg_obj, key) or cfg_dict.get(key, default) + + +def _normalize_square_size(raw_size: Any, default: int = 224) -> int: + if raw_size is None: + raw_size = default + if isinstance(raw_size, (list, tuple)): + raw_size = raw_size[0] + return int(raw_size) + + +def _cap_seq_len(raw_len: Any, cap: int = _MAX_SEQ_LEN) -> int: + return min(int(raw_len), cap) + + class ConfigMetadataAnalyzer(BaseMetadataAnalyzer): """Analyzer that extracts metadata from config.json, using transformers AutoConfig when available to leverage rich config object properties for architecture detection. @@ -191,6 +226,8 @@ def _classify_architecture(cfg_obj: Any, cfg_dict: Dict) -> Optional[str]: return "multimodal" except ImportError: pass + if is_image_token_vlm(model_type): + return "multimodal" # Fallback: check sub_configs / dict keys for vision+text pair if cfg_obj is not None: sub_configs = getattr(cfg_obj, "sub_configs", {}) @@ -348,8 +385,24 @@ def _inputs_multimodal( raw_len = _cfg_get(cfg_obj, "max_position_embeddings") or cfg_dict.get( "max_position_embeddings", 512 ) - seq_len = min(int(raw_len), _MAX_SEQ_LEN) + seq_len = _cap_seq_len(raw_len) + + model_type = _cfg_get(cfg_obj, "model_type") or cfg_dict.get("model_type") + family = get_vlm_family(model_type) + vlm_handlers = { + VLM_FAMILY_QWEN: lambda: self._inputs_qwen_vlm(cfg_obj, cfg_dict, seq_len), + VLM_FAMILY_LLAVA: lambda: self._inputs_llava(cfg_obj, cfg_dict), + VLM_FAMILY_GEMMA3: lambda: self._inputs_gemma3(cfg_obj, cfg_dict), + VLM_FAMILY_INTERNVL: lambda: self._inputs_internvl(cfg_obj, cfg_dict), + } + if family in vlm_handlers: + return vlm_handlers[family]() + + return self._inputs_generic_multimodal(cfg_obj, cfg_dict, seq_len) + def _inputs_generic_multimodal( + self, cfg_obj: Any, cfg_dict: Dict, seq_len: int + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: # Vision branch — prefer sub vision_config vis_cfg = _cfg_get(cfg_obj, "vision_config") or cfg_dict.get( "vision_config", {} @@ -367,9 +420,7 @@ def _inputs_multimodal( num_channels = _cfg_get(cfg_obj, "num_channels") or cfg_dict.get( "num_channels", 3 ) - if isinstance(raw_size, (list, tuple)): - raw_size = raw_size[0] - image_size = min(int(raw_size), _MAX_IMAGE_SIZE) + image_size = min(_normalize_square_size(raw_size), _MAX_IMAGE_SIZE) shapes = { "input_ids": [1, seq_len], @@ -383,6 +434,121 @@ def _inputs_multimodal( } return shapes, dtypes + def _inputs_qwen_vlm( + self, cfg_obj: Any, cfg_dict: Dict, seq_len: int + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + patch_size = int(_get_vision_field(cfg_obj, cfg_dict, "patch_size", 14)) + temporal_patch_size = int( + _get_vision_field(cfg_obj, cfg_dict, "temporal_patch_size", 2) + ) + num_channels = int(_get_vision_field(cfg_obj, cfg_dict, "num_channels", 3)) + shapes = { + "input_ids": [1, seq_len], + "attention_mask": [1, seq_len], + "pixel_values": [ + 4, + num_channels * temporal_patch_size * patch_size * patch_size, + ], + "image_grid_thw": [1, 3], + } + dtypes = { + "input_ids": "int64", + "attention_mask": "int64", + "pixel_values": "float32", + "image_grid_thw": "int64", + } + return shapes, dtypes + + def _inputs_llava( + self, cfg_obj: Any, cfg_dict: Dict + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + patch_size = int(_get_vision_field(cfg_obj, cfg_dict, "patch_size", 14)) + raw_size = min( + _normalize_square_size( + _get_vision_field(cfg_obj, cfg_dict, "image_size", 224) + ), + 112, + ) + image_size = max(patch_size, (raw_size // patch_size) * patch_size) + num_channels = int(_get_vision_field(cfg_obj, cfg_dict, "num_channels", 3)) + num_image_tokens = (image_size // patch_size) ** 2 + if ( + _cfg_get(cfg_obj, "vision_feature_select_strategy") + or cfg_dict.get("vision_feature_select_strategy") + ) == "full": + num_image_tokens += 1 + seq_len = min( + max(num_image_tokens + _VLM_TEXT_TAIL, _MAX_SEQ_LEN), _MAX_VLM_SEQ_LEN + ) + shapes = { + "input_ids": [1, seq_len], + "attention_mask": [1, seq_len], + "pixel_values": [1, num_channels, image_size, image_size], + } + dtypes = { + "input_ids": "int64", + "attention_mask": "int64", + "pixel_values": "float32", + } + return shapes, dtypes + + def _inputs_gemma3( + self, cfg_obj: Any, cfg_dict: Dict + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + mm_tokens = int( + _cfg_get(cfg_obj, "mm_tokens_per_image") + or cfg_dict.get("mm_tokens_per_image", 256) + ) + seq_len = min(max(mm_tokens + _VLM_TEXT_TAIL, _MAX_SEQ_LEN), _MAX_VLM_SEQ_LEN) + image_size = min( + _normalize_square_size( + _get_vision_field(cfg_obj, cfg_dict, "image_size", 224) + ), + _MAX_IMAGE_SIZE, + ) + num_channels = int(_get_vision_field(cfg_obj, cfg_dict, "num_channels", 3)) + shapes = { + "input_ids": [1, seq_len], + "attention_mask": [1, seq_len], + "pixel_values": [1, num_channels, image_size, image_size], + } + dtypes = { + "input_ids": "int64", + "attention_mask": "int64", + "pixel_values": "float32", + } + return shapes, dtypes + + def _inputs_internvl( + self, cfg_obj: Any, cfg_dict: Dict + ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: + num_image_token = int( + _cfg_get(cfg_obj, "num_image_token") or cfg_dict.get("num_image_token", 256) + ) + seq_len = min( + max(num_image_token + _VLM_TEXT_TAIL, _MAX_SEQ_LEN), _MAX_VLM_SEQ_LEN + ) + image_size = min( + _normalize_square_size( + _get_vision_field(cfg_obj, cfg_dict, "image_size", 224) + ), + _MAX_IMAGE_SIZE, + ) + num_channels = int(_get_vision_field(cfg_obj, cfg_dict, "num_channels", 3)) + shapes = { + "input_ids": [1, seq_len], + "attention_mask": [1, seq_len], + "pixel_values": [1, num_channels, image_size, image_size], + "image_flags": [1, 1], + } + dtypes = { + "input_ids": "int64", + "attention_mask": "int64", + "pixel_values": "float32", + "image_flags": "int64", + } + return shapes, dtypes + def _inputs_audio( self, cfg_obj: Any, cfg_dict: Dict ) -> Tuple[Dict[str, List[int]], Dict[str, str]]: diff --git a/graph_net/agent/vlm_model_types.py b/graph_net/agent/vlm_model_types.py new file mode 100644 index 000000000..e636e1037 --- /dev/null +++ b/graph_net/agent/vlm_model_types.py @@ -0,0 +1,27 @@ +"""Shared exact VLM model_type family mapping for agent input construction.""" + +from typing import Optional + +VLM_FAMILY_QWEN = "qwen_vlm" +VLM_FAMILY_LLAVA = "llava" +VLM_FAMILY_GEMMA3 = "gemma3" +VLM_FAMILY_INTERNVL = "internvl" + +VLM_MODEL_TYPE_TO_FAMILY = { + "qwen2_vl": VLM_FAMILY_QWEN, + "qwen2_5_vl": VLM_FAMILY_QWEN, + "qwen3_vl": VLM_FAMILY_QWEN, + "llava": VLM_FAMILY_LLAVA, + "gemma3": VLM_FAMILY_GEMMA3, + "internvl": VLM_FAMILY_INTERNVL, + "internvl_chat": VLM_FAMILY_INTERNVL, +} +IMAGE_TOKEN_VLM_MODEL_TYPES = frozenset(VLM_MODEL_TYPE_TO_FAMILY) + + +def get_vlm_family(model_type: Optional[str]) -> Optional[str]: + return VLM_MODEL_TYPE_TO_FAMILY.get((model_type or "").lower()) + + +def is_image_token_vlm(model_type: Optional[str]) -> bool: + return get_vlm_family(model_type) is not None From 859a01eade31d10ab7dea44f0903ebcdaa608277 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 20 May 2026 17:02:35 +0800 Subject: [PATCH 3/3] refactor(agent): centralize model type helpers Consolidate VLM family mapping, transformer auto mapping checks, and AutoModel class selection into a shared model_type_utils module so metadata analysis and code generation use the same model-type decisions. Co-Authored-By: Claude Opus 4.6 --- .../code_generator/template_generator.py | 49 ++++----- .../config_metadata_analyzer.py | 63 +++-------- graph_net/agent/model_type_utils.py | 100 ++++++++++++++++++ graph_net/agent/vlm_model_types.py | 27 ----- 4 files changed, 139 insertions(+), 100 deletions(-) create mode 100644 graph_net/agent/model_type_utils.py delete mode 100644 graph_net/agent/vlm_model_types.py diff --git a/graph_net/agent/code_generator/template_generator.py b/graph_net/agent/code_generator/template_generator.py index 59227ef86..6462519ec 100644 --- a/graph_net/agent/code_generator/template_generator.py +++ b/graph_net/agent/code_generator/template_generator.py @@ -6,14 +6,15 @@ from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata from graph_net.agent.code_generator.base import BaseCodeGenerator -from graph_net.agent.utils.exceptions import CodeGenerationError -from graph_net.agent.vlm_model_types import ( +from graph_net.agent.model_type_utils import ( VLM_FAMILY_GEMMA3, VLM_FAMILY_INTERNVL, VLM_FAMILY_LLAVA, VLM_FAMILY_QWEN, get_vlm_family, + select_auto_model_class, ) +from graph_net.agent.utils.exceptions import CodeGenerationError # Constants for safe vocab size calculation DEFAULT_VOCAB_SIZE = 30522 @@ -130,39 +131,35 @@ def _generate_model_loader( model_path = str(model_dir).replace("\\", "/") arch = model_metadata.architecture_type - if arch == "seq2seq": - return ( - f"from transformers import AutoConfig, AutoModelForSeq2SeqLM\n" - f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' - f"model = AutoModelForSeq2SeqLM.from_config(_config)" - ) - elif arch == "diffusion": + if arch == "diffusion": return ( f"from diffusers import UNet2DConditionModel\n" f'_config = UNet2DConditionModel.load_config("{model_path}")\n' f"model = UNet2DConditionModel.from_config(_config)" ) - else: - # text, moe, vision, multimodal, audio, None → AutoModel - model_type = model_metadata.model_type - if model_type and self._config_missing_model_type(model_dir): - return ( - f"import json as _json, os as _os, tempfile as _tmp\n" - f"from transformers import AutoConfig, AutoModel\n" - f'_raw = _json.load(open(_os.path.join("{model_path}", "config.json")))\n' - f'_raw["model_type"] = "{model_type}"\n' - f"_td = _tmp.mkdtemp()\n" - f'_json.dump(_raw, open(_os.path.join(_td, "config.json"), "w"))\n' - f"_config = AutoConfig.from_pretrained(_td, trust_remote_code=True)\n" - f"model = AutoModel.from_config(_config)" - ) + model_class = select_auto_model_class( + model_metadata.model_type, model_metadata.architecture_type + ) + model_type = model_metadata.model_type + if model_type and self._config_missing_model_type(model_dir): return ( - f"from transformers import AutoConfig, AutoModel\n" - f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' - f"model = AutoModel.from_config(_config)" + f"import json as _json, os as _os, tempfile as _tmp\n" + f"from transformers import AutoConfig, {model_class}\n" + f'_raw = _json.load(open(_os.path.join("{model_path}", "config.json")))\n' + f'_raw["model_type"] = "{model_type}"\n' + f"_td = _tmp.mkdtemp()\n" + f'_json.dump(_raw, open(_os.path.join(_td, "config.json"), "w"))\n' + f"_config = AutoConfig.from_pretrained(_td, trust_remote_code=True)\n" + f"model = {model_class}.from_config(_config)" ) + return ( + f"from transformers import AutoConfig, {model_class}\n" + f'_config = AutoConfig.from_pretrained("{model_path}", trust_remote_code=True)\n' + f"model = {model_class}.from_config(_config)" + ) + @staticmethod def _config_missing_model_type(model_dir: Path) -> bool: config_path = model_dir / "config.json" diff --git a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py index 714b1fc8e..cc06c8eba 100644 --- a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py +++ b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py @@ -6,17 +6,18 @@ from graph_net.agent.metadata_analyzer.base import BaseMetadataAnalyzer from graph_net.agent.metadata_analyzer.model_metadata import ModelMetadata -from graph_net.agent.utils.exceptions import ( - GraphExtractionErrorCategory, - MetadataAnalysisError, -) -from graph_net.agent.vlm_model_types import ( +from graph_net.agent.model_type_utils import ( VLM_FAMILY_GEMMA3, VLM_FAMILY_INTERNVL, VLM_FAMILY_LLAVA, VLM_FAMILY_QWEN, get_vlm_family, - is_image_token_vlm, + is_audio_model_type, + is_multimodal_model_type, +) +from graph_net.agent.utils.exceptions import ( + GraphExtractionErrorCategory, + MetadataAnalysisError, ) @@ -184,49 +185,17 @@ def _classify_architecture(cfg_obj: Any, cfg_dict: Dict) -> Optional[str]: # 2. Audio models # Use the union of transformers' audio task mapping tables — no hardcoded list. - try: - from transformers.models.auto.modeling_auto import ( - MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, - MODEL_FOR_CTC_MAPPING_NAMES, - MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, - MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, - MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, - ) - - all_audio: set = ( - set(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) - | set(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES) - | set(MODEL_FOR_CTC_MAPPING_NAMES) - | set(MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES) - | set(MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES) - ) - if model_type in all_audio: - return "audio" - except ImportError: - # Attribute-based fallback - if _cfg_get(cfg_obj, "num_mel_bins") or cfg_dict.get("num_mel_bins"): - return "audio" - if _cfg_get(cfg_obj, "feat_extract_norm") or cfg_dict.get( - "feat_extract_norm" - ): - return "audio" + if is_audio_model_type(model_type): + return "audio" + # Attribute-based fallback + if _cfg_get(cfg_obj, "num_mel_bins") or cfg_dict.get("num_mel_bins"): + return "audio" + if _cfg_get(cfg_obj, "feat_extract_norm") or cfg_dict.get("feat_extract_norm"): + return "audio" # 3. Multimodal VLMs - # Use transformers' multimodal task mapping tables — no hardcoded list. - try: - from transformers.models.auto.modeling_auto import ( - MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, - MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES, - ) - - all_multimodal: set = set(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES) | set( - MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES - ) - if model_type in all_multimodal: - return "multimodal" - except ImportError: - pass - if is_image_token_vlm(model_type): + # Use transformers' multimodal task mapping tables plus exact VLM family map. + if is_multimodal_model_type(model_type): return "multimodal" # Fallback: check sub_configs / dict keys for vision+text pair if cfg_obj is not None: diff --git a/graph_net/agent/model_type_utils.py b/graph_net/agent/model_type_utils.py new file mode 100644 index 000000000..13de59f03 --- /dev/null +++ b/graph_net/agent/model_type_utils.py @@ -0,0 +1,100 @@ +"""Shared model_type helpers for agent metadata analysis and code generation.""" + +from typing import Optional, Tuple + +VLM_FAMILY_QWEN = "qwen_vlm" +VLM_FAMILY_LLAVA = "llava" +VLM_FAMILY_GEMMA3 = "gemma3" +VLM_FAMILY_INTERNVL = "internvl" + +VLM_MODEL_TYPE_TO_FAMILY = { + "qwen2_vl": VLM_FAMILY_QWEN, + "qwen2_5_vl": VLM_FAMILY_QWEN, + "qwen3_vl": VLM_FAMILY_QWEN, + "llava": VLM_FAMILY_LLAVA, + "gemma3": VLM_FAMILY_GEMMA3, + "internvl": VLM_FAMILY_INTERNVL, + "internvl_chat": VLM_FAMILY_INTERNVL, +} +IMAGE_TOKEN_VLM_MODEL_TYPES = frozenset(VLM_MODEL_TYPE_TO_FAMILY) + +AUDIO_AUTO_MAPPING_NAMES = ( + "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", + "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES", + "MODEL_FOR_CTC_MAPPING_NAMES", + "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES", + "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES", +) +IMAGE_TEXT_TO_TEXT_AUTO_MAPPING_NAMES = ("MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES",) +MULTIMODAL_LM_AUTO_MAPPING_NAMES = ("MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES",) +OBJECT_DETECTION_AUTO_MAPPING_NAMES = ("MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES",) +CAUSAL_LM_AUTO_MAPPING_NAMES = ("MODEL_FOR_CAUSAL_LM_MAPPING_NAMES",) +MASKED_LM_AUTO_MAPPING_NAMES = ("MODEL_FOR_MASKED_LM_MAPPING_NAMES",) + + +def normalize_model_type(model_type: Optional[str]) -> str: + return (model_type or "").lower() + + +def get_vlm_family(model_type: Optional[str]) -> Optional[str]: + return VLM_MODEL_TYPE_TO_FAMILY.get(normalize_model_type(model_type)) + + +def is_image_token_vlm(model_type: Optional[str]) -> bool: + return get_vlm_family(model_type) is not None + + +def model_type_in_auto_mapping( + model_type: Optional[str], mapping_names: Tuple[str, ...] +) -> bool: + normalized = normalize_model_type(model_type) + if not normalized: + return False + try: + from transformers.models.auto import modeling_auto + except ImportError: + return False + return any( + normalized in getattr(modeling_auto, mapping_name, {}) + for mapping_name in mapping_names + ) + + +def is_audio_model_type(model_type: Optional[str]) -> bool: + return model_type_in_auto_mapping(model_type, AUDIO_AUTO_MAPPING_NAMES) + + +def is_multimodal_model_type(model_type: Optional[str]) -> bool: + return is_image_token_vlm(model_type) or model_type_in_auto_mapping( + model_type, + IMAGE_TEXT_TO_TEXT_AUTO_MAPPING_NAMES + MULTIMODAL_LM_AUTO_MAPPING_NAMES, + ) + + +def is_object_detection_model_type(model_type: Optional[str]) -> bool: + return model_type_in_auto_mapping(model_type, OBJECT_DETECTION_AUTO_MAPPING_NAMES) + + +def is_causal_lm_model_type(model_type: Optional[str]) -> bool: + return model_type_in_auto_mapping( + model_type, CAUSAL_LM_AUTO_MAPPING_NAMES + ) and not model_type_in_auto_mapping(model_type, MASKED_LM_AUTO_MAPPING_NAMES) + + +def select_auto_model_class( + model_type: Optional[str], architecture_type: Optional[str] +) -> str: + normalized_arch = architecture_type or "" + if normalized_arch == "seq2seq": + return "AutoModelForSeq2SeqLM" + if get_vlm_family(model_type) or model_type_in_auto_mapping( + model_type, IMAGE_TEXT_TO_TEXT_AUTO_MAPPING_NAMES + ): + return "AutoModelForImageTextToText" + if model_type_in_auto_mapping(model_type, MULTIMODAL_LM_AUTO_MAPPING_NAMES): + return "AutoModelForMultimodalLM" + if is_object_detection_model_type(model_type): + return "AutoModelForObjectDetection" + if normalized_arch in {"text", "moe"} and is_causal_lm_model_type(model_type): + return "AutoModelForCausalLM" + return "AutoModel" diff --git a/graph_net/agent/vlm_model_types.py b/graph_net/agent/vlm_model_types.py deleted file mode 100644 index e636e1037..000000000 --- a/graph_net/agent/vlm_model_types.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Shared exact VLM model_type family mapping for agent input construction.""" - -from typing import Optional - -VLM_FAMILY_QWEN = "qwen_vlm" -VLM_FAMILY_LLAVA = "llava" -VLM_FAMILY_GEMMA3 = "gemma3" -VLM_FAMILY_INTERNVL = "internvl" - -VLM_MODEL_TYPE_TO_FAMILY = { - "qwen2_vl": VLM_FAMILY_QWEN, - "qwen2_5_vl": VLM_FAMILY_QWEN, - "qwen3_vl": VLM_FAMILY_QWEN, - "llava": VLM_FAMILY_LLAVA, - "gemma3": VLM_FAMILY_GEMMA3, - "internvl": VLM_FAMILY_INTERNVL, - "internvl_chat": VLM_FAMILY_INTERNVL, -} -IMAGE_TOKEN_VLM_MODEL_TYPES = frozenset(VLM_MODEL_TYPE_TO_FAMILY) - - -def get_vlm_family(model_type: Optional[str]) -> Optional[str]: - return VLM_MODEL_TYPE_TO_FAMILY.get((model_type or "").lower()) - - -def is_image_token_vlm(model_type: Optional[str]) -> bool: - return get_vlm_family(model_type) is not None