batch api HF adapter for ring-flash-attn; cleanup and improvements

This commit is contained in:
Dan Saunders
2025-04-11 03:45:34 +00:00
parent 3f1873cc62
commit 2bb5c1fe7e
3 changed files with 192 additions and 2 deletions

View File

@@ -0,0 +1,175 @@
"""
HuggingFace flash attention adapter for basic ring attention (batch API).
Inspired by
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
Our implementation closely follows the structure of that module, but we've minified it
somewhat to support only the latest versions of transformers.
"""
# pylint: disable=protected-access
import os
from typing import Callable
import torch
import torch.distributed as dist
import transformers
import transformers.modeling_flash_attention_utils
from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size,
is_flash_attn_greater_or_equal,
)
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
except ImportError:
ALL_ATTENTION_FUNCTIONS = None
def create_ring_flash_attention_forward(process_group: dist.ProcessGroup) -> Callable:
"""
Create a ring flash attention forward function compatible with HuggingFace's interface.
Args:
process_group: A PyTorch distributed process group that defines the communication
topology for the ring attention pattern.
Returns:
A function that implements the ring flash attention forward pass with the
signature expected by HuggingFace Transformers.
"""
# transformers 4.48+
# pylint: disable=unused-argument
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: torch.Tensor | None = None,
softmax_scale: float | None = None,
sliding_window: int | None = None,
use_top_left_mask: bool = False,
softcap: float | None = None,
deterministic: bool = None,
cu_seq_lens_q: torch.LongTensor | None = None,
cu_seq_lens_k: torch.LongTensor | None = None,
max_length_q: int | None = None,
max_length_k: int | None = None,
target_dtype: torch.dtype | None = None,
**kwargs,
):
"""
Calls the forward method of Ring Flash Attention.
Args:
query_states: Tensor containing the query vectors.
key_states: Tensor containing the key vectors.
value_states: Tensor containing the value vectors.
attention_mask: Not used in this implementation.
query_length: Integer representing the length of the query sequence.
is_causal: Boolean indicating whether to apply a causal mask to the attention.
dropout: Float representing the dropout probability. Default is 0.0.
position_ids: Not used in this implementation.
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
sliding_window: Optional integer defining the size of the sliding attention window.
Default is None.
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
Default is False.
softcap: Not used in this implementation.
deterministic: Optional boolean to enforce deterministic computation. Default is None.
cu_seq_lens_q: Not used in this implementation.
cu_seq_lens_k: Not used in this implementation.
max_length_q: Not used in this implementation.
max_length_k: Not used in this implementation.
target_dtype: Not used in this implementation.
**kwargs: Additional keyword arguments. Not used in this implementation.
Returns:
torch.Tensor: The output of the attention mechanism, with shape
`[batch_size, query_length, num_heads, head_dim]`.
"""
if not use_top_left_mask:
causal = is_causal
else:
causal = is_causal and query_length != 1
# Handle sliding window
use_sliding_windows = (
_flash_supports_window_size
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
window_size = (
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
)
# Handle deterministic mode
if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
deterministic = (
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
)
# Call ring flash attention function
attn_output = ring_flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
group=process_group,
)
return attn_output
return _flash_attention_forward
# pylint: disable=unused-argument
def substitute_hf_flash_attn(process_group: dist.ProcessGroup):
"""
Substitute HuggingFace's flash attention implementation with ring-based implementation.
Args:
process_group: PyTorch distributed process group for communication.
"""
try:
# Substitute flash attention
old_flash_attention_forward = (
transformers.modeling_flash_attention_utils._flash_attention_forward
)
new_flash_attention_forward = create_ring_flash_attention_forward(process_group)
if check_params(old_flash_attention_forward, new_flash_attention_forward):
transformers.modeling_flash_attention_utils._flash_attention_forward = (
new_flash_attention_forward
)
else:
raise ValueError(
"The signature of the new flash attention forward function does not match the old one."
)
except Exception as exception:
raise ValueError(
f"The current transformer version {transformers.__version__} is not supported. "
"Please use pip install -U transformers to upgrade to the latest version. "
"If the code failed with the latest version, "
f"please file an issue."
) from exception
# Register with ALL_ATTENTION_FUNCTIONS if available
if ALL_ATTENTION_FUNCTIONS is not None:
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward

View File

@@ -76,7 +76,8 @@ def register_ring_attn(
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
f"each sequence will be processed across {sequence_parallel_degree} GPUs "
f"using the {ring_attn_func.value} ring-flash-attn implementation"
)
rank = dist.get_rank()

View File

@@ -647,7 +647,21 @@ class ModelLoader:
patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Set the ring attention function if passed in config
ring_attn_func = None
if self.cfg.ring_attn_func:
valid_funcs = [enum.value for enum in RingAttnFunc]
if self.cfg.ring_attn_func in valid_funcs:
ring_attn_func = RingAttnFunc(self.cfg.ring_attn_func)
else:
LOG.warning(
f"ring_attn_func: {self.cfg.ring_attn_func} must be one of {valid_funcs}"
)
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash