Reworked Flash Attention ops: xentropy, rmsnorm

This commit is contained in:
Casper
2023-12-07 17:26:06 +01:00
parent a6fefa8885
commit 5302d2d534
3 changed files with 45 additions and 37 deletions

View File

@@ -2,11 +2,11 @@ import torch
import logging
import warnings
from einops import rearrange
from functools import partial
import torch.nn.functional as F
from typing import Optional, Tuple
from flash_attn.bert_padding import pad_input, unpad_input
from axolotl.monkeypatch.fused_module import FusedAttention
from axolotl.monkeypatch.fused_modules import FusedAttention
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -383,3 +383,44 @@ def generate_qkv(
v,
output_pad_fn,
)
def replace_cross_entropy(modeling_class, module_name):
"""
modeling_class: transformers.models.llama.modeling_<class>
module_name: CrossEntropyLoss
"""
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
cross_entropy_loss = partial(
CrossEntropyLoss, inplace_backward=True
)
setattr(modeling_class, module_name, cross_entropy_loss)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
def replace_rms_norm(modeling_class, module_name):
"""
modeling_class: transformers.models.llama.modeling_<class>
module_name: RMSNorm
"""
try:
from flash_attn.ops.rms_norm import RMSNorm
class FlashRMSNorm(RMSNorm):
"""A faster RMS Norm."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
setattr(modeling_class, module_name, FlashRMSNorm)
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)

View File

@@ -24,8 +24,8 @@ from transformers.models.llama.modeling_llama import (
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from axolotl.monkeypatch.fused_module import FusedAttention, FusedMLP
from axolotl.monkeypatch.flash_module import flashattn_forward
from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP
from axolotl.monkeypatch.flash_modules import flashattn_forward
LOG = logging.getLogger("axolotl")
@@ -70,39 +70,6 @@ def replace_llama_attn_with_flash_attn(
llama_model_forward
)
# skip only if explicitly disabled
if cross_entropy:
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
# skip only if explicitly disabled
if rms_norm:
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask