diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 39cfb5c17..ef048082c 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -64,14 +64,13 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): try: from flash_attn.ops.rms_norm import RMSNorm - LOG.info("patching with flash_attn.ops.rms_norm") - class LlamaRMSNorm(RMSNorm): """Patched LLamaRMSNorm""" def __init__(self, hidden_size, eps=1e-6): super().__init__(hidden_size, eps=eps) + LOG.info("patching with flash_attn.ops.rms_norm") transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm except ImportError: LOG.info(