batch api HF adapter for ring-flash-attn; cleanup and improvements (#2520)

* batch api HF adapter for ring-flash-attn; cleanup and improvements

* update

* adding all batch ring-flash-attn methods via single adapter

* removing pad_to_sequence_len=False for now

* fix

* updating docs to include batch SP

* review comments

* fixes for batch API funcs, simplify

* fixes

* fix

* updates

* add batch_zigzag smoke test
This commit is contained in:
Dan Saunders
2025-04-16 13:50:48 -04:00
committed by GitHub
parent 682a9cf79b
commit b8c633aa97
13 changed files with 397 additions and 49 deletions

View File

@@ -693,6 +693,9 @@ sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model. # Must evenly divide the number of KV heads in your model.
heads_k_stride: 1 heads_k_stride: 1
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
# in the sample packing case, and "batch_ring" in the non-sample packing case.
ring_attn_func:
# Path to torch distx for optim 'adamw_anyprecision' # Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path: torchdistx_path:

View File

@@ -27,6 +27,9 @@ To enable sequence parallelism, add the following to your configuration file:
sequence_parallel_degree: 4 # Split sequences across 4 GPUs sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1 heads_k_stride: 1
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
``` ```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:

View File

@@ -776,6 +776,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["sequence_parallel_degree"] = ( training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree self.cfg.sequence_parallel_degree
) )
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
@@ -933,6 +934,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
kwargs["return_tensors"] = "pt" kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq): if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
kwargs["ring_attn_func"] = training_args.ring_attn_func
return collator( return collator(
*collator_args, *collator_args,

View File

@@ -9,6 +9,8 @@ from PIL.Image import Resampling
from transformers import TrainingArguments from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass @dataclass
class AxolotlTrainingMixins: class AxolotlTrainingMixins:
@@ -218,6 +220,12 @@ class AxolotlTrainingMixins:
default=1, default=1,
metadata={"help": "The number of workers to use in sequence parallelism"}, metadata={"help": "The number of workers to use in sequence parallelism"},
) )
ring_attn_func: Optional[RingAttnFunc] = field(
default=None,
metadata={
"help": "The ring-flash-attn function to use in sequence parallelism"
},
)
# multi-modal section # multi-modal section

View File

@@ -0,0 +1,12 @@
"""Init for ring attention monkeypatch module"""
# pylint: disable=unused-import
# flake8: noqa
from .patch import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
update_ring_attn_params,
)

View File

@@ -0,0 +1,192 @@
"""
HuggingFace flash attention adapter for basic ring attention (batch API).
Inspired by
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
Our implementation closely follows the structure of that module, but we've minified it
somewhat to support only the latest versions of transformers.
"""
# pylint: disable=protected-access,cyclic-import
import os
from typing import Callable
import torch
import torch.distributed as dist
import transformers
import transformers.modeling_flash_attention_utils
from ring_flash_attn import (
ring_flash_attn_func,
stripe_flash_attn_func,
zigzag_ring_flash_attn_func,
)
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size,
is_flash_attn_greater_or_equal,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
RING_ATTN_FUNC_MAPPING = {
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
}
def create_flash_attn_forward(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
) -> Callable:
"""
Create a ring flash attention forward function compatible with HuggingFace's
interface.
Args:
process_group: A PyTorch distributed process group.
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
attention with.
Returns:
A function that implements the ring flash attention forward pass with the
signature expected by HuggingFace Transformers.
"""
# transformers 4.48+
# pylint: disable=unused-argument
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: torch.Tensor | None = None,
softmax_scale: float | None = None,
sliding_window: int | None = None,
use_top_left_mask: bool = False,
softcap: float | None = None,
deterministic: bool = None,
cu_seq_lens_q: torch.LongTensor | None = None,
cu_seq_lens_k: torch.LongTensor | None = None,
max_length_q: int | None = None,
max_length_k: int | None = None,
target_dtype: torch.dtype | None = None,
**kwargs,
):
"""
Calls the forward method of Ring Flash Attention.
Args:
query_states: Tensor containing the query vectors.
key_states: Tensor containing the key vectors.
value_states: Tensor containing the value vectors.
attention_mask: Not used in this implementation.
query_length: Integer representing the length of the query sequence.
is_causal: Boolean indicating whether to apply a causal mask to the attention.
dropout: Float representing the dropout probability. Default is 0.0.
position_ids: Not used in this implementation.
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
sliding_window: Optional integer defining the size of the sliding attention window.
Default is None.
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
Default is False.
softcap: Not used in this implementation.
deterministic: Optional boolean to enforce deterministic computation. Default is None.
cu_seq_lens_q: Not used in this implementation.
cu_seq_lens_k: Not used in this implementation.
max_length_q: Not used in this implementation.
max_length_k: Not used in this implementation.
target_dtype: Not used in this implementation.
**kwargs: Additional keyword arguments. Not used in this implementation.
Returns:
torch.Tensor: The output of the attention mechanism, with shape
`[batch_size, query_length, num_heads, head_dim]`.
"""
if not use_top_left_mask:
causal = is_causal
else:
causal = is_causal and query_length != 1
# Handle sliding window
use_sliding_windows = (
_flash_supports_window_size
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
window_size = (
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
)
# Handle deterministic mode
if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
deterministic = (
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
)
# Call ring flash attention function
attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](
query_states,
key_states,
value_states,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
group=process_group,
)
return attn_output
return _flash_attention_forward
def substitute_hf_flash_attn(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
):
"""
Substitute HuggingFace's flash attention implementation with ring-based implementation.
Args:
process_group: PyTorch distributed process group for communication.
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
attention with.
"""
try:
# Substitute flash attention
old_flash_attention_forward = (
transformers.modeling_flash_attention_utils._flash_attention_forward
)
new_flash_attention_forward = create_flash_attn_forward(
process_group=process_group, ring_attn_func=ring_attn_func
)
if check_params(old_flash_attention_forward, new_flash_attention_forward):
transformers.modeling_flash_attention_utils._flash_attention_forward = (
new_flash_attention_forward
)
else:
raise ValueError(
"The signature of the new flash attention forward function does not match the old one."
)
except Exception as exception:
raise ValueError(
f"The current transformer version {transformers.__version__} is not supported. "
"Please use pip install -U transformers to upgrade to the latest version. "
"If the code failed with the latest version, "
f"please file an issue."
) from exception
# Register with ALL_ATTENTION_FUNCTIONS if available
if ALL_ATTENTION_FUNCTIONS is not None:
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward

View File

@@ -6,6 +6,8 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2. their sequence parallel version of Flash Attention 2.
""" """
from enum import Enum
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -16,6 +18,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging() configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
RING_ATTN_GROUP = None RING_ATTN_GROUP = None
@@ -40,7 +43,22 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None): class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
""" """
Create ring attention group and substitute flash attn with ring flash attn. Create ring attention group and substitute flash attn with ring flash attn.
@@ -48,6 +66,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
sequence_parallel_degree: Sequence parallelism factor. sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`. through to `ring_flash_attn.substitute_hf_flash_attn`.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
`batch` function.
""" """
if get_ring_attn_group() is not None: if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...") LOG.info("Ring attention already registered, exiting early...")
@@ -58,7 +79,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
f"each sequence will be processed across {sequence_parallel_degree} GPUs" f"each sequence will be processed across {sequence_parallel_degree} GPUs"
) )
rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
assert sequence_parallel_degree <= world_size, ( assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) " f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})" f"must be less than or equal to world_size ({world_size})"
@@ -68,10 +91,8 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
f"must evenly divide world_size ({world_size})" f"must evenly divide world_size ({world_size})"
) )
# Detailed logging of group formation # Assign ranks to sequence parallel groups
rank = dist.get_rank()
group_assignments = {} group_assignments = {}
for i in range(world_size // sequence_parallel_degree): for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list( ring_attn_ranks = list(
range( range(
@@ -92,35 +113,37 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
if rank == 0: if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}") LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if heads_k_stride is None: if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
heads_k_stride = 1 from ring_flash_attn import substitute_hf_flash_attn
from ring_flash_attn import substitute_hf_flash_attn substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func in [
RingAttnFunc.BATCH_RING,
RingAttnFunc.BATCH_ZIGZAG,
RingAttnFunc.BATCH_STRIPE,
]:
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
substitute_hf_flash_attn,
)
substitute_hf_flash_attn( substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride process_group=get_ring_attn_group(),
) ring_attn_func=ring_attn_func,
)
def update_ring_attn_params(batch: dict[str, torch.Tensor]): def update_ring_attn_params(position_ids: torch.Tensor | None):
""" """
Calculate the cumulative sequence lengths for the current forward pass and pass the Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`. value to the substituted `ring_flash_attn`.
Args: Args:
batch: A dictionary with a batch of data. May or may not contain `position_ids` position_ids: Optional tensor of position IDs (for sample packed data).
data; if not, we compute it.
""" """
from ring_flash_attn import update_ring_flash_attn_params from ring_flash_attn import update_ring_flash_attn_params
input_ids = batch["input_ids"]
position_ids = batch.get("position_ids")
if position_ids is None:
seq_len = input_ids.shape[1]
position_ids = torch.arange(
0, seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0)
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -4,7 +4,7 @@ includes logic for handling sequence parallelism collation.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any
import numpy as np import numpy as np
import torch import torch
@@ -13,6 +13,7 @@ from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass @dataclass
@@ -53,14 +54,15 @@ class DataCollatorForSeq2Seq:
""" """
tokenizer: PreTrainedTokenizerBase tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None model: Any | None = None
padding: Union[bool, str, PaddingStrategy] = True padding: bool | str | PaddingStrategy = True
max_length: Optional[int] = None max_length: int | None = None
pad_to_multiple_of: Optional[int] = None pad_to_multiple_of: int | None = None
label_pad_token_id: int = -100 label_pad_token_id: int = -100
position_pad_token_id: int = 0 position_pad_token_id: int = 0
return_tensors: str = "pt" return_tensors: str = "pt"
sequence_parallel_degree: int = 1 sequence_parallel_degree: int = 1
ring_attn_func: RingAttnFunc | None = None
def __post_init__(self): def __post_init__(self):
if self.sequence_parallel_degree > 1: if self.sequence_parallel_degree > 1:
@@ -157,19 +159,41 @@ class DataCollatorForSeq2Seq:
Sliced batch dictionary. Sliced batch dictionary.
""" """
# Get local (start, end) for sequence parallelism slicing # Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].shape[1] total_seq_len = batch["input_ids"].size(1)
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Update params for ring attention calculation # Update params for varlen ring attention calculation
update_ring_attn_params(batch=batch) if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing # Slice batch for sequence parallel processing
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] for key in batch:
for key in keys_to_slice: if batch[key].size(1) == total_seq_len:
if key in batch: if self.ring_attn_func in [
batch[key] = batch[key][:, start:end] RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# TODO(djsaunde): This doesn't seem to work as expected
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch return batch

View File

@@ -655,6 +655,7 @@ class ModelLoader:
register_ring_attn( register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree, sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride, heads_k_stride=self.cfg.heads_k_stride,
ring_attn_func=self.cfg.ring_attn_func,
) )
def patch_attention(self) -> None: def patch_attention(self) -> None:
@@ -1119,7 +1120,7 @@ class ModelLoader:
return skip_move_to_device return skip_move_to_device
def ajust_model_config(self) -> None: def adjust_model_config(self) -> None:
if ( if (
hasattr(self.model, "config") hasattr(self.model, "config")
and hasattr(self.model.config, "max_position_embeddings") and hasattr(self.model.config, "max_position_embeddings")
@@ -1279,7 +1280,7 @@ class ModelLoader:
else: else:
self.model.tie_weights() self.model.tie_weights()
self.ajust_model_config() self.adjust_model_config()
# log device memory usage # log device memory usage
if hasattr(self.model, "device") and self.model.device.type in ( if hasattr(self.model, "device") and self.model.device.type in (

View File

@@ -259,6 +259,7 @@ class AxolotlInputConfig(
sequence_parallel_degree: int | None = None sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None heads_k_stride: int | None = None
ring_attn_func: str | None = None
special_tokens: SpecialTokensConfig | None = None special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None tokens: list[str] | None = None
@@ -1147,7 +1148,7 @@ class AxolotlInputConfig(
return data return data
@field_validator("sequence_parallel_degree", mode="before") @field_validator("sequence_parallel_degree", mode="after")
@classmethod @classmethod
def check_sequence_parallel_degree(cls, value, info): def check_sequence_parallel_degree(cls, value, info):
if not value: if not value:
@@ -1159,9 +1160,12 @@ class AxolotlInputConfig(
"flash_attention: true must be set with sequence_parallel_degree > 1" "flash_attention: true must be set with sequence_parallel_degree > 1"
) )
if not info.data["micro_batch_size"] == 1: if (
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
raise ValueError( raise ValueError(
"micro_batch_size must be set to 1 " "micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement" "due to a `ring-flash-attn` requirement"
) )
@@ -1188,6 +1192,34 @@ class AxolotlInputConfig(
return value return value
@field_validator("ring_attn_func", mode="after")
@classmethod
def check_ring_attn_func(cls, value, info):
if not info.data.get("sequence_parallel_degree", 1) > 1:
return value
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if value is not None:
# Set the ring attention function if passed in config
valid_funcs = list(RingAttnFunc)
if value in valid_funcs:
value = RingAttnFunc(value)
else:
raise ValueError(
f"ring_attn_func: {value} must be one of {valid_funcs}"
)
else:
# Default ring attention function selection
sample_packing = info.data.get("sample_packing")
value = (
RingAttnFunc.VARLEN_LLAMA3
if sample_packing
else RingAttnFunc.BATCH_RING
)
return value
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_muon_deepspeed_fsdp(cls, data): def check_muon_deepspeed_fsdp(cls, data):

View File

@@ -3,6 +3,7 @@
import os import os
from pathlib import Path from pathlib import Path
import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
@@ -17,8 +18,15 @@ os.environ["WANDB_DISABLED"] = "true"
class TestSequenceParallelism: class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled""" """Test case for training with sequence parallelism enabled"""
def test_sequence_parallel_training(self, temp_dir): def _run_sequence_parallel_test(
# pylint: disable=duplicate-code self,
temp_dir,
sample_packing=True,
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
):
"""Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -27,9 +35,9 @@ class TestSequenceParallelism:
"strict": False, "strict": False,
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "qlora", "adapter": "qlora",
"sample_packing": True, "sample_packing": sample_packing,
"eval_sample_packing": True, "eval_sample_packing": sample_packing,
"pad_to_sequence_len": True, "pad_to_sequence_len": pad_to_sequence_len,
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
@@ -45,7 +53,7 @@ class TestSequenceParallelism:
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 8, "max_steps": 8,
"micro_batch_size": 1, "micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
@@ -61,6 +69,7 @@ class TestSequenceParallelism:
"weight_decay": 0.0, "weight_decay": 0.0,
"use_tensorboard": True, "use_tensorboard": True,
"sequence_parallel_degree": 2, "sequence_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
} }
) )
@@ -86,3 +95,35 @@ class TestSequenceParallelism:
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
) )
@pytest.mark.parametrize(
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
[
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None), # defaults to batch_ring ring_attn_func
(False, 2, True, "batch_zigzag"),
# (False, 2, False), # not yet working
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
# "no sample_packing, pad_to_sequence_len", # not yet working
],
)
def test_sequence_parallel_training(
self,
temp_dir,
sample_packing,
micro_batch_size,
pad_to_sequence_len,
ring_attn_func,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
temp_dir,
sample_packing=sample_packing,
micro_batch_size=micro_batch_size,
pad_to_sequence_len=pad_to_sequence_len,
ring_attn_func=ring_attn_func,
)

View File

@@ -73,7 +73,10 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state self, mock_world_size, mock_rank, mock_new_group, partial_state
): ):
"""Test that ring attention groups are created correctly.""" """Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks # Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total mock_world_size.return_value = 8 # 8 GPUs total
@@ -82,7 +85,11 @@ class TestRingAttention:
mock_new_group.return_value = mock_group mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4 # Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1) register_ring_attn(
sequence_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
# Verify the number of calls without examining the arguments # Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2 assert mock_new_group.call_count == 2