fix for qwen w lora (#906)

This commit is contained in:
Wing Lian
2023-11-30 12:45:50 -05:00
committed by GitHub
parent 1d21aa6b0a
commit 3e3229e2d9

View File

@@ -412,15 +412,22 @@ def load_model(
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
if not skip_prepare_model_for_kbit_training:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to