diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py new file mode 100644 index 000000000..ea4f9c57a --- /dev/null +++ b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py @@ -0,0 +1,175 @@ +""" +HuggingFace flash attention adapter for basic ring attention (batch API). + +Inspired by +https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py. +Our implementation closely follows the structure of that module, but we've minified it +somewhat to support only the latest versions of transformers. +""" + +# pylint: disable=protected-access + +import os +from typing import Callable + +import torch +import torch.distributed as dist +import transformers +import transformers.modeling_flash_attention_utils +from ring_flash_attn import ring_flash_attn_func +from ring_flash_attn.adapters.hf_adapter import check_params +from transformers.modeling_flash_attention_utils import ( + _flash_supports_window_size, + is_flash_attn_greater_or_equal, +) + +try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +except ImportError: + ALL_ATTENTION_FUNCTIONS = None + + +def create_ring_flash_attention_forward(process_group: dist.ProcessGroup) -> Callable: + """ + Create a ring flash attention forward function compatible with HuggingFace's interface. + + Args: + process_group: A PyTorch distributed process group that defines the communication + topology for the ring attention pattern. + + Returns: + A function that implements the ring flash attention forward pass with the + signature expected by HuggingFace Transformers. + """ + + # transformers 4.48+ + # pylint: disable=unused-argument + def _flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: torch.Tensor | None = None, + softmax_scale: float | None = None, + sliding_window: int | None = None, + use_top_left_mask: bool = False, + softcap: float | None = None, + deterministic: bool = None, + cu_seq_lens_q: torch.LongTensor | None = None, + cu_seq_lens_k: torch.LongTensor | None = None, + max_length_q: int | None = None, + max_length_k: int | None = None, + target_dtype: torch.dtype | None = None, + **kwargs, + ): + """ + Calls the forward method of Ring Flash Attention. + + Args: + query_states: Tensor containing the query vectors. + key_states: Tensor containing the key vectors. + value_states: Tensor containing the value vectors. + attention_mask: Not used in this implementation. + query_length: Integer representing the length of the query sequence. + is_causal: Boolean indicating whether to apply a causal mask to the attention. + dropout: Float representing the dropout probability. Default is 0.0. + position_ids: Not used in this implementation. + softmax_scale: Optional float value for the softmax scaling factor. Default is None. + sliding_window: Optional integer defining the size of the sliding attention window. + Default is None. + use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention. + Default is False. + softcap: Not used in this implementation. + deterministic: Optional boolean to enforce deterministic computation. Default is None. + cu_seq_lens_q: Not used in this implementation. + cu_seq_lens_k: Not used in this implementation. + max_length_q: Not used in this implementation. + max_length_k: Not used in this implementation. + target_dtype: Not used in this implementation. + **kwargs: Additional keyword arguments. Not used in this implementation. + + Returns: + torch.Tensor: The output of the attention mechanism, with shape + `[batch_size, query_length, num_heads, head_dim]`. + """ + if not use_top_left_mask: + causal = is_causal + else: + causal = is_causal and query_length != 1 + + # Handle sliding window + use_sliding_windows = ( + _flash_supports_window_size + and sliding_window is not None + and key_states.shape[1] > sliding_window + ) + window_size = ( + (sliding_window, sliding_window) if use_sliding_windows else (-1, -1) + ) + + # Handle deterministic mode + if is_flash_attn_greater_or_equal("2.4.1"): + if deterministic is None: + deterministic = ( + os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + ) + + # Call ring flash attention function + attn_output = ring_flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + group=process_group, + ) + + return attn_output + + return _flash_attention_forward + + +# pylint: disable=unused-argument +def substitute_hf_flash_attn(process_group: dist.ProcessGroup): + """ + Substitute HuggingFace's flash attention implementation with ring-based implementation. + + Args: + process_group: PyTorch distributed process group for communication. + """ + try: + # Substitute flash attention + old_flash_attention_forward = ( + transformers.modeling_flash_attention_utils._flash_attention_forward + ) + new_flash_attention_forward = create_ring_flash_attention_forward(process_group) + + if check_params(old_flash_attention_forward, new_flash_attention_forward): + transformers.modeling_flash_attention_utils._flash_attention_forward = ( + new_flash_attention_forward + ) + else: + raise ValueError( + "The signature of the new flash attention forward function does not match the old one." + ) + except Exception as exception: + raise ValueError( + f"The current transformer version {transformers.__version__} is not supported. " + "Please use pip install -U transformers to upgrade to the latest version. " + "If the code failed with the latest version, " + f"please file an issue." + ) from exception + + # Register with ALL_ATTENTION_FUNCTIONS if available + if ALL_ATTENTION_FUNCTIONS is not None: + from ring_flash_attn.adapters.hf_adapter import flash_attention_forward + + ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward diff --git a/src/axolotl/monkeypatch/attention/ring_attn/patch.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py index b5587ddca..160e57052 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/patch.py @@ -76,7 +76,8 @@ def register_ring_attn( LOG.info( "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs" + f"each sequence will be processed across {sequence_parallel_degree} GPUs " + f"using the {ring_attn_func.value} ring-flash-attn implementation" ) rank = dist.get_rank() diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d7105daba..9f4893974 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -647,7 +647,21 @@ class ModelLoader: patch_self_attn_lora(self.cfg) if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: - from axolotl.monkeypatch.attention.ring_attn import register_ring_attn + from axolotl.monkeypatch.attention.ring_attn import ( + RingAttnFunc, + register_ring_attn, + ) + + # Set the ring attention function if passed in config + ring_attn_func = None + if self.cfg.ring_attn_func: + valid_funcs = [enum.value for enum in RingAttnFunc] + if self.cfg.ring_attn_func in valid_funcs: + ring_attn_func = RingAttnFunc(self.cfg.ring_attn_func) + else: + LOG.warning( + f"ring_attn_func: {self.cfg.ring_attn_func} must be one of {valid_funcs}" + ) # Initialize ring attn for sequence parallelism. This must be done after # model init but before the first forward pass, since it modifies flash