From bf289123e9dbccd908db492096fba6a76c4514f4 Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 7 Dec 2023 19:48:40 +0100 Subject: [PATCH] Replace with flash modules --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index d347387cf..b0e053242 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -25,8 +25,11 @@ from transformers.models.llama.modeling_llama import ( from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP -from axolotl.monkeypatch.flash_modules import flashattn_forward - +from axolotl.monkeypatch.flash_modules import ( + flashattn_forward, + replace_cross_entropy, + replace_rms_norm +) LOG = logging.getLogger("axolotl") @@ -69,6 +72,10 @@ def replace_llama_attn_with_flash_attn( transformers.models.llama.modeling_llama.LlamaModel.forward = ( llama_model_forward ) + if cross_entropy: + replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss") + if rms_norm: + replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm") # Disable the transformation of the attention mask in LlamaModel as the flash attention