Reworked Flash Attention ops: xentropy, rmsnorm
This commit is contained in:
@@ -2,11 +2,11 @@ import torch
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from functools import partial
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from axolotl.monkeypatch.fused_modules import FusedAttention
|
||||||
from axolotl.monkeypatch.fused_module import FusedAttention
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
@@ -383,3 +383,44 @@ def generate_qkv(
|
|||||||
v,
|
v,
|
||||||
output_pad_fn,
|
output_pad_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def replace_cross_entropy(modeling_class, module_name):
|
||||||
|
"""
|
||||||
|
modeling_class: transformers.models.llama.modeling_<class>
|
||||||
|
module_name: CrossEntropyLoss
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
|
||||||
|
cross_entropy_loss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(modeling_class, module_name, cross_entropy_loss)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
LOG.info(
|
||||||
|
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def replace_rms_norm(modeling_class, module_name):
|
||||||
|
"""
|
||||||
|
modeling_class: transformers.models.llama.modeling_<class>
|
||||||
|
module_name: RMSNorm
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
class FlashRMSNorm(RMSNorm):
|
||||||
|
"""A faster RMS Norm."""
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
|
setattr(modeling_class, module_name, FlashRMSNorm)
|
||||||
|
except ImportError:
|
||||||
|
LOG.info(
|
||||||
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
|
)
|
||||||
@@ -24,8 +24,8 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
from axolotl.monkeypatch.fused_module import FusedAttention, FusedMLP
|
from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP
|
||||||
from axolotl.monkeypatch.flash_module import flashattn_forward
|
from axolotl.monkeypatch.flash_modules import flashattn_forward
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -70,39 +70,6 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
llama_model_forward
|
llama_model_forward
|
||||||
)
|
)
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
|
||||||
if cross_entropy:
|
|
||||||
try:
|
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
|
||||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
|
||||||
CrossEntropyLoss, inplace_backward=True
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
LOG.info(
|
|
||||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
|
||||||
if rms_norm:
|
|
||||||
try:
|
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
|
||||||
|
|
||||||
class LlamaRMSNorm(RMSNorm):
|
|
||||||
"""Patched LLamaRMSNorm"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__(hidden_size, eps=eps)
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
|
||||||
except ImportError:
|
|
||||||
LOG.info(
|
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
|
|||||||
Reference in New Issue
Block a user