From 5302d2d5348776ffd667b4828226e80beafc4e98 Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 7 Dec 2023 17:26:06 +0100 Subject: [PATCH] Reworked Flash Attention ops: xentropy, rmsnorm --- .../{flash_module.py => flash_modules.py} | 45 ++++++++++++++++++- .../{fused_module.py => fused_modules.py} | 0 .../monkeypatch/llama_attn_hijack_flash.py | 37 +-------------- 3 files changed, 45 insertions(+), 37 deletions(-) rename src/axolotl/monkeypatch/{flash_module.py => flash_modules.py} (89%) rename src/axolotl/monkeypatch/{fused_module.py => fused_modules.py} (100%) diff --git a/src/axolotl/monkeypatch/flash_module.py b/src/axolotl/monkeypatch/flash_modules.py similarity index 89% rename from src/axolotl/monkeypatch/flash_module.py rename to src/axolotl/monkeypatch/flash_modules.py index 443523836..42e5345c0 100644 --- a/src/axolotl/monkeypatch/flash_module.py +++ b/src/axolotl/monkeypatch/flash_modules.py @@ -2,11 +2,11 @@ import torch import logging import warnings from einops import rearrange +from functools import partial import torch.nn.functional as F from typing import Optional, Tuple from flash_attn.bert_padding import pad_input, unpad_input - -from axolotl.monkeypatch.fused_module import FusedAttention +from axolotl.monkeypatch.fused_modules import FusedAttention try: from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports @@ -383,3 +383,44 @@ def generate_qkv( v, output_pad_fn, ) + +def replace_cross_entropy(modeling_class, module_name): + """ + modeling_class: transformers.models.llama.modeling_ + module_name: CrossEntropyLoss + """ + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss + + LOG.info("patching with flash_attn.losses.cross_entropy") + + cross_entropy_loss = partial( + CrossEntropyLoss, inplace_backward=True + ) + + setattr(modeling_class, module_name, cross_entropy_loss) + + except ImportError: + LOG.info( + "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" + ) + +def replace_rms_norm(modeling_class, module_name): + """ + modeling_class: transformers.models.llama.modeling_ + module_name: RMSNorm + """ + try: + from flash_attn.ops.rms_norm import RMSNorm + + class FlashRMSNorm(RMSNorm): + """A faster RMS Norm.""" + def __init__(self, hidden_size, eps=1e-6): + super().__init__(hidden_size, eps=eps) + + LOG.info("patching with flash_attn.ops.rms_norm") + setattr(modeling_class, module_name, FlashRMSNorm) + except ImportError: + LOG.info( + "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" + ) \ No newline at end of file diff --git a/src/axolotl/monkeypatch/fused_module.py b/src/axolotl/monkeypatch/fused_modules.py similarity index 100% rename from src/axolotl/monkeypatch/fused_module.py rename to src/axolotl/monkeypatch/fused_modules.py diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 92a383b1b..d347387cf 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -24,8 +24,8 @@ from transformers.models.llama.modeling_llama import ( ) from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name -from axolotl.monkeypatch.fused_module import FusedAttention, FusedMLP -from axolotl.monkeypatch.flash_module import flashattn_forward +from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP +from axolotl.monkeypatch.flash_modules import flashattn_forward LOG = logging.getLogger("axolotl") @@ -70,39 +70,6 @@ def replace_llama_attn_with_flash_attn( llama_model_forward ) - # skip only if explicitly disabled - if cross_entropy: - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss - - LOG.info("patching with flash_attn.losses.cross_entropy") - transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( - CrossEntropyLoss, inplace_backward=True - ) - except ImportError: - LOG.info( - "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" - ) - - # skip only if explicitly disabled - if rms_norm: - try: - from flash_attn.ops.rms_norm import RMSNorm - - class LlamaRMSNorm(RMSNorm): - """Patched LLamaRMSNorm""" - - def __init__(self, hidden_size, eps=1e-6): - super().__init__(hidden_size, eps=eps) - - LOG.info("patching with flash_attn.ops.rms_norm") - transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm - except ImportError: - LOG.info( - "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" - ) - - # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask