sequence parallelism
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
from enum import Enum
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
class USPRingAttnType(Enum):
|
||||
BASIC = "basic"
|
||||
ZIGZAG = "zigzag"
|
||||
STRIDE = "stride"
|
||||
|
||||
def patch_seq_parallel():
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||
pass
|
||||
|
||||
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
|
||||
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
|
||||
|
||||
fa_forward = build_usp_fa_forward(ring_impl_type)
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
|
||||
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from yunchang import LongContextAttention, set_seq_parallel_pg
|
||||
|
||||
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
|
||||
|
||||
|
||||
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
|
||||
usp_attn = LongContextAttention(ring_impl_type.value)
|
||||
|
||||
def flash_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
dropout: float = 0.0,
|
||||
scaling: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
softcap: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
attn_output = usp_attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=scaling,
|
||||
causal=True,
|
||||
softcap=softcap,
|
||||
)
|
||||
return attn_output, None
|
||||
|
||||
return flash_attention_forward
|
||||
Reference in New Issue
Block a user