From 72a6fe1c1f1916651cc34b80cf0a94984775b442 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Mon, 4 Sep 2023 19:44:51 -0400 Subject: [PATCH] use flash_attn rmsnorm when available (#526) * use flash_attn xentropy when available * use flash_attn.ops.rms_norm when available * log when xentropy is not found * log how to install RMSNorm * add quotes so pip install works --- .../monkeypatch/llama_attn_hijack_flash.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index b0163a655..39cfb5c17 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -58,7 +58,24 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): ) except ImportError: LOG.info( - "optimized flash-attention CrossEntropyLoss not found (run `pip install git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy`)" + "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" + ) + + try: + from flash_attn.ops.rms_norm import RMSNorm + + LOG.info("patching with flash_attn.ops.rms_norm") + + class LlamaRMSNorm(RMSNorm): + """Patched LLamaRMSNorm""" + + def __init__(self, hidden_size, eps=1e-6): + super().__init__(hidden_size, eps=eps) + + transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm + except ImportError: + LOG.info( + "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" )