Bug: prepare_model_for_kbit_training() never called layer norms stay in wrong dtype #57
-
|
Found another issue in The docstring says step 2 is "Prepare for k-bit training (cast layer norms to fp32)" but the actual def prepare_qlora_model(model_cfg: ModelConfig, lora_cfg: LoraConfig):
"""
Full QLoRA setup:
1. Load model with 4-bit NF4 quantization
2. Prepare for k-bit training (cast layer norms to fp32) # <-- documented but never done
3. Inject LoRA adapters
"""
model_cfg.use_4bit = True
model, tokenizer = load_model_and_tokenizer(model_cfg)
model = inject_lora(model, lora_cfg) # goes straight to LoRA injection
return model, tokenizerWithout The fix: from peft import prepare_model_for_kbit_training
def prepare_qlora_model(model_cfg: ModelConfig, lora_cfg: LoraConfig):
model_cfg.use_4bit = True
model, tokenizer = load_model_and_tokenizer(model_cfg)
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True
)
model = inject_lora(model, lora_cfg)
return model, tokenizerAlso worth noting Saw noticeably more NaN losses without this fix on Mistral 7B after ~500 steps. Happy to open a PR. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Good catch, confirmed on my end too. The call was in the docstring but never actually made it into the code classic case of writing the plan and forgetting to execute it. Fix is straightforward, add the import and the call between loading and LoRA injection: from peft import prepare_model_for_kbit_training
def prepare_qlora_model(model_cfg: ModelConfig, lora_cfg: LoraConfig):
model_cfg.use_4bit = True
model, tokenizer = load_model_and_tokenizer(model_cfg)
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True
)
model = inject_lora(model, lora_cfg)
return model, tokenizerAnd yeah the gradient checkpointing thing you flagged is real setting it in PR welcome, small change. |
Beta Was this translation helpful? Give feedback.
Good catch, confirmed on my end too.
The call was in the docstring but never actually made it into the code classic case of writing the plan and forgetting to execute it.
Fix is straightforward, add the import and the call between loading and LoRA injection:
And yeah the gradient checkpointing thing you flagged…