diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index b0163a655..39cfb5c17 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -58,7 +58,24 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): ) except ImportError: LOG.info( - "optimized flash-attention CrossEntropyLoss not found (run `pip install git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy`)" + "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" + ) + + try: + from flash_attn.ops.rms_norm import RMSNorm + + LOG.info("patching with flash_attn.ops.rms_norm") + + class LlamaRMSNorm(RMSNorm): + """Patched LLamaRMSNorm""" + + def __init__(self, hidden_size, eps=1e-6): + super().__init__(hidden_size, eps=eps) + + transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm + except ImportError: + LOG.info( + "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" )