sequence parallelism
This commit is contained in:
@@ -0,0 +1,18 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
class USPRingAttnType(Enum):
|
||||||
|
BASIC = "basic"
|
||||||
|
ZIGZAG = "zigzag"
|
||||||
|
STRIDE = "stride"
|
||||||
|
|
||||||
|
def patch_seq_parallel():
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
|
||||||
|
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
|
||||||
|
|
||||||
|
fa_forward = build_usp_fa_forward(ring_impl_type)
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
|
||||||
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Tuple, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from yunchang import LongContextAttention, set_seq_parallel_pg
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
|
||||||
|
|
||||||
|
|
||||||
|
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
|
||||||
|
usp_attn = LongContextAttention(ring_impl_type.value)
|
||||||
|
|
||||||
|
def flash_attention_forward(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scaling: Optional[float] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, None]:
|
||||||
|
attn_output = usp_attn(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=scaling,
|
||||||
|
causal=True,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
return flash_attention_forward
|
||||||
Reference in New Issue
Block a user