adding all batch ring-flash-attn methods via single adapter
This commit is contained in:
@@ -1,174 +0,0 @@
|
|||||||
"""
|
|
||||||
HuggingFace flash attention adapter for basic ring attention (batch API).
|
|
||||||
|
|
||||||
Inspired by
|
|
||||||
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
|
|
||||||
Our implementation closely follows the structure of that module, but we've minified it
|
|
||||||
somewhat to support only the latest versions of transformers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import transformers
|
|
||||||
import transformers.modeling_flash_attention_utils
|
|
||||||
from ring_flash_attn import ring_flash_attn_func
|
|
||||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
|
||||||
from transformers.modeling_flash_attention_utils import (
|
|
||||||
_flash_supports_window_size,
|
|
||||||
is_flash_attn_greater_or_equal,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
||||||
except ImportError:
|
|
||||||
ALL_ATTENTION_FUNCTIONS = None
|
|
||||||
|
|
||||||
|
|
||||||
def create_ring_flash_attention_forward(process_group: dist.ProcessGroup) -> Callable:
|
|
||||||
"""
|
|
||||||
Create a ring flash attention forward function compatible with HuggingFace's interface.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
process_group: A PyTorch distributed process group.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A function that implements the ring flash attention forward pass with the
|
|
||||||
signature expected by HuggingFace Transformers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# transformers 4.48+
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def _flash_attention_forward(
|
|
||||||
query_states: torch.Tensor,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
value_states: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
query_length: int,
|
|
||||||
is_causal: bool,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
position_ids: torch.Tensor | None = None,
|
|
||||||
softmax_scale: float | None = None,
|
|
||||||
sliding_window: int | None = None,
|
|
||||||
use_top_left_mask: bool = False,
|
|
||||||
softcap: float | None = None,
|
|
||||||
deterministic: bool = None,
|
|
||||||
cu_seq_lens_q: torch.LongTensor | None = None,
|
|
||||||
cu_seq_lens_k: torch.LongTensor | None = None,
|
|
||||||
max_length_q: int | None = None,
|
|
||||||
max_length_k: int | None = None,
|
|
||||||
target_dtype: torch.dtype | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Calls the forward method of Ring Flash Attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_states: Tensor containing the query vectors.
|
|
||||||
key_states: Tensor containing the key vectors.
|
|
||||||
value_states: Tensor containing the value vectors.
|
|
||||||
attention_mask: Not used in this implementation.
|
|
||||||
query_length: Integer representing the length of the query sequence.
|
|
||||||
is_causal: Boolean indicating whether to apply a causal mask to the attention.
|
|
||||||
dropout: Float representing the dropout probability. Default is 0.0.
|
|
||||||
position_ids: Not used in this implementation.
|
|
||||||
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
|
|
||||||
sliding_window: Optional integer defining the size of the sliding attention window.
|
|
||||||
Default is None.
|
|
||||||
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
|
|
||||||
Default is False.
|
|
||||||
softcap: Not used in this implementation.
|
|
||||||
deterministic: Optional boolean to enforce deterministic computation. Default is None.
|
|
||||||
cu_seq_lens_q: Not used in this implementation.
|
|
||||||
cu_seq_lens_k: Not used in this implementation.
|
|
||||||
max_length_q: Not used in this implementation.
|
|
||||||
max_length_k: Not used in this implementation.
|
|
||||||
target_dtype: Not used in this implementation.
|
|
||||||
**kwargs: Additional keyword arguments. Not used in this implementation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The output of the attention mechanism, with shape
|
|
||||||
`[batch_size, query_length, num_heads, head_dim]`.
|
|
||||||
"""
|
|
||||||
if not use_top_left_mask:
|
|
||||||
causal = is_causal
|
|
||||||
else:
|
|
||||||
causal = is_causal and query_length != 1
|
|
||||||
|
|
||||||
# Handle sliding window
|
|
||||||
use_sliding_windows = (
|
|
||||||
_flash_supports_window_size
|
|
||||||
and sliding_window is not None
|
|
||||||
and key_states.shape[1] > sliding_window
|
|
||||||
)
|
|
||||||
window_size = (
|
|
||||||
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle deterministic mode
|
|
||||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
|
||||||
if deterministic is None:
|
|
||||||
deterministic = (
|
|
||||||
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call ring flash attention function
|
|
||||||
attn_output = ring_flash_attn_func(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
dropout_p=dropout,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=causal,
|
|
||||||
window_size=window_size,
|
|
||||||
alibi_slopes=None,
|
|
||||||
deterministic=deterministic,
|
|
||||||
return_attn_probs=False,
|
|
||||||
group=process_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
return _flash_attention_forward
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def substitute_hf_flash_attn(process_group: dist.ProcessGroup):
|
|
||||||
"""
|
|
||||||
Substitute HuggingFace's flash attention implementation with ring-based implementation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
process_group: PyTorch distributed process group for communication.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Substitute flash attention
|
|
||||||
old_flash_attention_forward = (
|
|
||||||
transformers.modeling_flash_attention_utils._flash_attention_forward
|
|
||||||
)
|
|
||||||
new_flash_attention_forward = create_ring_flash_attention_forward(process_group)
|
|
||||||
|
|
||||||
if check_params(old_flash_attention_forward, new_flash_attention_forward):
|
|
||||||
transformers.modeling_flash_attention_utils._flash_attention_forward = (
|
|
||||||
new_flash_attention_forward
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"The signature of the new flash attention forward function does not match the old one."
|
|
||||||
)
|
|
||||||
except Exception as exception:
|
|
||||||
raise ValueError(
|
|
||||||
f"The current transformer version {transformers.__version__} is not supported. "
|
|
||||||
"Please use pip install -U transformers to upgrade to the latest version. "
|
|
||||||
"If the code failed with the latest version, "
|
|
||||||
f"please file an issue."
|
|
||||||
) from exception
|
|
||||||
|
|
||||||
# Register with ALL_ATTENTION_FUNCTIONS if available
|
|
||||||
if ALL_ATTENTION_FUNCTIONS is not None:
|
|
||||||
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
|
|
||||||
|
|
||||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
|
||||||
@@ -57,7 +57,12 @@ class RingAttnFunc(str, Enum):
|
|||||||
def register_ring_attn(
|
def register_ring_attn(
|
||||||
sequence_parallel_degree: int,
|
sequence_parallel_degree: int,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
|
<<<<<<< HEAD
|
||||||
ring_attn_func: RingAttnFunc | None,
|
ring_attn_func: RingAttnFunc | None,
|
||||||
|
=======
|
||||||
|
sample_packing: bool,
|
||||||
|
ring_attn_func: str | None,
|
||||||
|
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
@@ -120,6 +125,10 @@ def register_ring_attn(
|
|||||||
substitute_hf_flash_attn(
|
substitute_hf_flash_attn(
|
||||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||||
)
|
)
|
||||||
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
# TODO(djsaunde): handle other ring attn funcs in this branch
|
||||||
|
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
|
||||||
elif ring_attn_func in [
|
elif ring_attn_func in [
|
||||||
RingAttnFunc.BATCH_RING,
|
RingAttnFunc.BATCH_RING,
|
||||||
RingAttnFunc.BATCH_ZIGZAG,
|
RingAttnFunc.BATCH_ZIGZAG,
|
||||||
|
|||||||
@@ -647,21 +647,7 @@ class ModelLoader:
|
|||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||||
RingAttnFunc,
|
|
||||||
register_ring_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the ring attention function if passed in config
|
|
||||||
ring_attn_func = None
|
|
||||||
if self.cfg.ring_attn_func:
|
|
||||||
valid_funcs = [enum.value for enum in RingAttnFunc]
|
|
||||||
if self.cfg.ring_attn_func in valid_funcs:
|
|
||||||
ring_attn_func = RingAttnFunc(self.cfg.ring_attn_func)
|
|
||||||
else:
|
|
||||||
LOG.warning(
|
|
||||||
f"ring_attn_func: {self.cfg.ring_attn_func} must be one of {valid_funcs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize ring attn for sequence parallelism. This must be done after
|
# Initialize ring attn for sequence parallelism. This must be done after
|
||||||
# model init but before the first forward pass, since it modifies flash
|
# model init but before the first forward pass, since it modifies flash
|
||||||
|
|||||||
Reference in New Issue
Block a user