fix for qwen w lora (#906)

This commit is contained in:
Wing Lian
2023-11-30 12:45:50 -05:00
committed by GitHub
parent 1d21aa6b0a
commit 3e3229e2d9

View File

@@ -412,15 +412,22 @@ def load_model(
module.to(torch.float32) module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
if (cfg.adapter == "lora" and load_in_8bit) or ( if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit cfg.adapter == "qlora" and cfg.load_in_4bit
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training( if not skip_prepare_model_for_kbit_training:
model, use_gradient_checkpointing=cfg.gradient_checkpointing model = prepare_model_for_kbit_training(
) model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to