diff --git a/src/axolotl/flash_attn.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py similarity index 100% rename from src/axolotl/flash_attn.py rename to src/axolotl/monkeypatch/llama_attn_hijack_flash.py diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2c199a16e..253bdcbd8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -92,7 +92,9 @@ def load_model( if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and not cfg.inference: - from axolotl.flash_attn import replace_llama_attn_with_flash_attn + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + replace_llama_attn_with_flash_attn, + ) LOG.info("patching with flash attention") replace_llama_attn_with_flash_attn() @@ -331,6 +333,16 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) + # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if cfg.flash_attention and cfg.is_llama_derived_model: + for name, module in model.named_modules(): + if "norm" in name: + module.to(torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module.to(torch_dtype) + model, lora_config = load_adapter(model, cfg, adapter) if cfg.ddp and not load_in_8bit: @@ -407,14 +419,6 @@ def load_llama_adapter(model, cfg): else: model = get_peft_model(model, peft_config) - if cfg.flash_attention: - for name, module in model.named_modules(): - if "norm" in name: - module.to(torch.float16) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module.to(torch.float16) - model.print_trainable_parameters() return model, peft_config