Skip to content
Discussion options

You must be logged in to vote

Good catch, confirmed on my end too.

The call was in the docstring but never actually made it into the code classic case of writing the plan and forgetting to execute it.

Fix is straightforward, add the import and the call between loading and LoRA injection:

from peft import prepare_model_for_kbit_training

def prepare_qlora_model(model_cfg: ModelConfig, lora_cfg: LoraConfig):
    model_cfg.use_4bit = True
    model, tokenizer = load_model_and_tokenizer(model_cfg)
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=True
    )
    model = inject_lora(model, lora_cfg)
    return model, tokenizer

And yeah the gradient checkpointing thing you flagged…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by newtscammander
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants