diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0d8c812f3..acc6f41fa 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -412,15 +412,22 @@ def load_model( module.to(torch.float32) needs_fa2_dtype = cfg.adapter or cfg.fsdp + skip_prepare_model_for_kbit_training = False + + if cfg.model_config_type == "qwen" and cfg.adapter == "lora": + # Qwen doesn't play nicely with LoRA if this is enabled + skip_prepare_model_for_kbit_training = True + if (cfg.adapter == "lora" and load_in_8bit) or ( cfg.adapter == "qlora" and cfg.load_in_4bit ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") if cfg.gradient_checkpointing: model.gradient_checkpointing_enable() - model = prepare_model_for_kbit_training( - model, use_gradient_checkpointing=cfg.gradient_checkpointing - ) + if not skip_prepare_model_for_kbit_training: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=cfg.gradient_checkpointing + ) needs_fa2_dtype = True # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to