sequence parallelism

This commit is contained in:
Wing Lian
2025-02-23 12:19:34 -05:00
parent a4170030ab
commit d88e071120
2 changed files with 54 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
from enum import Enum
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
class USPRingAttnType(Enum):
BASIC = "basic"
ZIGZAG = "zigzag"
STRIDE = "stride"
def patch_seq_parallel():
ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
pass
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
fa_forward = build_usp_fa_forward(ring_impl_type)
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward

View File

@@ -0,0 +1,36 @@
from enum import Enum
from typing import Optional, Tuple, Callable
import torch
from yunchang import LongContextAttention, set_seq_parallel_pg
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
usp_attn = LongContextAttention(ring_impl_type.value)
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
attn_output = usp_attn(
query,
key,
value,
dropout_p=dropout,
softmax_scale=scaling,
causal=True,
softcap=softcap,
)
return attn_output, None
return flash_attention_forward