batch api HF adapter for ring-flash-attn; cleanup and improvements (#2520)
* batch api HF adapter for ring-flash-attn; cleanup and improvements * update * adding all batch ring-flash-attn methods via single adapter * removing pad_to_sequence_len=False for now * fix * updating docs to include batch SP * review comments * fixes for batch API funcs, simplify * fixes * fix * updates * add batch_zigzag smoke test
This commit is contained in:
@@ -693,6 +693,9 @@ sequence_parallel_degree:
|
|||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
# Must evenly divide the number of KV heads in your model.
|
# Must evenly divide the number of KV heads in your model.
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
|
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
|
||||||
|
# in the sample packing case, and "batch_ring" in the non-sample packing case.
|
||||||
|
ring_attn_func:
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ To enable sequence parallelism, add the following to your configuration file:
|
|||||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
|
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
|
||||||
|
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
||||||
|
ring_attn_func:
|
||||||
```
|
```
|
||||||
|
|
||||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||||
|
|||||||
@@ -776,6 +776,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||||
self.cfg.sequence_parallel_degree
|
self.cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
|
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
@@ -933,6 +934,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||||
|
kwargs["ring_attn_func"] = training_args.ring_attn_func
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from PIL.Image import Resampling
|
|||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingMixins:
|
class AxolotlTrainingMixins:
|
||||||
@@ -218,6 +220,12 @@ class AxolotlTrainingMixins:
|
|||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||||
)
|
)
|
||||||
|
ring_attn_func: Optional[RingAttnFunc] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The ring-flash-attn function to use in sequence parallelism"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# multi-modal section
|
# multi-modal section
|
||||||
|
|
||||||
|
|||||||
12
src/axolotl/monkeypatch/attention/ring_attn/__init__.py
Normal file
12
src/axolotl/monkeypatch/attention/ring_attn/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""Init for ring attention monkeypatch module"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .patch import (
|
||||||
|
RingAttnFunc,
|
||||||
|
get_ring_attn_group,
|
||||||
|
register_ring_attn,
|
||||||
|
set_ring_attn_group,
|
||||||
|
update_ring_attn_params,
|
||||||
|
)
|
||||||
192
src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py
Normal file
192
src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
HuggingFace flash attention adapter for basic ring attention (batch API).
|
||||||
|
|
||||||
|
Inspired by
|
||||||
|
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
|
||||||
|
Our implementation closely follows the structure of that module, but we've minified it
|
||||||
|
somewhat to support only the latest versions of transformers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=protected-access,cyclic-import
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import transformers
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
|
from ring_flash_attn import (
|
||||||
|
ring_flash_attn_func,
|
||||||
|
stripe_flash_attn_func,
|
||||||
|
zigzag_ring_flash_attn_func,
|
||||||
|
)
|
||||||
|
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||||
|
from transformers.modeling_flash_attention_utils import (
|
||||||
|
_flash_supports_window_size,
|
||||||
|
is_flash_attn_greater_or_equal,
|
||||||
|
)
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||||
|
|
||||||
|
RING_ATTN_FUNC_MAPPING = {
|
||||||
|
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
|
||||||
|
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
|
||||||
|
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_flash_attn_forward(
|
||||||
|
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
Create a ring flash attention forward function compatible with HuggingFace's
|
||||||
|
interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_group: A PyTorch distributed process group.
|
||||||
|
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
|
||||||
|
attention with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that implements the ring flash attention forward pass with the
|
||||||
|
signature expected by HuggingFace Transformers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# transformers 4.48+
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def _flash_attention_forward(
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
is_causal: bool,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
position_ids: torch.Tensor | None = None,
|
||||||
|
softmax_scale: float | None = None,
|
||||||
|
sliding_window: int | None = None,
|
||||||
|
use_top_left_mask: bool = False,
|
||||||
|
softcap: float | None = None,
|
||||||
|
deterministic: bool = None,
|
||||||
|
cu_seq_lens_q: torch.LongTensor | None = None,
|
||||||
|
cu_seq_lens_k: torch.LongTensor | None = None,
|
||||||
|
max_length_q: int | None = None,
|
||||||
|
max_length_k: int | None = None,
|
||||||
|
target_dtype: torch.dtype | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the forward method of Ring Flash Attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_states: Tensor containing the query vectors.
|
||||||
|
key_states: Tensor containing the key vectors.
|
||||||
|
value_states: Tensor containing the value vectors.
|
||||||
|
attention_mask: Not used in this implementation.
|
||||||
|
query_length: Integer representing the length of the query sequence.
|
||||||
|
is_causal: Boolean indicating whether to apply a causal mask to the attention.
|
||||||
|
dropout: Float representing the dropout probability. Default is 0.0.
|
||||||
|
position_ids: Not used in this implementation.
|
||||||
|
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
|
||||||
|
sliding_window: Optional integer defining the size of the sliding attention window.
|
||||||
|
Default is None.
|
||||||
|
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
|
||||||
|
Default is False.
|
||||||
|
softcap: Not used in this implementation.
|
||||||
|
deterministic: Optional boolean to enforce deterministic computation. Default is None.
|
||||||
|
cu_seq_lens_q: Not used in this implementation.
|
||||||
|
cu_seq_lens_k: Not used in this implementation.
|
||||||
|
max_length_q: Not used in this implementation.
|
||||||
|
max_length_k: Not used in this implementation.
|
||||||
|
target_dtype: Not used in this implementation.
|
||||||
|
**kwargs: Additional keyword arguments. Not used in this implementation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output of the attention mechanism, with shape
|
||||||
|
`[batch_size, query_length, num_heads, head_dim]`.
|
||||||
|
"""
|
||||||
|
if not use_top_left_mask:
|
||||||
|
causal = is_causal
|
||||||
|
else:
|
||||||
|
causal = is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Handle sliding window
|
||||||
|
use_sliding_windows = (
|
||||||
|
_flash_supports_window_size
|
||||||
|
and sliding_window is not None
|
||||||
|
and key_states.shape[1] > sliding_window
|
||||||
|
)
|
||||||
|
window_size = (
|
||||||
|
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle deterministic mode
|
||||||
|
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||||
|
if deterministic is None:
|
||||||
|
deterministic = (
|
||||||
|
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call ring flash attention function
|
||||||
|
attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
window_size=window_size,
|
||||||
|
alibi_slopes=None,
|
||||||
|
deterministic=deterministic,
|
||||||
|
return_attn_probs=False,
|
||||||
|
group=process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
return _flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
def substitute_hf_flash_attn(
|
||||||
|
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Substitute HuggingFace's flash attention implementation with ring-based implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_group: PyTorch distributed process group for communication.
|
||||||
|
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
|
||||||
|
attention with.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Substitute flash attention
|
||||||
|
old_flash_attention_forward = (
|
||||||
|
transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||||
|
)
|
||||||
|
new_flash_attention_forward = create_flash_attn_forward(
|
||||||
|
process_group=process_group, ring_attn_func=ring_attn_func
|
||||||
|
)
|
||||||
|
|
||||||
|
if check_params(old_flash_attention_forward, new_flash_attention_forward):
|
||||||
|
transformers.modeling_flash_attention_utils._flash_attention_forward = (
|
||||||
|
new_flash_attention_forward
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The signature of the new flash attention forward function does not match the old one."
|
||||||
|
)
|
||||||
|
except Exception as exception:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current transformer version {transformers.__version__} is not supported. "
|
||||||
|
"Please use pip install -U transformers to upgrade to the latest version. "
|
||||||
|
"If the code failed with the latest version, "
|
||||||
|
f"please file an issue."
|
||||||
|
) from exception
|
||||||
|
|
||||||
|
# Register with ALL_ATTENTION_FUNCTIONS if available
|
||||||
|
if ALL_ATTENTION_FUNCTIONS is not None:
|
||||||
|
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
|
||||||
|
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
||||||
@@ -6,6 +6,8 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
|||||||
their sequence parallel version of Flash Attention 2.
|
their sequence parallel version of Flash Attention 2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
@@ -16,6 +18,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
|||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
RING_ATTN_GROUP = None
|
RING_ATTN_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
@@ -40,7 +43,22 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
|||||||
RING_ATTN_GROUP = ring_attn_group
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
class RingAttnFunc(str, Enum):
|
||||||
|
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||||
|
|
||||||
|
# VARLEN_RING = "varlen_ring"
|
||||||
|
# VARLEN_ZIGZAG = "varlen_zigzag"
|
||||||
|
VARLEN_LLAMA3 = "varlen_llama3"
|
||||||
|
BATCH_RING = "batch_ring"
|
||||||
|
BATCH_ZIGZAG = "batch_zigzag"
|
||||||
|
BATCH_STRIPE = "batch_stripe"
|
||||||
|
|
||||||
|
|
||||||
|
def register_ring_attn(
|
||||||
|
sequence_parallel_degree: int,
|
||||||
|
heads_k_stride: int | None,
|
||||||
|
ring_attn_func: RingAttnFunc | None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
@@ -48,6 +66,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
|||||||
sequence_parallel_degree: Sequence parallelism factor.
|
sequence_parallel_degree: Sequence parallelism factor.
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed
|
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||||
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||||
|
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||||
|
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
||||||
|
`batch` function.
|
||||||
"""
|
"""
|
||||||
if get_ring_attn_group() is not None:
|
if get_ring_attn_group() is not None:
|
||||||
LOG.info("Ring attention already registered, exiting early...")
|
LOG.info("Ring attention already registered, exiting early...")
|
||||||
@@ -58,7 +79,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
|||||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
assert sequence_parallel_degree <= world_size, (
|
assert sequence_parallel_degree <= world_size, (
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
f"must be less than or equal to world_size ({world_size})"
|
f"must be less than or equal to world_size ({world_size})"
|
||||||
@@ -68,10 +91,8 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
|||||||
f"must evenly divide world_size ({world_size})"
|
f"must evenly divide world_size ({world_size})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Detailed logging of group formation
|
# Assign ranks to sequence parallel groups
|
||||||
rank = dist.get_rank()
|
|
||||||
group_assignments = {}
|
group_assignments = {}
|
||||||
|
|
||||||
for i in range(world_size // sequence_parallel_degree):
|
for i in range(world_size // sequence_parallel_degree):
|
||||||
ring_attn_ranks = list(
|
ring_attn_ranks = list(
|
||||||
range(
|
range(
|
||||||
@@ -92,35 +113,37 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
|
|
||||||
if heads_k_stride is None:
|
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||||
heads_k_stride = 1
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|
||||||
from ring_flash_attn import substitute_hf_flash_attn
|
substitute_hf_flash_attn(
|
||||||
|
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||||
|
)
|
||||||
|
elif ring_attn_func in [
|
||||||
|
RingAttnFunc.BATCH_RING,
|
||||||
|
RingAttnFunc.BATCH_ZIGZAG,
|
||||||
|
RingAttnFunc.BATCH_STRIPE,
|
||||||
|
]:
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
|
||||||
|
substitute_hf_flash_attn,
|
||||||
|
)
|
||||||
|
|
||||||
substitute_hf_flash_attn(
|
substitute_hf_flash_attn(
|
||||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
process_group=get_ring_attn_group(),
|
||||||
)
|
ring_attn_func=ring_attn_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
|
def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||||
"""
|
"""
|
||||||
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
||||||
value to the substituted `ring_flash_attn`.
|
value to the substituted `ring_flash_attn`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: A dictionary with a batch of data. May or may not contain `position_ids`
|
position_ids: Optional tensor of position IDs (for sample packed data).
|
||||||
data; if not, we compute it.
|
|
||||||
"""
|
"""
|
||||||
from ring_flash_attn import update_ring_flash_attn_params
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
|
||||||
input_ids = batch["input_ids"]
|
|
||||||
position_ids = batch.get("position_ids")
|
|
||||||
if position_ids is None:
|
|
||||||
seq_len = input_ids.shape[1]
|
|
||||||
position_ids = torch.arange(
|
|
||||||
0, seq_len, dtype=torch.long, device=input_ids.device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
@@ -4,7 +4,7 @@ includes logic for handling sequence parallelism collation.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -13,6 +13,7 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
|
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -53,14 +54,15 @@ class DataCollatorForSeq2Seq:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
model: Optional[Any] = None
|
model: Any | None = None
|
||||||
padding: Union[bool, str, PaddingStrategy] = True
|
padding: bool | str | PaddingStrategy = True
|
||||||
max_length: Optional[int] = None
|
max_length: int | None = None
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: int | None = None
|
||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
sequence_parallel_degree: int = 1
|
sequence_parallel_degree: int = 1
|
||||||
|
ring_attn_func: RingAttnFunc | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.sequence_parallel_degree > 1:
|
if self.sequence_parallel_degree > 1:
|
||||||
@@ -157,19 +159,41 @@ class DataCollatorForSeq2Seq:
|
|||||||
Sliced batch dictionary.
|
Sliced batch dictionary.
|
||||||
"""
|
"""
|
||||||
# Get local (start, end) for sequence parallelism slicing
|
# Get local (start, end) for sequence parallelism slicing
|
||||||
total_seq_len = batch["input_ids"].shape[1]
|
total_seq_len = batch["input_ids"].size(1)
|
||||||
slice_size = total_seq_len // self.local_world_size
|
|
||||||
start = self.local_rank * slice_size
|
|
||||||
end = start + slice_size
|
|
||||||
|
|
||||||
# Update params for ring attention calculation
|
# Update params for varlen ring attention calculation
|
||||||
update_ring_attn_params(batch=batch)
|
if batch.get("position_ids") is not None:
|
||||||
|
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||||
|
|
||||||
# Slice batch for sequence parallel processing
|
# Slice batch for sequence parallel processing
|
||||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
for key in batch:
|
||||||
for key in keys_to_slice:
|
if batch[key].size(1) == total_seq_len:
|
||||||
if key in batch:
|
if self.ring_attn_func in [
|
||||||
batch[key] = batch[key][:, start:end]
|
RingAttnFunc.VARLEN_LLAMA3,
|
||||||
|
RingAttnFunc.BATCH_RING,
|
||||||
|
]:
|
||||||
|
batch[key] = (
|
||||||
|
batch[key]
|
||||||
|
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||||
|
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
|
||||||
|
|
||||||
|
# Take rank's chunk and opposing chunk for zigzag pattern
|
||||||
|
selected_chunks = [
|
||||||
|
chunks[self.local_rank],
|
||||||
|
chunks[2 * self.local_world_size - self.local_rank - 1],
|
||||||
|
]
|
||||||
|
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
||||||
|
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
||||||
|
# TODO(djsaunde): This doesn't seem to work as expected
|
||||||
|
# Split into striped data and stack
|
||||||
|
tensor = torch.stack(
|
||||||
|
batch[key].split(self.local_world_size, dim=1),
|
||||||
|
dim=1,
|
||||||
|
).transpose(1, 2)
|
||||||
|
batch[key] = tensor[:, self.local_rank].contiguous()
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|||||||
@@ -655,6 +655,7 @@ class ModelLoader:
|
|||||||
register_ring_attn(
|
register_ring_attn(
|
||||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||||
heads_k_stride=self.cfg.heads_k_stride,
|
heads_k_stride=self.cfg.heads_k_stride,
|
||||||
|
ring_attn_func=self.cfg.ring_attn_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
@@ -1119,7 +1120,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
return skip_move_to_device
|
return skip_move_to_device
|
||||||
|
|
||||||
def ajust_model_config(self) -> None:
|
def adjust_model_config(self) -> None:
|
||||||
if (
|
if (
|
||||||
hasattr(self.model, "config")
|
hasattr(self.model, "config")
|
||||||
and hasattr(self.model.config, "max_position_embeddings")
|
and hasattr(self.model.config, "max_position_embeddings")
|
||||||
@@ -1279,7 +1280,7 @@ class ModelLoader:
|
|||||||
else:
|
else:
|
||||||
self.model.tie_weights()
|
self.model.tie_weights()
|
||||||
|
|
||||||
self.ajust_model_config()
|
self.adjust_model_config()
|
||||||
|
|
||||||
# log device memory usage
|
# log device memory usage
|
||||||
if hasattr(self.model, "device") and self.model.device.type in (
|
if hasattr(self.model, "device") and self.model.device.type in (
|
||||||
|
|||||||
@@ -259,6 +259,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
sequence_parallel_degree: int | None = None
|
sequence_parallel_degree: int | None = None
|
||||||
heads_k_stride: int | None = None
|
heads_k_stride: int | None = None
|
||||||
|
ring_attn_func: str | None = None
|
||||||
|
|
||||||
special_tokens: SpecialTokensConfig | None = None
|
special_tokens: SpecialTokensConfig | None = None
|
||||||
tokens: list[str] | None = None
|
tokens: list[str] | None = None
|
||||||
@@ -1147,7 +1148,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@field_validator("sequence_parallel_degree", mode="before")
|
@field_validator("sequence_parallel_degree", mode="after")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sequence_parallel_degree(cls, value, info):
|
def check_sequence_parallel_degree(cls, value, info):
|
||||||
if not value:
|
if not value:
|
||||||
@@ -1159,9 +1160,12 @@ class AxolotlInputConfig(
|
|||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not info.data["micro_batch_size"] == 1:
|
if (
|
||||||
|
info.data.get("sample_packing")
|
||||||
|
and not info.data["micro_batch_size"] == 1
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"micro_batch_size must be set to 1 "
|
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
||||||
"due to a `ring-flash-attn` requirement"
|
"due to a `ring-flash-attn` requirement"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1188,6 +1192,34 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@field_validator("ring_attn_func", mode="after")
|
||||||
|
@classmethod
|
||||||
|
def check_ring_attn_func(cls, value, info):
|
||||||
|
if not info.data.get("sequence_parallel_degree", 1) > 1:
|
||||||
|
return value
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||||
|
|
||||||
|
if value is not None:
|
||||||
|
# Set the ring attention function if passed in config
|
||||||
|
valid_funcs = list(RingAttnFunc)
|
||||||
|
if value in valid_funcs:
|
||||||
|
value = RingAttnFunc(value)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"ring_attn_func: {value} must be one of {valid_funcs}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default ring attention function selection
|
||||||
|
sample_packing = info.data.get("sample_packing")
|
||||||
|
value = (
|
||||||
|
RingAttnFunc.VARLEN_LLAMA3
|
||||||
|
if sample_packing
|
||||||
|
else RingAttnFunc.BATCH_RING
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_muon_deepspeed_fsdp(cls, data):
|
def check_muon_deepspeed_fsdp(cls, data):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
@@ -17,8 +18,15 @@ os.environ["WANDB_DISABLED"] = "true"
|
|||||||
class TestSequenceParallelism:
|
class TestSequenceParallelism:
|
||||||
"""Test case for training with sequence parallelism enabled"""
|
"""Test case for training with sequence parallelism enabled"""
|
||||||
|
|
||||||
def test_sequence_parallel_training(self, temp_dir):
|
def _run_sequence_parallel_test(
|
||||||
# pylint: disable=duplicate-code
|
self,
|
||||||
|
temp_dir,
|
||||||
|
sample_packing=True,
|
||||||
|
micro_batch_size=1,
|
||||||
|
pad_to_sequence_len=True,
|
||||||
|
ring_attn_func=None,
|
||||||
|
):
|
||||||
|
"""Helper method to run sequence parallel tests with different configurations"""
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
@@ -27,9 +35,9 @@ class TestSequenceParallelism:
|
|||||||
"strict": False,
|
"strict": False,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"sample_packing": True,
|
"sample_packing": sample_packing,
|
||||||
"eval_sample_packing": True,
|
"eval_sample_packing": sample_packing,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": pad_to_sequence_len,
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
@@ -45,7 +53,7 @@ class TestSequenceParallelism:
|
|||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 8,
|
"max_steps": 8,
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": micro_batch_size,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
@@ -61,6 +69,7 @@ class TestSequenceParallelism:
|
|||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"sequence_parallel_degree": 2,
|
"sequence_parallel_degree": 2,
|
||||||
|
"ring_attn_func": ring_attn_func,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -86,3 +95,35 @@ class TestSequenceParallelism:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
|
||||||
|
[
|
||||||
|
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
|
||||||
|
(False, 2, True, None), # defaults to batch_ring ring_attn_func
|
||||||
|
(False, 2, True, "batch_zigzag"),
|
||||||
|
# (False, 2, False), # not yet working
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"sample_packing, varlen_llama3 ring_attn_func",
|
||||||
|
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||||
|
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||||
|
# "no sample_packing, pad_to_sequence_len", # not yet working
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sequence_parallel_training(
|
||||||
|
self,
|
||||||
|
temp_dir,
|
||||||
|
sample_packing,
|
||||||
|
micro_batch_size,
|
||||||
|
pad_to_sequence_len,
|
||||||
|
ring_attn_func,
|
||||||
|
):
|
||||||
|
"""Test sequence parallel training with different configurations"""
|
||||||
|
self._run_sequence_parallel_test(
|
||||||
|
temp_dir,
|
||||||
|
sample_packing=sample_packing,
|
||||||
|
micro_batch_size=micro_batch_size,
|
||||||
|
pad_to_sequence_len=pad_to_sequence_len,
|
||||||
|
ring_attn_func=ring_attn_func,
|
||||||
|
)
|
||||||
|
|||||||
@@ -73,7 +73,10 @@ class TestRingAttention:
|
|||||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||||
):
|
):
|
||||||
"""Test that ring attention groups are created correctly."""
|
"""Test that ring attention groups are created correctly."""
|
||||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
from axolotl.monkeypatch.attention.ring_attn import (
|
||||||
|
RingAttnFunc,
|
||||||
|
register_ring_attn,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
mock_world_size.return_value = 8 # 8 GPUs total
|
mock_world_size.return_value = 8 # 8 GPUs total
|
||||||
@@ -82,7 +85,11 @@ class TestRingAttention:
|
|||||||
mock_new_group.return_value = mock_group
|
mock_new_group.return_value = mock_group
|
||||||
|
|
||||||
# Call register_ring_attn with size 4
|
# Call register_ring_attn with size 4
|
||||||
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
|
register_ring_attn(
|
||||||
|
sequence_parallel_degree=4,
|
||||||
|
heads_k_stride=1,
|
||||||
|
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the number of calls without examining the arguments
|
# Verify the number of calls without examining the arguments
|
||||||
assert mock_new_group.call_count == 2
|
assert mock_new_group.call_count == 2
|
||||||
|
|||||||
Reference in New Issue
Block a user