From d88e071120244a05d884536ebdba7e6546fad1ec Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 23 Feb 2025 12:19:34 -0500 Subject: [PATCH] sequence parallelism --- .../attention/sequence_parallel/__init__.py | 18 ++++++++++ .../attention/sequence_parallel/usp.py | 36 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 src/axolotl/monkeypatch/attention/sequence_parallel/__init__.py create mode 100644 src/axolotl/monkeypatch/attention/sequence_parallel/usp.py diff --git a/src/axolotl/monkeypatch/attention/sequence_parallel/__init__.py b/src/axolotl/monkeypatch/attention/sequence_parallel/__init__.py new file mode 100644 index 000000000..3ff99c775 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/sequence_parallel/__init__.py @@ -0,0 +1,18 @@ +from enum import Enum + +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +class USPRingAttnType(Enum): + BASIC = "basic" + ZIGZAG = "zigzag" + STRIDE = "stride" + +def patch_seq_parallel(): + ALL_ATTENTION_FUNCTIONS["flash_attention_2"] + pass + +def apply_usp_attn_patch(ring_impl_type: USPRingAttnType): + from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward + + fa_forward = build_usp_fa_forward(ring_impl_type) + ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward diff --git a/src/axolotl/monkeypatch/attention/sequence_parallel/usp.py b/src/axolotl/monkeypatch/attention/sequence_parallel/usp.py new file mode 100644 index 000000000..34ee792c5 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/sequence_parallel/usp.py @@ -0,0 +1,36 @@ +from enum import Enum +from typing import Optional, Tuple, Callable + +import torch +from yunchang import LongContextAttention, set_seq_parallel_pg + +from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType + + +def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable: + usp_attn = LongContextAttention(ring_impl_type.value) + + def flash_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + **kwargs, + ) -> Tuple[torch.Tensor, None]: + attn_output = usp_attn( + query, + key, + value, + dropout_p=dropout, + softmax_scale=scaling, + causal=True, + softcap=softcap, + ) + return attn_output, None + + return flash_attention_forward