From 5a83875d77c887652cd37c61e6710e25faad2f1c Mon Sep 17 00:00:00 2001 From: Tai An Date: Thu, 14 May 2026 12:10:22 -0700 Subject: [PATCH] fix(utils/weight_conversion): handle str target_modules/target_parameters in MoE config conversion Closes #3229. --- src/peft/utils/transformers_weight_conversion.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/peft/utils/transformers_weight_conversion.py b/src/peft/utils/transformers_weight_conversion.py index 3350175481..f7d042df98 100644 --- a/src/peft/utils/transformers_weight_conversion.py +++ b/src/peft/utils/transformers_weight_conversion.py @@ -368,8 +368,15 @@ def _convert_peft_config_moe(peft_config, model_type: str) -> None: if not fused_targets: return - peft_config.target_parameters = set(peft_config.target_parameters or []) - peft_config.target_modules = set(peft_config.target_modules or []) + def _normalize_to_set(value): + if value is None: + return set() + if isinstance(value, str): + return {value} + return set(value) + + peft_config.target_parameters = _normalize_to_set(peft_config.target_parameters) + peft_config.target_modules = _normalize_to_set(peft_config.target_modules) if not hasattr(peft_config, "rank_pattern") or peft_config.rank_pattern is None: peft_config.rank_pattern = {} if not hasattr(peft_config, "alpha_pattern") or peft_config.alpha_pattern is None: