diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 261acd934..c95e346e1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -355,6 +355,7 @@ def load_model( if hasattr(module, "weight"): module.to(torch.float32) + fix_dtype = False if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) @@ -363,16 +364,19 @@ def load_model( model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) + fix_dtype = True - # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - if cfg.flash_attention and cfg.is_llama_derived_model: - for name, module in model.named_modules(): - if "norm" in name: + # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if (fix_dtype or cfg.adapter == "" or cfg.adapter == None) and ( + cfg.flash_attention and cfg.is_llama_derived_model + ): + for name, module in model.named_modules(): + if "norm" in name: + module.to(cfg.torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): module.to(cfg.torch_dtype) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module.to(cfg.torch_dtype) model, lora_config = load_adapter(model, cfg, cfg.adapter)