batch api HF adapter for ring-flash-attn; cleanup and improvements
This commit is contained in:
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
HuggingFace flash attention adapter for basic ring attention (batch API).
|
||||
|
||||
Inspired by
|
||||
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
|
||||
Our implementation closely follows the structure of that module, but we've minified it
|
||||
somewhat to support only the latest versions of transformers.
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from ring_flash_attn import ring_flash_attn_func
|
||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size,
|
||||
is_flash_attn_greater_or_equal,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
except ImportError:
|
||||
ALL_ATTENTION_FUNCTIONS = None
|
||||
|
||||
|
||||
def create_ring_flash_attention_forward(process_group: dist.ProcessGroup) -> Callable:
|
||||
"""
|
||||
Create a ring flash attention forward function compatible with HuggingFace's interface.
|
||||
|
||||
Args:
|
||||
process_group: A PyTorch distributed process group that defines the communication
|
||||
topology for the ring attention pattern.
|
||||
|
||||
Returns:
|
||||
A function that implements the ring flash attention forward pass with the
|
||||
signature expected by HuggingFace Transformers.
|
||||
"""
|
||||
|
||||
# transformers 4.48+
|
||||
# pylint: disable=unused-argument
|
||||
def _flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
position_ids: torch.Tensor | None = None,
|
||||
softmax_scale: float | None = None,
|
||||
sliding_window: int | None = None,
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: float | None = None,
|
||||
deterministic: bool = None,
|
||||
cu_seq_lens_q: torch.LongTensor | None = None,
|
||||
cu_seq_lens_k: torch.LongTensor | None = None,
|
||||
max_length_q: int | None = None,
|
||||
max_length_k: int | None = None,
|
||||
target_dtype: torch.dtype | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Ring Flash Attention.
|
||||
|
||||
Args:
|
||||
query_states: Tensor containing the query vectors.
|
||||
key_states: Tensor containing the key vectors.
|
||||
value_states: Tensor containing the value vectors.
|
||||
attention_mask: Not used in this implementation.
|
||||
query_length: Integer representing the length of the query sequence.
|
||||
is_causal: Boolean indicating whether to apply a causal mask to the attention.
|
||||
dropout: Float representing the dropout probability. Default is 0.0.
|
||||
position_ids: Not used in this implementation.
|
||||
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
|
||||
sliding_window: Optional integer defining the size of the sliding attention window.
|
||||
Default is None.
|
||||
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
|
||||
Default is False.
|
||||
softcap: Not used in this implementation.
|
||||
deterministic: Optional boolean to enforce deterministic computation. Default is None.
|
||||
cu_seq_lens_q: Not used in this implementation.
|
||||
cu_seq_lens_k: Not used in this implementation.
|
||||
max_length_q: Not used in this implementation.
|
||||
max_length_k: Not used in this implementation.
|
||||
target_dtype: Not used in this implementation.
|
||||
**kwargs: Additional keyword arguments. Not used in this implementation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output of the attention mechanism, with shape
|
||||
`[batch_size, query_length, num_heads, head_dim]`.
|
||||
"""
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
causal = is_causal and query_length != 1
|
||||
|
||||
# Handle sliding window
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size
|
||||
and sliding_window is not None
|
||||
and key_states.shape[1] > sliding_window
|
||||
)
|
||||
window_size = (
|
||||
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
|
||||
)
|
||||
|
||||
# Handle deterministic mode
|
||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||
if deterministic is None:
|
||||
deterministic = (
|
||||
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
)
|
||||
|
||||
# Call ring flash attention function
|
||||
attn_output = ring_flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=None,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=False,
|
||||
group=process_group,
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
return _flash_attention_forward
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def substitute_hf_flash_attn(process_group: dist.ProcessGroup):
|
||||
"""
|
||||
Substitute HuggingFace's flash attention implementation with ring-based implementation.
|
||||
|
||||
Args:
|
||||
process_group: PyTorch distributed process group for communication.
|
||||
"""
|
||||
try:
|
||||
# Substitute flash attention
|
||||
old_flash_attention_forward = (
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||
)
|
||||
new_flash_attention_forward = create_ring_flash_attention_forward(process_group)
|
||||
|
||||
if check_params(old_flash_attention_forward, new_flash_attention_forward):
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward = (
|
||||
new_flash_attention_forward
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The signature of the new flash attention forward function does not match the old one."
|
||||
)
|
||||
except Exception as exception:
|
||||
raise ValueError(
|
||||
f"The current transformer version {transformers.__version__} is not supported. "
|
||||
"Please use pip install -U transformers to upgrade to the latest version. "
|
||||
"If the code failed with the latest version, "
|
||||
f"please file an issue."
|
||||
) from exception
|
||||
|
||||
# Register with ALL_ATTENTION_FUNCTIONS if available
|
||||
if ALL_ATTENTION_FUNCTIONS is not None:
|
||||
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
||||
@@ -76,7 +76,8 @@ def register_ring_attn(
|
||||
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs "
|
||||
f"using the {ring_attn_func.value} ring-flash-attn implementation"
|
||||
)
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
@@ -647,7 +647,21 @@ class ModelLoader:
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
RingAttnFunc,
|
||||
register_ring_attn,
|
||||
)
|
||||
|
||||
# Set the ring attention function if passed in config
|
||||
ring_attn_func = None
|
||||
if self.cfg.ring_attn_func:
|
||||
valid_funcs = [enum.value for enum in RingAttnFunc]
|
||||
if self.cfg.ring_attn_func in valid_funcs:
|
||||
ring_attn_func = RingAttnFunc(self.cfg.ring_attn_func)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"ring_attn_func: {self.cfg.ring_attn_func} must be one of {valid_funcs}"
|
||||
)
|
||||
|
||||
# Initialize ring attn for sequence parallelism. This must be done after
|
||||
# model init but before the first forward pass, since it modifies flash
|
||||
|
||||
Reference in New Issue
Block a user