progress (messy :O)

This commit is contained in:
Dan Saunders
2025-06-12 18:54:41 +00:00
parent ae73123eae
commit aced809989
8 changed files with 333 additions and 199 deletions

View File

@@ -7,11 +7,13 @@ from __future__ import annotations
import os import os
from collections import defaultdict from collections import defaultdict
from functools import partial, wraps from functools import partial, wraps
from typing import Callable, Literal, Optional from typing import Any, Callable, Literal, Optional
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
import datasets import datasets
import torch import torch
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.utils.data import ( from torch.utils.data import (
BatchSampler, BatchSampler,
DataLoader, DataLoader,

View File

@@ -17,7 +17,6 @@ from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel
from torch.distributed.tensor.experimental import context_parallel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
@@ -34,7 +33,7 @@ from axolotl.loaders import (
load_processor, load_processor,
load_tokenizer, load_tokenizer,
) )
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.ctx_managers.sequence_parallel import ContextParallelContextManager
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
@@ -205,32 +204,24 @@ def execute_training(
) )
if cfg.sequence_parallel_degree > 1: if cfg.sequence_parallel_degree > 1:
if cfg.sdp_attention: # Models to enter context parallel manager for
world_size = dist.get_world_size() models = [trainer.model]
mesh_shape = ( if hasattr(trainer, "ref_model") and trainer.ref_model:
world_size // cfg.sequence_parallel_degree, models.append(trainer.ref_model)
cfg.sequence_parallel_degree,
)
mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
stack.enter_context(context_parallel(mesh=mesh))
else: # flash_attention
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
stack.enter_context( # Attention backend
SequenceParallelContextManager( backend = "sdp_attention" if cfg.sdp_attention else "flash_attention"
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree, stack.enter_context(
gradient_accumulation_steps=cfg.gradient_accumulation_steps, ContextParallelContextManager(
ring_attn_func=cfg.ring_attn_func, models=models,
heads_k_stride=cfg.heads_k_stride, backend=backend,
) context_parallel_degree=cfg.sequence_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
) )
)
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -3,4 +3,4 @@
# pylint: disable=unused-import # pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .sequence_parallel import SequenceParallelContextManager from .sequence_parallel import ContextParallelContextManager

View File

@@ -2,11 +2,13 @@
import functools import functools
import inspect import inspect
from typing import Literal
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch.distributed.tensor.experimental import context_parallel
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput from transformers.utils import ModelOutput
@@ -15,177 +17,41 @@ from axolotl.monkeypatch.ring_attn import (
patch_prepare_data_loader, patch_prepare_data_loader,
patch_prepare_device_mesh, patch_prepare_device_mesh,
register_ring_attn, register_ring_attn,
update_ring_attn_params,
) )
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this class ContextParallelContextManager:
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity. """Context manager for context parallelism operations.
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
) -> tuple[dict[str, torch.Tensor], int, int]:
"""
Apply sequence parallelism slicing to a batch.
Special handling is implemented for integer logits_to_keep, which indicates This class provides a context that will automatically apply context parallelism
to only keep the last N tokens in the sequence during generation.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused, but
related to above TODO.
Returns:
tuple of:
- Batch dictionary with sliced tensors.
- The original sequence length before padding.
- The number of padding tokens added.
"""
original_seq_len = batch["input_ids"].size(1)
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
else:
# If position_ids aren't already in the batch, create them
batch["position_ids"] = torch.arange(
0,
original_seq_len,
dtype=torch.long,
device=batch["input_ids"].device,
).expand(batch["input_ids"].size(0), -1)
if "logits_to_keep" in batch and isinstance(batch["logits_to_keep"], int):
logits_to_keep = batch["logits_to_keep"]
# Calculate which positions in the full sequence contain the last N tokens
start_position = max(0, original_seq_len - logits_to_keep)
chunk_size = original_seq_len // local_world_size
rank_start = local_rank * chunk_size
rank_end = rank_start + chunk_size
# Create a boolean mask tensor for this rank's chunk
mask = torch.zeros(
chunk_size,
dtype=torch.bool,
device=batch["input_ids"].device,
)
if rank_end > start_position:
# Calculate how many of the last N tokens fall within this rank's range
tokens_in_rank = min(rank_end, original_seq_len) - max(
rank_start, start_position
)
# Calculate where these tokens start in the local chunk
local_start_idx = max(0, start_position - rank_start)
# Set the appropriate positions in the mask to True
mask[local_start_idx : local_start_idx + tokens_in_rank] = True
# Replace the integer with the boolean mask
batch["logits_to_keep"] = mask
# Add padding to make sequence length divisible by local_world_size
total_seq_len = original_seq_len
pad_len = 0
divisor = min(local_world_size, 64)
if total_seq_len % divisor != 0:
pad_len = divisor - (total_seq_len % divisor)
# Apply padding to all relevant tensors
for key in batch:
if (
isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
# Create padding tensor
pad_value = -100 if key == "labels" else 0
padding = torch.full(
(batch[key].size(0), pad_len, *batch[key].shape[2:]),
pad_value,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=1)
if key == "logits_to_keep":
# Create padding tensor
padding = torch.ones(
1,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=0)
# Update the total sequence length after padding
total_seq_len = batch["input_ids"].size(1)
# Slice batch for sequence parallel
for key in batch:
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
continue
# Split in sequential fashion and grab this rank's chunk
if batch[key].size(1) == total_seq_len:
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif key == "logits_to_keep":
batch[key] = (
batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous()
)
# Handle num_items_in_batch
if "num_items_in_batch" in batch:
# Approximation; this needed since num_items_in_batch may be counted across
# all samples in a gradient accumulated batch, not on a per-step basis.
batch["num_items_in_batch"] = (
batch["labels"] != -100
).sum() * gradient_accumulation_steps
return batch, original_seq_len, pad_len
class SequenceParallelContextManager:
"""Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook. across the context parallelism group using a post-forward hook.
Args: Args:
models: List of models to apply sequence parallelism to pre- and post- forward models: List of models to apply context parallelism to pre- and post- forward
hooks. hooks.
sequence_parallel_degree: Number of processes to split sequences over. backend: Which attention backend to use.
context_parallel_degree: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over. gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused. ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to heads_k_stride: Context parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation. `varlen_llama3` `ring_flash_attn` implementation.
""" """
def __init__( def __init__(
self, self,
models: list[nn.Module], models: list[PreTrainedModel],
sequence_parallel_degree: int, backend: Literal["sdp_attention", "flash_attention"],
context_parallel_degree: int,
gradient_accumulation_steps: int, gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc, ring_attn_func: RingAttnFunc,
heads_k_stride: int | None, heads_k_stride: int | None,
): ):
self.models = models self.models = models
self.sequence_parallel_degree = sequence_parallel_degree self.backend = backend
self.context_parallel_degree = context_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride self.heads_k_stride = heads_k_stride
@@ -204,14 +70,34 @@ class SequenceParallelContextManager:
self.pad_len = 0 self.pad_len = 0
# Create a partially applied version of the apply_sequence_parallelism function # Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial( self.apply_context_parallelism = functools.partial(
apply_sequence_parallelism, apply_context_parallelism,
local_rank=self.local_rank, local_rank=self.local_rank,
local_world_size=self.local_world_size, local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps, gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func, ring_attn_func=self.ring_attn_func,
) )
# SPDA CP initialization
world_size = dist.get_world_size()
mesh_shape = (
world_size // self.context_parallel_degree,
self.context_parallel_degree,
)
world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
self.context_parallel_managers = []
for model in models:
ctx_manager = get_context_parallel_manager(
enabled=self.context_parallel_degree > 1,
world_mesh=world_mesh,
model=model,
)
self.context_parallel_managers.append(ctx_manager)
def __enter__(self): def __enter__(self):
self._register_model_hooks() self._register_model_hooks()
@@ -226,22 +112,25 @@ class SequenceParallelContextManager:
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority) # TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
def _register_ring_attn(self): def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism if self.backend == "flash_attention":
register_ring_attn( # Initialize ring attn for context parallelism
sequence_parallel_degree=self.sequence_parallel_degree, register_ring_attn(
heads_k_stride=self.heads_k_stride, sequence_parallel_degree=self.context_parallel_degree,
ring_attn_func=self.ring_attn_func, heads_k_stride=self.heads_k_stride,
) ring_attn_func=self.ring_attn_func,
)
else:
stack.enter_context(context_parallel(mesh=mesh))
# Patches for accelerate functionality # Patches for accelerate functionality
patch_prepare_data_loader() patch_prepare_data_loader()
patch_prepare_device_mesh( patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree sequence_parallel_degree=self.context_parallel_degree
) )
def _register_model_hooks(self): def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism # Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs): def cp_flash_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function # Get parameter names from the model's forward function
forward_params = list( forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys() inspect.signature(self.models[0].forward).parameters.keys()
@@ -257,13 +146,13 @@ class SequenceParallelContextManager:
# Apply sequence parallelism to updated kwargs # Apply sequence parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = ( updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(updated_kwargs) self.apply_context_parallelism(updated_kwargs)
) )
return remaining_args, updated_kwargs return remaining_args, updated_kwargs
# Forward post-hook to gather outputs # Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput:
# Gather the sharded outputs # Gather the sharded outputs
output = self._gather_outputs(output) output = self._gather_outputs(output)
@@ -277,15 +166,19 @@ class SequenceParallelContextManager:
return output return output
def cp_sdpa_pre_hook(_, args, kwargs):
with self.context_parallel_managers[?](list(inputs.values())):
# Register both hooks # Register both hooks
for model in self.models: for model in self.models:
self.hook_handles.append( self.hook_handles.append(
model.register_forward_pre_hook( model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True cp_flash_pre_hook, with_kwargs=True
) )
) )
self.hook_handles.append( self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook) model.register_forward_hook(cp_flash_post_hook)
) )
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:

View File

@@ -0,0 +1,145 @@
"""Utils for context parallel context manager."""
import torch
from axolotl.monkeypatch.ring_attn.patch import update_ring_attn_params
from axolotl.utils.schemas.enums import RingAttnFunc
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
def apply_context_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
) -> tuple[dict[str, torch.Tensor], int, iwnt]:
"""
Apply context parallelism slicing to a batch.
Special handling is implemented for integer logits_to_keep, which indicates
to only keep the last N tokens in the input sequence during generation.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the context parallel group.
local_world_size: World size of the context parallel group.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused, but
related to above TODO.
Returns:
tuple of:
- Batch dictionary with sliced tensors.
- The original sequence length before padding.
- The number of padding tokens added.
"""
original_seq_len = batch["input_ids"].size(1)
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
else:
# If position_ids aren't already in the batch, create them
batch["position_ids"] = torch.arange(
0,
original_seq_len,
dtype=torch.long,
device=batch["input_ids"].device,
).expand(batch["input_ids"].size(0), -1)
if "logits_to_keep" in batch and isinstance(batch["logits_to_keep"], int):
logits_to_keep = batch["logits_to_keep"]
# Calculate which positions in the full sequence contain the last N tokens
start_position = max(0, original_seq_len - logits_to_keep)
chunk_size = original_seq_len // local_world_size
rank_start = local_rank * chunk_size
rank_end = rank_start + chunk_size
# Create a boolean mask tensor for this rank's chunk
mask = torch.zeros(
chunk_size,
dtype=torch.bool,
device=batch["input_ids"].device,
)
if rank_end > start_position:
# Calculate how many of the last N tokens fall within this rank's range
tokens_in_rank = min(rank_end, original_seq_len) - max(
rank_start, start_position
)
# Calculate where these tokens start in the local chunk
local_start_idx = max(0, start_position - rank_start)
# Set the appropriate positions in the mask to True
mask[local_start_idx : local_start_idx + tokens_in_rank] = True
# Replace the integer with the boolean mask
batch["logits_to_keep"] = mask
# Add padding to make sequence length divisible by local_world_size
total_seq_len = original_seq_len
pad_len = 0
divisor = min(local_world_size, 64)
if total_seq_len % divisor != 0:
pad_len = divisor - (total_seq_len % divisor)
# Apply padding to all relevant tensors
for key in batch:
if (
isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
# Create padding tensor
pad_value = -100 if key == "labels" else 0
padding = torch.full(
(batch[key].size(0), pad_len, *batch[key].shape[2:]),
pad_value,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=1)
if key == "logits_to_keep":
# Create padding tensor
padding = torch.ones(
1,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=0)
# Update the total sequence length after padding
total_seq_len = batch["input_ids"].size(1)
# Slice batch for context parallel
for key in batch:
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
continue
# Split in sequential fashion and grab this rank's chunk
if batch[key].size(1) == total_seq_len:
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif key == "logits_to_keep":
batch[key] = (
batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous()
)
# Handle num_items_in_batch
if "num_items_in_batch" in batch:
# Approximation; this needed since num_items_in_batch may be counted across
# all samples in a gradient accumulated batch, not on a per-step basis.
batch["num_items_in_batch"] = (
batch["labels"] != -100
).sum() * gradient_accumulation_steps
return batch, original_seq_len, pad_len

View File

@@ -0,0 +1,103 @@
import contextlib
from typing import Callable, Generator, Optional, Union
import torch
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import BlockMask
from transformers import PreTrainedModel
def _get_sdpa_context() -> (
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
):
"""
Creates a context manager to confine to flash/efficient/cuDNN attention backends.
Returns:
A context manager function that takes an optional context parallel context.
"""
@contextlib.contextmanager
def context(cp_context: Union[Generator[None, None, None], None] = None):
with contextlib.ExitStack() as stack:
if cp_context is not None:
stack.enter_context(
sdpa_kernel(
[
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
]
)
)
stack.enter_context(cp_context)
yield
return context
def get_context_parallel_manager(
*,
enabled: bool = False,
world_mesh: torch.distributed.DeviceMesh,
model: PreTrainedModel,
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
"""
Context manager for applying context parallelism to a model. In addition to applying the
standard context manager to patch SDPA and shard model inputs and buffers along the sequence
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
Args:
enabled: Whether context parallel is enabled. Default: False
world_mesh: Global device mesh.
model: Model to apply context parallelism to.
Returns:
A context manager applying context parallelism if enabled is True. Otherwise a context manager
disabling the math SDPA backend.
Raises:
ValueError: if enabled is True but world_mesh does not contain a "cp" dimension
"""
if enabled and "cp" not in world_mesh.mesh_dim_names:
raise ValueError(
"Context parallel is enabled but no context parallel device mesh is provided."
)
# TODO: context parallel for multimodal models requires extra work
if enabled and not isinstance(model, TransformerDecoder):
raise ValueError("Context parallel is only supported for text models")
# TODO: this is a hacky proxy for whether we use flex for chunked attention
# remove this once flex is supported
if enabled and any([layer.mask_mod is not None for layer in model.layers]):
raise ValueError("Context parallel with flex attention is not yet supported")
model_buffers = list(model.buffers())
@contextlib.contextmanager
def context(model_inputs: list[torch.Tensor]):
# Create context parallel context if enabled
cp_context = None
if enabled and any([isinstance(input, BlockMask) for input in model_inputs]):
raise ValueError(
"Context parallel with flex attention is not yet supported"
)
if enabled:
set_rotate_method("allgather")
cp_context = context_parallel(
world_mesh["cp"],
buffers=model_inputs + model_buffers,
buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
no_restore_buffers=set(model_inputs),
)
# Create and enter the train context with the optional cp_context
sdpa_context = _get_sdpa_context()
with sdpa_context(cp_context):
yield
return context

View File

@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
register_ring_attn, register_ring_attn,
set_ring_attn_group, set_ring_attn_group,
) )
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism from axolotl.utils.ctx_managers.sequence_parallel import apply_context_parallelism
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.trl import TRLConfig
@@ -328,7 +328,7 @@ class TestApplySequenceParallelism:
"""Test that function returns original batch when world size is 1.""" """Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0 mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism( result, _, _ = apply_context_parallelism(
batch=sequence_parallel_batch, batch=sequence_parallel_batch,
local_rank=0, local_rank=0,
local_world_size=1, local_world_size=1,
@@ -347,7 +347,7 @@ class TestApplySequenceParallelism:
batch = sequence_parallel_batch batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1) seq_len = batch["input_ids"].size(1)
result, _, _ = apply_sequence_parallelism( result, _, _ = apply_context_parallelism(
batch=batch, batch=batch,
local_rank=0, local_rank=0,
local_world_size=2, local_world_size=2,
@@ -374,7 +374,7 @@ class TestApplySequenceParallelism:
seq_len = batch["input_ids"].size(1) seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone() original_input_ids = batch["input_ids"].clone()
result, _, _ = apply_sequence_parallelism( result, _, _ = apply_context_parallelism(
batch=batch, batch=batch,
local_rank=1, local_rank=1,
local_world_size=2, local_world_size=2,
@@ -440,7 +440,7 @@ class TestApplySequenceParallelism:
# Create a partially applied function # Create a partially applied function
rank0_ring_parallel = functools.partial( rank0_ring_parallel = functools.partial(
apply_sequence_parallelism, apply_context_parallelism,
local_rank=0, local_rank=0,
local_world_size=2, local_world_size=2,
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
@@ -466,7 +466,7 @@ class TestApplySequenceParallelism:
original_input_ids = batch["input_ids"].clone() original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing # This should run without error even though position_ids is missing
result, _, _ = apply_sequence_parallelism( result, _, _ = apply_context_parallelism(
batch=batch, batch=batch,
local_rank=0, local_rank=0,
local_world_size=2, local_world_size=2,