Reworked Flash Attention ops: xentropy, rmsnorm
This commit is contained in:
@@ -2,11 +2,11 @@ import torch
|
||||
import logging
|
||||
import warnings
|
||||
from einops import rearrange
|
||||
from functools import partial
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
from axolotl.monkeypatch.fused_module import FusedAttention
|
||||
from axolotl.monkeypatch.fused_modules import FusedAttention
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
@@ -383,3 +383,44 @@ def generate_qkv(
|
||||
v,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
def replace_cross_entropy(modeling_class, module_name):
|
||||
"""
|
||||
modeling_class: transformers.models.llama.modeling_<class>
|
||||
module_name: CrossEntropyLoss
|
||||
"""
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
|
||||
cross_entropy_loss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
|
||||
setattr(modeling_class, module_name, cross_entropy_loss)
|
||||
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
|
||||
def replace_rms_norm(modeling_class, module_name):
|
||||
"""
|
||||
modeling_class: transformers.models.llama.modeling_<class>
|
||||
module_name: RMSNorm
|
||||
"""
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
class FlashRMSNorm(RMSNorm):
|
||||
"""A faster RMS Norm."""
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
setattr(modeling_class, module_name, FlashRMSNorm)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
@@ -24,8 +24,8 @@ from transformers.models.llama.modeling_llama import (
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
from axolotl.monkeypatch.fused_module import FusedAttention, FusedMLP
|
||||
from axolotl.monkeypatch.flash_module import flashattn_forward
|
||||
from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP
|
||||
from axolotl.monkeypatch.flash_modules import flashattn_forward
|
||||
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -70,39 +70,6 @@ def replace_llama_attn_with_flash_attn(
|
||||
llama_model_forward
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
|
||||
Reference in New Issue
Block a user