diff --git a/docs/config.qmd b/docs/config.qmd index 0f55c1077..a67734498 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -693,6 +693,9 @@ sequence_parallel_degree: # Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Must evenly divide the number of KV heads in your model. heads_k_stride: 1 +# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3" +# in the sample packing case, and "batch_ring" in the non-sample packing case. +ring_attn_func: # Path to torch distx for optim 'adamw_anyprecision' torchdistx_path: diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index 98ca4d746..20739333a 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -27,6 +27,9 @@ To enable sequence parallelism, add the following to your configuration file: sequence_parallel_degree: 4 # Split sequences across 4 GPUs # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 +# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to +# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise. +ring_attn_func: ``` The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 40dedb456..7c2be4956 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -776,6 +776,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["sequence_parallel_degree"] = ( self.cfg.sequence_parallel_degree ) + training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig @@ -933,6 +934,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): kwargs["return_tensors"] = "pt" if issubclass(collator, DataCollatorForSeq2Seq): kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree + kwargs["ring_attn_func"] = training_args.ring_attn_func return collator( *collator_args, diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 18843abb4..3fe32f507 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -9,6 +9,8 @@ from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig +from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc + @dataclass class AxolotlTrainingMixins: @@ -218,6 +220,12 @@ class AxolotlTrainingMixins: default=1, metadata={"help": "The number of workers to use in sequence parallelism"}, ) + ring_attn_func: Optional[RingAttnFunc] = field( + default=None, + metadata={ + "help": "The ring-flash-attn function to use in sequence parallelism" + }, + ) # multi-modal section diff --git a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py b/src/axolotl/monkeypatch/attention/ring_attn/__init__.py new file mode 100644 index 000000000..055607e92 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/ring_attn/__init__.py @@ -0,0 +1,12 @@ +"""Init for ring attention monkeypatch module""" + +# pylint: disable=unused-import +# flake8: noqa + +from .patch import ( + RingAttnFunc, + get_ring_attn_group, + register_ring_attn, + set_ring_attn_group, + update_ring_attn_params, +) diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/__init__.py b/src/axolotl/monkeypatch/attention/ring_attn/adapters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py new file mode 100644 index 000000000..a88c9f6f1 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py @@ -0,0 +1,192 @@ +""" +HuggingFace flash attention adapter for basic ring attention (batch API). + +Inspired by +https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py. +Our implementation closely follows the structure of that module, but we've minified it +somewhat to support only the latest versions of transformers. +""" + +# pylint: disable=protected-access,cyclic-import + +import os +from typing import Callable + +import torch +import torch.distributed as dist +import transformers +import transformers.modeling_flash_attention_utils +from ring_flash_attn import ( + ring_flash_attn_func, + stripe_flash_attn_func, + zigzag_ring_flash_attn_func, +) +from ring_flash_attn.adapters.hf_adapter import check_params +from transformers.modeling_flash_attention_utils import ( + _flash_supports_window_size, + is_flash_attn_greater_or_equal, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc + +RING_ATTN_FUNC_MAPPING = { + RingAttnFunc.BATCH_RING: ring_flash_attn_func, + RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func, + RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func, +} + + +def create_flash_attn_forward( + process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc +) -> Callable: + """ + Create a ring flash attention forward function compatible with HuggingFace's + interface. + + Args: + process_group: A PyTorch distributed process group. + ring_attn_func: Function from `ring_flash_attention` to replace HF flash + attention with. + + Returns: + A function that implements the ring flash attention forward pass with the + signature expected by HuggingFace Transformers. + """ + + # transformers 4.48+ + # pylint: disable=unused-argument + def _flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: torch.Tensor | None = None, + softmax_scale: float | None = None, + sliding_window: int | None = None, + use_top_left_mask: bool = False, + softcap: float | None = None, + deterministic: bool = None, + cu_seq_lens_q: torch.LongTensor | None = None, + cu_seq_lens_k: torch.LongTensor | None = None, + max_length_q: int | None = None, + max_length_k: int | None = None, + target_dtype: torch.dtype | None = None, + **kwargs, + ): + """ + Calls the forward method of Ring Flash Attention. + + Args: + query_states: Tensor containing the query vectors. + key_states: Tensor containing the key vectors. + value_states: Tensor containing the value vectors. + attention_mask: Not used in this implementation. + query_length: Integer representing the length of the query sequence. + is_causal: Boolean indicating whether to apply a causal mask to the attention. + dropout: Float representing the dropout probability. Default is 0.0. + position_ids: Not used in this implementation. + softmax_scale: Optional float value for the softmax scaling factor. Default is None. + sliding_window: Optional integer defining the size of the sliding attention window. + Default is None. + use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention. + Default is False. + softcap: Not used in this implementation. + deterministic: Optional boolean to enforce deterministic computation. Default is None. + cu_seq_lens_q: Not used in this implementation. + cu_seq_lens_k: Not used in this implementation. + max_length_q: Not used in this implementation. + max_length_k: Not used in this implementation. + target_dtype: Not used in this implementation. + **kwargs: Additional keyword arguments. Not used in this implementation. + + Returns: + torch.Tensor: The output of the attention mechanism, with shape + `[batch_size, query_length, num_heads, head_dim]`. + """ + if not use_top_left_mask: + causal = is_causal + else: + causal = is_causal and query_length != 1 + + # Handle sliding window + use_sliding_windows = ( + _flash_supports_window_size + and sliding_window is not None + and key_states.shape[1] > sliding_window + ) + window_size = ( + (sliding_window, sliding_window) if use_sliding_windows else (-1, -1) + ) + + # Handle deterministic mode + if is_flash_attn_greater_or_equal("2.4.1"): + if deterministic is None: + deterministic = ( + os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + ) + + # Call ring flash attention function + attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func]( + query_states, + key_states, + value_states, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + group=process_group, + ) + + return attn_output + + return _flash_attention_forward + + +def substitute_hf_flash_attn( + process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc +): + """ + Substitute HuggingFace's flash attention implementation with ring-based implementation. + + Args: + process_group: PyTorch distributed process group for communication. + ring_attn_func: Function from `ring_flash_attention` to replace HF flash + attention with. + """ + try: + # Substitute flash attention + old_flash_attention_forward = ( + transformers.modeling_flash_attention_utils._flash_attention_forward + ) + new_flash_attention_forward = create_flash_attn_forward( + process_group=process_group, ring_attn_func=ring_attn_func + ) + + if check_params(old_flash_attention_forward, new_flash_attention_forward): + transformers.modeling_flash_attention_utils._flash_attention_forward = ( + new_flash_attention_forward + ) + else: + raise ValueError( + "The signature of the new flash attention forward function does not match the old one." + ) + except Exception as exception: + raise ValueError( + f"The current transformer version {transformers.__version__} is not supported. " + "Please use pip install -U transformers to upgrade to the latest version. " + "If the code failed with the latest version, " + f"please file an issue." + ) from exception + + # Register with ALL_ATTENTION_FUNCTIONS if available + if ALL_ATTENTION_FUNCTIONS is not None: + from ring_flash_attn.adapters.hf_adapter import flash_attention_forward + + ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py similarity index 68% rename from src/axolotl/monkeypatch/attention/ring_attn.py rename to src/axolotl/monkeypatch/attention/ring_attn/patch.py index 30aa78f01..b5587ddca 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/patch.py @@ -6,6 +6,8 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc their sequence parallel version of Flash Attention 2. """ +from enum import Enum + import torch import torch.distributed as dist from accelerate.logging import get_logger @@ -16,6 +18,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids configure_logging() LOG = get_logger(__name__) + RING_ATTN_GROUP = None @@ -40,7 +43,22 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): RING_ATTN_GROUP = ring_attn_group -def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None): +class RingAttnFunc(str, Enum): + """Enum class for supported `ring-flash-attn` implementations""" + + # VARLEN_RING = "varlen_ring" + # VARLEN_ZIGZAG = "varlen_zigzag" + VARLEN_LLAMA3 = "varlen_llama3" + BATCH_RING = "batch_ring" + BATCH_ZIGZAG = "batch_zigzag" + BATCH_STRIPE = "batch_stripe" + + +def register_ring_attn( + sequence_parallel_degree: int, + heads_k_stride: int | None, + ring_attn_func: RingAttnFunc | None, +): """ Create ring attention group and substitute flash attn with ring flash attn. @@ -48,6 +66,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None sequence_parallel_degree: Sequence parallelism factor. heads_k_stride: Sequence parallelism K head stride size. Passed through to `ring_flash_attn.substitute_hf_flash_attn`. + ring_attn_func: `ring_flash_attn` ring attention implemention. If sample + packing is enabled, it must be a `varlen` function; otherwise, it must be a + `batch` function. """ if get_ring_attn_group() is not None: LOG.info("Ring attention already registered, exiting early...") @@ -58,7 +79,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None f"each sequence will be processed across {sequence_parallel_degree} GPUs" ) + rank = dist.get_rank() world_size = dist.get_world_size() + assert sequence_parallel_degree <= world_size, ( f"sequence_parallel_degree ({sequence_parallel_degree}) " f"must be less than or equal to world_size ({world_size})" @@ -68,10 +91,8 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None f"must evenly divide world_size ({world_size})" ) - # Detailed logging of group formation - rank = dist.get_rank() + # Assign ranks to sequence parallel groups group_assignments = {} - for i in range(world_size // sequence_parallel_degree): ring_attn_ranks = list( range( @@ -92,35 +113,37 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None if rank == 0: LOG.info(f"Sequence parallel group assignments: {group_assignments}") - if heads_k_stride is None: - heads_k_stride = 1 + if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: + from ring_flash_attn import substitute_hf_flash_attn - from ring_flash_attn import substitute_hf_flash_attn + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 + ) + elif ring_attn_func in [ + RingAttnFunc.BATCH_RING, + RingAttnFunc.BATCH_ZIGZAG, + RingAttnFunc.BATCH_STRIPE, + ]: + from axolotl.monkeypatch.attention.ring_attn.adapters.batch import ( + substitute_hf_flash_attn, + ) - substitute_hf_flash_attn( - process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride - ) + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), + ring_attn_func=ring_attn_func, + ) -def update_ring_attn_params(batch: dict[str, torch.Tensor]): +def update_ring_attn_params(position_ids: torch.Tensor | None): """ Calculate the cumulative sequence lengths for the current forward pass and pass the value to the substituted `ring_flash_attn`. Args: - batch: A dictionary with a batch of data. May or may not contain `position_ids` - data; if not, we compute it. + position_ids: Optional tensor of position IDs (for sample packed data). """ from ring_flash_attn import update_ring_flash_attn_params - input_ids = batch["input_ids"] - position_ids = batch.get("position_ids") - if position_ids is None: - seq_len = input_ids.shape[1] - position_ids = torch.arange( - 0, seq_len, dtype=torch.long, device=input_ids.device - ).unsqueeze(0) - cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index ed445ae56..738ef0dc5 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -4,7 +4,7 @@ includes logic for handling sequence parallelism collation. """ from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any import numpy as np import torch @@ -13,6 +13,7 @@ from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params +from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc @dataclass @@ -53,14 +54,15 @@ class DataCollatorForSeq2Seq: """ tokenizer: PreTrainedTokenizerBase - model: Optional[Any] = None - padding: Union[bool, str, PaddingStrategy] = True - max_length: Optional[int] = None - pad_to_multiple_of: Optional[int] = None + model: Any | None = None + padding: bool | str | PaddingStrategy = True + max_length: int | None = None + pad_to_multiple_of: int | None = None label_pad_token_id: int = -100 position_pad_token_id: int = 0 return_tensors: str = "pt" sequence_parallel_degree: int = 1 + ring_attn_func: RingAttnFunc | None = None def __post_init__(self): if self.sequence_parallel_degree > 1: @@ -157,19 +159,41 @@ class DataCollatorForSeq2Seq: Sliced batch dictionary. """ # Get local (start, end) for sequence parallelism slicing - total_seq_len = batch["input_ids"].shape[1] - slice_size = total_seq_len // self.local_world_size - start = self.local_rank * slice_size - end = start + slice_size + total_seq_len = batch["input_ids"].size(1) - # Update params for ring attention calculation - update_ring_attn_params(batch=batch) + # Update params for varlen ring attention calculation + if batch.get("position_ids") is not None: + update_ring_attn_params(position_ids=batch["position_ids"]) # Slice batch for sequence parallel processing - keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] - for key in keys_to_slice: - if key in batch: - batch[key] = batch[key][:, start:end] + for key in batch: + if batch[key].size(1) == total_seq_len: + if self.ring_attn_func in [ + RingAttnFunc.VARLEN_LLAMA3, + RingAttnFunc.BATCH_RING, + ]: + batch[key] = ( + batch[key] + .chunk(self.local_world_size, dim=1)[self.local_rank] + .contiguous() + ) + elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: + chunks = batch[key].chunk(2 * self.local_world_size, dim=1) + + # Take rank's chunk and opposing chunk for zigzag pattern + selected_chunks = [ + chunks[self.local_rank], + chunks[2 * self.local_world_size - self.local_rank - 1], + ] + batch[key] = torch.cat(selected_chunks, dim=1).contiguous() + elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE: + # TODO(djsaunde): This doesn't seem to work as expected + # Split into striped data and stack + tensor = torch.stack( + batch[key].split(self.local_world_size, dim=1), + dim=1, + ).transpose(1, 2) + batch[key] = tensor[:, self.local_rank].contiguous() return batch diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4d4366994..d7105daba 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -655,6 +655,7 @@ class ModelLoader: register_ring_attn( sequence_parallel_degree=self.cfg.sequence_parallel_degree, heads_k_stride=self.cfg.heads_k_stride, + ring_attn_func=self.cfg.ring_attn_func, ) def patch_attention(self) -> None: @@ -1119,7 +1120,7 @@ class ModelLoader: return skip_move_to_device - def ajust_model_config(self) -> None: + def adjust_model_config(self) -> None: if ( hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings") @@ -1279,7 +1280,7 @@ class ModelLoader: else: self.model.tie_weights() - self.ajust_model_config() + self.adjust_model_config() # log device memory usage if hasattr(self.model, "device") and self.model.device.type in ( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 8a76c4eb4..a0fc2c7d3 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -259,6 +259,7 @@ class AxolotlInputConfig( sequence_parallel_degree: int | None = None heads_k_stride: int | None = None + ring_attn_func: str | None = None special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None @@ -1147,7 +1148,7 @@ class AxolotlInputConfig( return data - @field_validator("sequence_parallel_degree", mode="before") + @field_validator("sequence_parallel_degree", mode="after") @classmethod def check_sequence_parallel_degree(cls, value, info): if not value: @@ -1159,9 +1160,12 @@ class AxolotlInputConfig( "flash_attention: true must be set with sequence_parallel_degree > 1" ) - if not info.data["micro_batch_size"] == 1: + if ( + info.data.get("sample_packing") + and not info.data["micro_batch_size"] == 1 + ): raise ValueError( - "micro_batch_size must be set to 1 " + "micro_batch_size must be set to 1 when sample_packing is enabled" "due to a `ring-flash-attn` requirement" ) @@ -1188,6 +1192,34 @@ class AxolotlInputConfig( return value + @field_validator("ring_attn_func", mode="after") + @classmethod + def check_ring_attn_func(cls, value, info): + if not info.data.get("sequence_parallel_degree", 1) > 1: + return value + + from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc + + if value is not None: + # Set the ring attention function if passed in config + valid_funcs = list(RingAttnFunc) + if value in valid_funcs: + value = RingAttnFunc(value) + else: + raise ValueError( + f"ring_attn_func: {value} must be one of {valid_funcs}" + ) + else: + # Default ring attention function selection + sample_packing = info.data.get("sample_packing") + value = ( + RingAttnFunc.VARLEN_LLAMA3 + if sample_packing + else RingAttnFunc.BATCH_RING + ) + + return value + @model_validator(mode="before") @classmethod def check_muon_deepspeed_fsdp(cls, data): diff --git a/tests/e2e/multigpu/test_sp.py b/tests/e2e/multigpu/test_sp.py index 288720eec..72e5cb88c 100644 --- a/tests/e2e/multigpu/test_sp.py +++ b/tests/e2e/multigpu/test_sp.py @@ -3,6 +3,7 @@ import os from pathlib import Path +import pytest import yaml from accelerate.test_utils import execute_subprocess_async from transformers.testing_utils import get_torch_dist_unique_port @@ -17,8 +18,15 @@ os.environ["WANDB_DISABLED"] = "true" class TestSequenceParallelism: """Test case for training with sequence parallelism enabled""" - def test_sequence_parallel_training(self, temp_dir): - # pylint: disable=duplicate-code + def _run_sequence_parallel_test( + self, + temp_dir, + sample_packing=True, + micro_batch_size=1, + pad_to_sequence_len=True, + ring_attn_func=None, + ): + """Helper method to run sequence parallel tests with different configurations""" cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -27,9 +35,9 @@ class TestSequenceParallelism: "strict": False, "sequence_len": 2048, "adapter": "qlora", - "sample_packing": True, - "eval_sample_packing": True, - "pad_to_sequence_len": True, + "sample_packing": sample_packing, + "eval_sample_packing": sample_packing, + "pad_to_sequence_len": pad_to_sequence_len, "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.05, @@ -45,7 +53,7 @@ class TestSequenceParallelism: ], "num_epochs": 1, "max_steps": 8, - "micro_batch_size": 1, + "micro_batch_size": micro_batch_size, "gradient_accumulation_steps": 2, "output_dir": temp_dir, "learning_rate": 0.00001, @@ -61,6 +69,7 @@ class TestSequenceParallelism: "weight_decay": 0.0, "use_tensorboard": True, "sequence_parallel_degree": 2, + "ring_attn_func": ring_attn_func, } ) @@ -86,3 +95,35 @@ class TestSequenceParallelism: check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high" ) + + @pytest.mark.parametrize( + "sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func", + [ + (True, 1, True, None), # defaults to varlen_llama3 ring_attn_func + (False, 2, True, None), # defaults to batch_ring ring_attn_func + (False, 2, True, "batch_zigzag"), + # (False, 2, False), # not yet working + ], + ids=[ + "sample_packing, varlen_llama3 ring_attn_func", + "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func", + "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func", + # "no sample_packing, pad_to_sequence_len", # not yet working + ], + ) + def test_sequence_parallel_training( + self, + temp_dir, + sample_packing, + micro_batch_size, + pad_to_sequence_len, + ring_attn_func, + ): + """Test sequence parallel training with different configurations""" + self._run_sequence_parallel_test( + temp_dir, + sample_packing=sample_packing, + micro_batch_size=micro_batch_size, + pad_to_sequence_len=pad_to_sequence_len, + ring_attn_func=ring_attn_func, + ) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 1361a8522..70a601f63 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -73,7 +73,10 @@ class TestRingAttention: self, mock_world_size, mock_rank, mock_new_group, partial_state ): """Test that ring attention groups are created correctly.""" - from axolotl.monkeypatch.attention.ring_attn import register_ring_attn + from axolotl.monkeypatch.attention.ring_attn import ( + RingAttnFunc, + register_ring_attn, + ) # Setup mocks mock_world_size.return_value = 8 # 8 GPUs total @@ -82,7 +85,11 @@ class TestRingAttention: mock_new_group.return_value = mock_group # Call register_ring_attn with size 4 - register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1) + register_ring_attn( + sequence_parallel_degree=4, + heads_k_stride=1, + ring_attn_func=RingAttnFunc.VARLEN_LLAMA3, + ) # Verify the number of calls without examining the arguments assert mock_new_group.call_count == 2