batch api HF adapter for ring-flash-attn; cleanup and improvements (#2520)
* batch api HF adapter for ring-flash-attn; cleanup and improvements * update * adding all batch ring-flash-attn methods via single adapter * removing pad_to_sequence_len=False for now * fix * updating docs to include batch SP * review comments * fixes for batch API funcs, simplify * fixes * fix * updates * add batch_zigzag smoke test
This commit is contained in:
@@ -693,6 +693,9 @@ sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
|
||||
# in the sample packing case, and "batch_ring" in the non-sample packing case.
|
||||
ring_attn_func:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
@@ -27,6 +27,9 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
|
||||
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
@@ -776,6 +776,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
@@ -933,6 +934,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
kwargs["return_tensors"] = "pt"
|
||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||
kwargs["ring_attn_func"] = training_args.ring_attn_func
|
||||
|
||||
return collator(
|
||||
*collator_args,
|
||||
|
||||
@@ -9,6 +9,8 @@ from PIL.Image import Resampling
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
@@ -218,6 +220,12 @@ class AxolotlTrainingMixins:
|
||||
default=1,
|
||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||
)
|
||||
ring_attn_func: Optional[RingAttnFunc] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The ring-flash-attn function to use in sequence parallelism"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
|
||||
12
src/axolotl/monkeypatch/attention/ring_attn/__init__.py
Normal file
12
src/axolotl/monkeypatch/attention/ring_attn/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Init for ring attention monkeypatch module"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .patch import (
|
||||
RingAttnFunc,
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
192
src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py
Normal file
192
src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
HuggingFace flash attention adapter for basic ring attention (batch API).
|
||||
|
||||
Inspired by
|
||||
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
|
||||
Our implementation closely follows the structure of that module, but we've minified it
|
||||
somewhat to support only the latest versions of transformers.
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access,cyclic-import
|
||||
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from ring_flash_attn import (
|
||||
ring_flash_attn_func,
|
||||
stripe_flash_attn_func,
|
||||
zigzag_ring_flash_attn_func,
|
||||
)
|
||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size,
|
||||
is_flash_attn_greater_or_equal,
|
||||
)
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
RING_ATTN_FUNC_MAPPING = {
|
||||
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
|
||||
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
|
||||
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
|
||||
}
|
||||
|
||||
|
||||
def create_flash_attn_forward(
|
||||
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||
) -> Callable:
|
||||
"""
|
||||
Create a ring flash attention forward function compatible with HuggingFace's
|
||||
interface.
|
||||
|
||||
Args:
|
||||
process_group: A PyTorch distributed process group.
|
||||
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
|
||||
attention with.
|
||||
|
||||
Returns:
|
||||
A function that implements the ring flash attention forward pass with the
|
||||
signature expected by HuggingFace Transformers.
|
||||
"""
|
||||
|
||||
# transformers 4.48+
|
||||
# pylint: disable=unused-argument
|
||||
def _flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
position_ids: torch.Tensor | None = None,
|
||||
softmax_scale: float | None = None,
|
||||
sliding_window: int | None = None,
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: float | None = None,
|
||||
deterministic: bool = None,
|
||||
cu_seq_lens_q: torch.LongTensor | None = None,
|
||||
cu_seq_lens_k: torch.LongTensor | None = None,
|
||||
max_length_q: int | None = None,
|
||||
max_length_k: int | None = None,
|
||||
target_dtype: torch.dtype | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Ring Flash Attention.
|
||||
|
||||
Args:
|
||||
query_states: Tensor containing the query vectors.
|
||||
key_states: Tensor containing the key vectors.
|
||||
value_states: Tensor containing the value vectors.
|
||||
attention_mask: Not used in this implementation.
|
||||
query_length: Integer representing the length of the query sequence.
|
||||
is_causal: Boolean indicating whether to apply a causal mask to the attention.
|
||||
dropout: Float representing the dropout probability. Default is 0.0.
|
||||
position_ids: Not used in this implementation.
|
||||
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
|
||||
sliding_window: Optional integer defining the size of the sliding attention window.
|
||||
Default is None.
|
||||
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
|
||||
Default is False.
|
||||
softcap: Not used in this implementation.
|
||||
deterministic: Optional boolean to enforce deterministic computation. Default is None.
|
||||
cu_seq_lens_q: Not used in this implementation.
|
||||
cu_seq_lens_k: Not used in this implementation.
|
||||
max_length_q: Not used in this implementation.
|
||||
max_length_k: Not used in this implementation.
|
||||
target_dtype: Not used in this implementation.
|
||||
**kwargs: Additional keyword arguments. Not used in this implementation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output of the attention mechanism, with shape
|
||||
`[batch_size, query_length, num_heads, head_dim]`.
|
||||
"""
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
causal = is_causal and query_length != 1
|
||||
|
||||
# Handle sliding window
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size
|
||||
and sliding_window is not None
|
||||
and key_states.shape[1] > sliding_window
|
||||
)
|
||||
window_size = (
|
||||
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
|
||||
)
|
||||
|
||||
# Handle deterministic mode
|
||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||
if deterministic is None:
|
||||
deterministic = (
|
||||
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
)
|
||||
|
||||
# Call ring flash attention function
|
||||
attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=None,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=False,
|
||||
group=process_group,
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
return _flash_attention_forward
|
||||
|
||||
|
||||
def substitute_hf_flash_attn(
|
||||
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||
):
|
||||
"""
|
||||
Substitute HuggingFace's flash attention implementation with ring-based implementation.
|
||||
|
||||
Args:
|
||||
process_group: PyTorch distributed process group for communication.
|
||||
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
|
||||
attention with.
|
||||
"""
|
||||
try:
|
||||
# Substitute flash attention
|
||||
old_flash_attention_forward = (
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||
)
|
||||
new_flash_attention_forward = create_flash_attn_forward(
|
||||
process_group=process_group, ring_attn_func=ring_attn_func
|
||||
)
|
||||
|
||||
if check_params(old_flash_attention_forward, new_flash_attention_forward):
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward = (
|
||||
new_flash_attention_forward
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The signature of the new flash attention forward function does not match the old one."
|
||||
)
|
||||
except Exception as exception:
|
||||
raise ValueError(
|
||||
f"The current transformer version {transformers.__version__} is not supported. "
|
||||
"Please use pip install -U transformers to upgrade to the latest version. "
|
||||
"If the code failed with the latest version, "
|
||||
f"please file an issue."
|
||||
) from exception
|
||||
|
||||
# Register with ALL_ATTENTION_FUNCTIONS if available
|
||||
if ALL_ATTENTION_FUNCTIONS is not None:
|
||||
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
||||
@@ -6,6 +6,8 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
@@ -16,6 +18,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
|
||||
@@ -40,7 +43,22 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||
class RingAttnFunc(str, Enum):
|
||||
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||
|
||||
# VARLEN_RING = "varlen_ring"
|
||||
# VARLEN_ZIGZAG = "varlen_zigzag"
|
||||
VARLEN_LLAMA3 = "varlen_llama3"
|
||||
BATCH_RING = "batch_ring"
|
||||
BATCH_ZIGZAG = "batch_zigzag"
|
||||
BATCH_STRIPE = "batch_stripe"
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
@@ -48,6 +66,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
||||
`batch` function.
|
||||
"""
|
||||
if get_ring_attn_group() is not None:
|
||||
LOG.info("Ring attention already registered, exiting early...")
|
||||
@@ -58,7 +79,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
@@ -68,10 +91,8 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
# Detailed logging of group formation
|
||||
rank = dist.get_rank()
|
||||
# Assign ranks to sequence parallel groups
|
||||
group_assignments = {}
|
||||
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
@@ -92,35 +113,37 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if heads_k_stride is None:
|
||||
heads_k_stride = 1
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||
)
|
||||
elif ring_attn_func in [
|
||||
RingAttnFunc.BATCH_RING,
|
||||
RingAttnFunc.BATCH_ZIGZAG,
|
||||
RingAttnFunc.BATCH_STRIPE,
|
||||
]:
|
||||
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
|
||||
substitute_hf_flash_attn,
|
||||
)
|
||||
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(),
|
||||
ring_attn_func=ring_attn_func,
|
||||
)
|
||||
|
||||
|
||||
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
|
||||
def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
"""
|
||||
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
||||
value to the substituted `ring_flash_attn`.
|
||||
|
||||
Args:
|
||||
batch: A dictionary with a batch of data. May or may not contain `position_ids`
|
||||
data; if not, we compute it.
|
||||
position_ids: Optional tensor of position IDs (for sample packed data).
|
||||
"""
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
position_ids = batch.get("position_ids")
|
||||
if position_ids is None:
|
||||
seq_len = input_ids.shape[1]
|
||||
position_ids = torch.arange(
|
||||
0, seq_len, dtype=torch.long, device=input_ids.device
|
||||
).unsqueeze(0)
|
||||
|
||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
@@ -4,7 +4,7 @@ includes logic for handling sequence parallelism collation.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -13,6 +13,7 @@ from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,14 +54,15 @@ class DataCollatorForSeq2Seq:
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
model: Optional[Any] = None
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
model: Any | None = None
|
||||
padding: bool | str | PaddingStrategy = True
|
||||
max_length: int | None = None
|
||||
pad_to_multiple_of: int | None = None
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
sequence_parallel_degree: int = 1
|
||||
ring_attn_func: RingAttnFunc | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sequence_parallel_degree > 1:
|
||||
@@ -157,19 +159,41 @@ class DataCollatorForSeq2Seq:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
# Get local (start, end) for sequence parallelism slicing
|
||||
total_seq_len = batch["input_ids"].shape[1]
|
||||
slice_size = total_seq_len // self.local_world_size
|
||||
start = self.local_rank * slice_size
|
||||
end = start + slice_size
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# Update params for ring attention calculation
|
||||
update_ring_attn_params(batch=batch)
|
||||
# Update params for varlen ring attention calculation
|
||||
if batch.get("position_ids") is not None:
|
||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||
|
||||
# Slice batch for sequence parallel processing
|
||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||
for key in keys_to_slice:
|
||||
if key in batch:
|
||||
batch[key] = batch[key][:, start:end]
|
||||
for key in batch:
|
||||
if batch[key].size(1) == total_seq_len:
|
||||
if self.ring_attn_func in [
|
||||
RingAttnFunc.VARLEN_LLAMA3,
|
||||
RingAttnFunc.BATCH_RING,
|
||||
]:
|
||||
batch[key] = (
|
||||
batch[key]
|
||||
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
||||
.contiguous()
|
||||
)
|
||||
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
|
||||
|
||||
# Take rank's chunk and opposing chunk for zigzag pattern
|
||||
selected_chunks = [
|
||||
chunks[self.local_rank],
|
||||
chunks[2 * self.local_world_size - self.local_rank - 1],
|
||||
]
|
||||
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
||||
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
||||
# TODO(djsaunde): This doesn't seem to work as expected
|
||||
# Split into striped data and stack
|
||||
tensor = torch.stack(
|
||||
batch[key].split(self.local_world_size, dim=1),
|
||||
dim=1,
|
||||
).transpose(1, 2)
|
||||
batch[key] = tensor[:, self.local_rank].contiguous()
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@@ -655,6 +655,7 @@ class ModelLoader:
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||
heads_k_stride=self.cfg.heads_k_stride,
|
||||
ring_attn_func=self.cfg.ring_attn_func,
|
||||
)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
@@ -1119,7 +1120,7 @@ class ModelLoader:
|
||||
|
||||
return skip_move_to_device
|
||||
|
||||
def ajust_model_config(self) -> None:
|
||||
def adjust_model_config(self) -> None:
|
||||
if (
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "max_position_embeddings")
|
||||
@@ -1279,7 +1280,7 @@ class ModelLoader:
|
||||
else:
|
||||
self.model.tie_weights()
|
||||
|
||||
self.ajust_model_config()
|
||||
self.adjust_model_config()
|
||||
|
||||
# log device memory usage
|
||||
if hasattr(self.model, "device") and self.model.device.type in (
|
||||
|
||||
@@ -259,6 +259,7 @@ class AxolotlInputConfig(
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
ring_attn_func: str | None = None
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
@@ -1147,7 +1148,7 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@field_validator("sequence_parallel_degree", mode="before")
|
||||
@field_validator("sequence_parallel_degree", mode="after")
|
||||
@classmethod
|
||||
def check_sequence_parallel_degree(cls, value, info):
|
||||
if not value:
|
||||
@@ -1159,9 +1160,12 @@ class AxolotlInputConfig(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if not info.data["micro_batch_size"] == 1:
|
||||
if (
|
||||
info.data.get("sample_packing")
|
||||
and not info.data["micro_batch_size"] == 1
|
||||
):
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 "
|
||||
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
@@ -1188,6 +1192,34 @@ class AxolotlInputConfig(
|
||||
|
||||
return value
|
||||
|
||||
@field_validator("ring_attn_func", mode="after")
|
||||
@classmethod
|
||||
def check_ring_attn_func(cls, value, info):
|
||||
if not info.data.get("sequence_parallel_degree", 1) > 1:
|
||||
return value
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
if value is not None:
|
||||
# Set the ring attention function if passed in config
|
||||
valid_funcs = list(RingAttnFunc)
|
||||
if value in valid_funcs:
|
||||
value = RingAttnFunc(value)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ring_attn_func: {value} must be one of {valid_funcs}"
|
||||
)
|
||||
else:
|
||||
# Default ring attention function selection
|
||||
sample_packing = info.data.get("sample_packing")
|
||||
value = (
|
||||
RingAttnFunc.VARLEN_LLAMA3
|
||||
if sample_packing
|
||||
else RingAttnFunc.BATCH_RING
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
@@ -17,8 +18,15 @@ os.environ["WANDB_DISABLED"] = "true"
|
||||
class TestSequenceParallelism:
|
||||
"""Test case for training with sequence parallelism enabled"""
|
||||
|
||||
def test_sequence_parallel_training(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
def _run_sequence_parallel_test(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing=True,
|
||||
micro_batch_size=1,
|
||||
pad_to_sequence_len=True,
|
||||
ring_attn_func=None,
|
||||
):
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -27,9 +35,9 @@ class TestSequenceParallelism:
|
||||
"strict": False,
|
||||
"sequence_len": 2048,
|
||||
"adapter": "qlora",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sample_packing": sample_packing,
|
||||
"eval_sample_packing": sample_packing,
|
||||
"pad_to_sequence_len": pad_to_sequence_len,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
@@ -45,7 +53,7 @@ class TestSequenceParallelism:
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 8,
|
||||
"micro_batch_size": 1,
|
||||
"micro_batch_size": micro_batch_size,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
@@ -61,6 +69,7 @@ class TestSequenceParallelism:
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -86,3 +95,35 @@ class TestSequenceParallelism:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
|
||||
[
|
||||
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
|
||||
(False, 2, True, None), # defaults to batch_ring ring_attn_func
|
||||
(False, 2, True, "batch_zigzag"),
|
||||
# (False, 2, False), # not yet working
|
||||
],
|
||||
ids=[
|
||||
"sample_packing, varlen_llama3 ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||
# "no sample_packing, pad_to_sequence_len", # not yet working
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing,
|
||||
micro_batch_size,
|
||||
pad_to_sequence_len,
|
||||
ring_attn_func,
|
||||
):
|
||||
"""Test sequence parallel training with different configurations"""
|
||||
self._run_sequence_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=sample_packing,
|
||||
micro_batch_size=micro_batch_size,
|
||||
pad_to_sequence_len=pad_to_sequence_len,
|
||||
ring_attn_func=ring_attn_func,
|
||||
)
|
||||
|
||||
@@ -73,7 +73,10 @@ class TestRingAttention:
|
||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||
):
|
||||
"""Test that ring attention groups are created correctly."""
|
||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
RingAttnFunc,
|
||||
register_ring_attn,
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 8 # 8 GPUs total
|
||||
@@ -82,7 +85,11 @@ class TestRingAttention:
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
Reference in New Issue
Block a user