From 4ae8df16a968ffe5148158818ea15a3fa664a738 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 11 Apr 2025 05:08:08 +0000 Subject: [PATCH] adding all batch ring-flash-attn methods via single adapter --- .../ring_attn/adapters/batch_ring.py | 174 ------------------ .../monkeypatch/attention/ring_attn/patch.py | 9 + src/axolotl/utils/models.py | 16 +- 3 files changed, 10 insertions(+), 189 deletions(-) delete mode 100644 src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py deleted file mode 100644 index 91f1d0c67..000000000 --- a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -HuggingFace flash attention adapter for basic ring attention (batch API). - -Inspired by -https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py. -Our implementation closely follows the structure of that module, but we've minified it -somewhat to support only the latest versions of transformers. -""" - -# pylint: disable=protected-access - -import os -from typing import Callable - -import torch -import torch.distributed as dist -import transformers -import transformers.modeling_flash_attention_utils -from ring_flash_attn import ring_flash_attn_func -from ring_flash_attn.adapters.hf_adapter import check_params -from transformers.modeling_flash_attention_utils import ( - _flash_supports_window_size, - is_flash_attn_greater_or_equal, -) - -try: - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -except ImportError: - ALL_ATTENTION_FUNCTIONS = None - - -def create_ring_flash_attention_forward(process_group: dist.ProcessGroup) -> Callable: - """ - Create a ring flash attention forward function compatible with HuggingFace's interface. - - Args: - process_group: A PyTorch distributed process group. - - Returns: - A function that implements the ring flash attention forward pass with the - signature expected by HuggingFace Transformers. - """ - - # transformers 4.48+ - # pylint: disable=unused-argument - def _flash_attention_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, - is_causal: bool, - dropout: float = 0.0, - position_ids: torch.Tensor | None = None, - softmax_scale: float | None = None, - sliding_window: int | None = None, - use_top_left_mask: bool = False, - softcap: float | None = None, - deterministic: bool = None, - cu_seq_lens_q: torch.LongTensor | None = None, - cu_seq_lens_k: torch.LongTensor | None = None, - max_length_q: int | None = None, - max_length_k: int | None = None, - target_dtype: torch.dtype | None = None, - **kwargs, - ): - """ - Calls the forward method of Ring Flash Attention. - - Args: - query_states: Tensor containing the query vectors. - key_states: Tensor containing the key vectors. - value_states: Tensor containing the value vectors. - attention_mask: Not used in this implementation. - query_length: Integer representing the length of the query sequence. - is_causal: Boolean indicating whether to apply a causal mask to the attention. - dropout: Float representing the dropout probability. Default is 0.0. - position_ids: Not used in this implementation. - softmax_scale: Optional float value for the softmax scaling factor. Default is None. - sliding_window: Optional integer defining the size of the sliding attention window. - Default is None. - use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention. - Default is False. - softcap: Not used in this implementation. - deterministic: Optional boolean to enforce deterministic computation. Default is None. - cu_seq_lens_q: Not used in this implementation. - cu_seq_lens_k: Not used in this implementation. - max_length_q: Not used in this implementation. - max_length_k: Not used in this implementation. - target_dtype: Not used in this implementation. - **kwargs: Additional keyword arguments. Not used in this implementation. - - Returns: - torch.Tensor: The output of the attention mechanism, with shape - `[batch_size, query_length, num_heads, head_dim]`. - """ - if not use_top_left_mask: - causal = is_causal - else: - causal = is_causal and query_length != 1 - - # Handle sliding window - use_sliding_windows = ( - _flash_supports_window_size - and sliding_window is not None - and key_states.shape[1] > sliding_window - ) - window_size = ( - (sliding_window, sliding_window) if use_sliding_windows else (-1, -1) - ) - - # Handle deterministic mode - if is_flash_attn_greater_or_equal("2.4.1"): - if deterministic is None: - deterministic = ( - os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - ) - - # Call ring flash attention function - attn_output = ring_flash_attn_func( - query_states, - key_states, - value_states, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - group=process_group, - ) - - return attn_output - - return _flash_attention_forward - - -# pylint: disable=unused-argument -def substitute_hf_flash_attn(process_group: dist.ProcessGroup): - """ - Substitute HuggingFace's flash attention implementation with ring-based implementation. - - Args: - process_group: PyTorch distributed process group for communication. - """ - try: - # Substitute flash attention - old_flash_attention_forward = ( - transformers.modeling_flash_attention_utils._flash_attention_forward - ) - new_flash_attention_forward = create_ring_flash_attention_forward(process_group) - - if check_params(old_flash_attention_forward, new_flash_attention_forward): - transformers.modeling_flash_attention_utils._flash_attention_forward = ( - new_flash_attention_forward - ) - else: - raise ValueError( - "The signature of the new flash attention forward function does not match the old one." - ) - except Exception as exception: - raise ValueError( - f"The current transformer version {transformers.__version__} is not supported. " - "Please use pip install -U transformers to upgrade to the latest version. " - "If the code failed with the latest version, " - f"please file an issue." - ) from exception - - # Register with ALL_ATTENTION_FUNCTIONS if available - if ALL_ATTENTION_FUNCTIONS is not None: - from ring_flash_attn.adapters.hf_adapter import flash_attention_forward - - ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward diff --git a/src/axolotl/monkeypatch/attention/ring_attn/patch.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py index 160e57052..36d58c510 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/patch.py @@ -57,7 +57,12 @@ class RingAttnFunc(str, Enum): def register_ring_attn( sequence_parallel_degree: int, heads_k_stride: int | None, +<<<<<<< HEAD ring_attn_func: RingAttnFunc | None, +======= + sample_packing: bool, + ring_attn_func: str | None, +>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter) ): """ Create ring attention group and substitute flash attn with ring flash attn. @@ -120,6 +125,10 @@ def register_ring_attn( substitute_hf_flash_attn( process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 ) +<<<<<<< HEAD +======= + # TODO(djsaunde): handle other ring attn funcs in this branch +>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter) elif ring_attn_func in [ RingAttnFunc.BATCH_RING, RingAttnFunc.BATCH_ZIGZAG, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 9f4893974..d7105daba 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -647,21 +647,7 @@ class ModelLoader: patch_self_attn_lora(self.cfg) if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: - from axolotl.monkeypatch.attention.ring_attn import ( - RingAttnFunc, - register_ring_attn, - ) - - # Set the ring attention function if passed in config - ring_attn_func = None - if self.cfg.ring_attn_func: - valid_funcs = [enum.value for enum in RingAttnFunc] - if self.cfg.ring_attn_func in valid_funcs: - ring_attn_func = RingAttnFunc(self.cfg.ring_attn_func) - else: - LOG.warning( - f"ring_attn_func: {self.cfg.ring_attn_func} must be one of {valid_funcs}" - ) + from axolotl.monkeypatch.attention.ring_attn import register_ring_attn # Initialize ring attn for sequence parallelism. This must be done after # model init but before the first forward pass, since it modifies flash