diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 70e443cb3..9b8f4ba92 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -7,11 +7,13 @@ from __future__ import annotations import os from collections import defaultdict from functools import partial, wraps -from typing import Callable, Literal, Optional +from typing import Any, Callable, Literal, Optional +from axolotl.utils.ctx_managers.utils import get_context_parallel_manager import datasets import torch from datasets import Dataset +from torch import nn from torch.utils.data import ( BatchSampler, DataLoader, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 26638a975..7003c73ae 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -17,7 +17,6 @@ from accelerate.utils import save_fsdp_model from datasets import Dataset from huggingface_hub.errors import OfflineModeIsEnabled from peft import PeftConfig, PeftModel -from torch.distributed.tensor.experimental import context_parallel from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import Trainer @@ -34,7 +33,7 @@ from axolotl.loaders import ( load_processor, load_tokenizer, ) -from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager +from axolotl.utils.ctx_managers.sequence_parallel import ContextParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except @@ -205,32 +204,24 @@ def execute_training( ) if cfg.sequence_parallel_degree > 1: - if cfg.sdp_attention: - world_size = dist.get_world_size() - mesh_shape = ( - world_size // cfg.sequence_parallel_degree, - cfg.sequence_parallel_degree, - ) - mesh = dist.DeviceMesh( - "cuda", - torch.tensor(list(range(world_size))).reshape(mesh_shape), - mesh_dim_names=("dp", "cp"), - ) - stack.enter_context(context_parallel(mesh=mesh)) - else: # flash_attention - models = [trainer.model] - if hasattr(trainer, "ref_model") and trainer.ref_model: - models.append(trainer.ref_model) + # Models to enter context parallel manager for + models = [trainer.model] + if hasattr(trainer, "ref_model") and trainer.ref_model: + models.append(trainer.ref_model) - stack.enter_context( - SequenceParallelContextManager( - models=models, - sequence_parallel_degree=cfg.sequence_parallel_degree, - gradient_accumulation_steps=cfg.gradient_accumulation_steps, - ring_attn_func=cfg.ring_attn_func, - heads_k_stride=cfg.heads_k_stride, - ) + # Attention backend + backend = "sdp_attention" if cfg.sdp_attention else "flash_attention" + + stack.enter_context( + ContextParallelContextManager( + models=models, + backend=backend, + context_parallel_degree=cfg.sequence_parallel_degree, + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + ring_attn_func=cfg.ring_attn_func, + heads_k_stride=cfg.heads_k_stride, ) + ) LOG.info("Starting trainer...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) diff --git a/src/axolotl/utils/ctx_managers/__init__.py b/src/axolotl/utils/ctx_managers/__init__.py index e544621b5..b92bfdf94 100644 --- a/src/axolotl/utils/ctx_managers/__init__.py +++ b/src/axolotl/utils/ctx_managers/__init__.py @@ -3,4 +3,4 @@ # pylint: disable=unused-import # flake8: noqa -from .sequence_parallel import SequenceParallelContextManager +from .sequence_parallel import ContextParallelContextManager diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/context_parallel.py similarity index 51% rename from src/axolotl/utils/ctx_managers/sequence_parallel.py rename to src/axolotl/utils/ctx_managers/context_parallel.py index 491cb9877..01c724f8d 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/context_parallel.py @@ -2,11 +2,13 @@ import functools import inspect +from typing import Literal import torch import torch.distributed as dist -from torch import nn +from torch.distributed.tensor.experimental import context_parallel from torch.utils.hooks import RemovableHandle +from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import ModelOutput @@ -15,177 +17,41 @@ from axolotl.monkeypatch.ring_attn import ( patch_prepare_data_loader, patch_prepare_device_mesh, register_ring_attn, - update_ring_attn_params, ) from axolotl.utils.schemas.enums import RingAttnFunc +from axolotl.utils.ctx_managers.utils import get_context_parallel_manager -# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this -# module. Currently, we just focus on batch ring and varlen llama3 for simplicity. -def apply_sequence_parallelism( - batch: dict[str, torch.Tensor], - local_rank: int, - local_world_size: int, - gradient_accumulation_steps: int, - ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument -) -> tuple[dict[str, torch.Tensor], int, int]: - """ - Apply sequence parallelism slicing to a batch. +class ContextParallelContextManager: + """Context manager for context parallelism operations. - Special handling is implemented for integer logits_to_keep, which indicates - to only keep the last N tokens in the sequence during generation. - - Args: - batch: Batch dictionary (e.g., input_ids, attention_mask, etc.). - local_rank: Local rank in the sequence parallel group. - local_world_size: World size of the sequence parallel group. - gradient_accumulation_steps: Number of steps to accumulate gradients over. - ring_attn_func: Which ring attention function to use. Currently unused, but - related to above TODO. - - Returns: - tuple of: - - Batch dictionary with sliced tensors. - - The original sequence length before padding. - - The number of padding tokens added. - """ - original_seq_len = batch["input_ids"].size(1) - - # Update ring attention params if needed - if batch.get("position_ids") is not None: - update_ring_attn_params(position_ids=batch["position_ids"]) - else: - # If position_ids aren't already in the batch, create them - batch["position_ids"] = torch.arange( - 0, - original_seq_len, - dtype=torch.long, - device=batch["input_ids"].device, - ).expand(batch["input_ids"].size(0), -1) - - if "logits_to_keep" in batch and isinstance(batch["logits_to_keep"], int): - logits_to_keep = batch["logits_to_keep"] - - # Calculate which positions in the full sequence contain the last N tokens - start_position = max(0, original_seq_len - logits_to_keep) - chunk_size = original_seq_len // local_world_size - rank_start = local_rank * chunk_size - rank_end = rank_start + chunk_size - - # Create a boolean mask tensor for this rank's chunk - mask = torch.zeros( - chunk_size, - dtype=torch.bool, - device=batch["input_ids"].device, - ) - - if rank_end > start_position: - # Calculate how many of the last N tokens fall within this rank's range - tokens_in_rank = min(rank_end, original_seq_len) - max( - rank_start, start_position - ) - - # Calculate where these tokens start in the local chunk - local_start_idx = max(0, start_position - rank_start) - - # Set the appropriate positions in the mask to True - mask[local_start_idx : local_start_idx + tokens_in_rank] = True - - # Replace the integer with the boolean mask - batch["logits_to_keep"] = mask - - # Add padding to make sequence length divisible by local_world_size - total_seq_len = original_seq_len - pad_len = 0 - divisor = min(local_world_size, 64) - if total_seq_len % divisor != 0: - pad_len = divisor - (total_seq_len % divisor) - - # Apply padding to all relevant tensors - for key in batch: - if ( - isinstance(batch[key], torch.Tensor) - and batch[key].dim() > 1 - and batch[key].size(1) == total_seq_len - ): - # Create padding tensor - pad_value = -100 if key == "labels" else 0 - padding = torch.full( - (batch[key].size(0), pad_len, *batch[key].shape[2:]), - pad_value, - dtype=batch[key].dtype, - device=batch[key].device, - ) - - # Concatenate padding to the right side of the tensor - batch[key] = torch.cat([batch[key], padding], dim=1) - if key == "logits_to_keep": - # Create padding tensor - padding = torch.ones( - 1, - dtype=batch[key].dtype, - device=batch[key].device, - ) - - # Concatenate padding to the right side of the tensor - batch[key] = torch.cat([batch[key], padding], dim=0) - - # Update the total sequence length after padding - total_seq_len = batch["input_ids"].size(1) - - # Slice batch for sequence parallel - for key in batch: - if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1: - continue - - # Split in sequential fashion and grab this rank's chunk - if batch[key].size(1) == total_seq_len: - batch[key] = ( - batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous() - ) - elif key == "logits_to_keep": - batch[key] = ( - batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous() - ) - - # Handle num_items_in_batch - if "num_items_in_batch" in batch: - # Approximation; this needed since num_items_in_batch may be counted across - # all samples in a gradient accumulated batch, not on a per-step basis. - batch["num_items_in_batch"] = ( - batch["labels"] != -100 - ).sum() * gradient_accumulation_steps - - return batch, original_seq_len, pad_len - - -class SequenceParallelContextManager: - """Context manager for sequence parallelism operations. - - This class provides a context that will automatically apply sequence parallelism + This class provides a context that will automatically apply context parallelism during model forward passes using a pre-forward hook, and gather outputs from - across the sequence parallelism group using a post-forward hook. + across the context parallelism group using a post-forward hook. Args: - models: List of models to apply sequence parallelism to pre- and post- forward + models: List of models to apply context parallelism to pre- and post- forward hooks. - sequence_parallel_degree: Number of processes to split sequences over. + backend: Which attention backend to use. + context_parallel_degree: Number of processes to split sequences over. gradient_accumulation_steps: Number of steps to accumulate gradients over. ring_attn_func: Which ring attention function to use. Currently unused. - heads_k_stride: Sequence parallelism K head stride size. Passed through to + heads_k_stride: Context parallelism K head stride size. Passed through to `varlen_llama3` `ring_flash_attn` implementation. """ def __init__( self, - models: list[nn.Module], - sequence_parallel_degree: int, + models: list[PreTrainedModel], + backend: Literal["sdp_attention", "flash_attention"], + context_parallel_degree: int, gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, heads_k_stride: int | None, ): self.models = models - self.sequence_parallel_degree = sequence_parallel_degree + self.backend = backend + self.context_parallel_degree = context_parallel_degree self.gradient_accumulation_steps = gradient_accumulation_steps self.ring_attn_func = ring_attn_func self.heads_k_stride = heads_k_stride @@ -204,14 +70,34 @@ class SequenceParallelContextManager: self.pad_len = 0 # Create a partially applied version of the apply_sequence_parallelism function - self.apply_sequence_parallelism = functools.partial( - apply_sequence_parallelism, + self.apply_context_parallelism = functools.partial( + apply_context_parallelism, local_rank=self.local_rank, local_world_size=self.local_world_size, gradient_accumulation_steps=self.gradient_accumulation_steps, ring_attn_func=self.ring_attn_func, ) + # SPDA CP initialization + world_size = dist.get_world_size() + mesh_shape = ( + world_size // self.context_parallel_degree, + self.context_parallel_degree, + ) + world_mesh = dist.DeviceMesh( + "cuda", + torch.tensor(list(range(world_size))).reshape(mesh_shape), + mesh_dim_names=("dp", "cp"), + ) + self.context_parallel_managers = [] + for model in models: + ctx_manager = get_context_parallel_manager( + enabled=self.context_parallel_degree > 1, + world_mesh=world_mesh, + model=model, + ) + self.context_parallel_managers.append(ctx_manager) + def __enter__(self): self._register_model_hooks() @@ -226,22 +112,25 @@ class SequenceParallelContextManager: # TODO(djsaunde): Un-patch attention and accelerate functions (low priority) def _register_ring_attn(self): - # Initialize ring attn for sequence parallelism - register_ring_attn( - sequence_parallel_degree=self.sequence_parallel_degree, - heads_k_stride=self.heads_k_stride, - ring_attn_func=self.ring_attn_func, - ) + if self.backend == "flash_attention": + # Initialize ring attn for context parallelism + register_ring_attn( + sequence_parallel_degree=self.context_parallel_degree, + heads_k_stride=self.heads_k_stride, + ring_attn_func=self.ring_attn_func, + ) + else: + stack.enter_context(context_parallel(mesh=mesh)) # Patches for accelerate functionality patch_prepare_data_loader() patch_prepare_device_mesh( - sequence_parallel_degree=self.sequence_parallel_degree + sequence_parallel_degree=self.context_parallel_degree ) def _register_model_hooks(self): # Forward pre-hook to apply sequence parallelism - def sequence_parallel_pre_hook(_, args, kwargs): + def cp_flash_pre_hook(_, args, kwargs): # Get parameter names from the model's forward function forward_params = list( inspect.signature(self.models[0].forward).parameters.keys() @@ -257,13 +146,13 @@ class SequenceParallelContextManager: # Apply sequence parallelism to updated kwargs updated_kwargs, self.original_seq_len, self.pad_len = ( - self.apply_sequence_parallelism(updated_kwargs) + self.apply_context_parallelism(updated_kwargs) ) return remaining_args, updated_kwargs # Forward post-hook to gather outputs - def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: + def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput: # Gather the sharded outputs output = self._gather_outputs(output) @@ -277,15 +166,19 @@ class SequenceParallelContextManager: return output + def cp_sdpa_pre_hook(_, args, kwargs): + with self.context_parallel_managers[?](list(inputs.values())): + + # Register both hooks for model in self.models: self.hook_handles.append( model.register_forward_pre_hook( - sequence_parallel_pre_hook, with_kwargs=True + cp_flash_pre_hook, with_kwargs=True ) ) self.hook_handles.append( - model.register_forward_hook(sequence_parallel_post_hook) + model.register_forward_hook(cp_flash_post_hook) ) def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: diff --git a/src/axolotl/utils/ctx_managers/context_parallel/__init__.py b/src/axolotl/utils/ctx_managers/context_parallel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/utils/ctx_managers/context_parallel/utils.py b/src/axolotl/utils/ctx_managers/context_parallel/utils.py new file mode 100644 index 000000000..3652250f0 --- /dev/null +++ b/src/axolotl/utils/ctx_managers/context_parallel/utils.py @@ -0,0 +1,145 @@ +"""Utils for context parallel context manager.""" + +import torch + +from axolotl.monkeypatch.ring_attn.patch import update_ring_attn_params +from axolotl.utils.schemas.enums import RingAttnFunc + + +# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this +# module. Currently, we just focus on batch ring and varlen llama3 for simplicity. +def apply_context_parallelism( + batch: dict[str, torch.Tensor], + local_rank: int, + local_world_size: int, + gradient_accumulation_steps: int, + ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument +) -> tuple[dict[str, torch.Tensor], int, iwnt]: + """ + Apply context parallelism slicing to a batch. + + Special handling is implemented for integer logits_to_keep, which indicates + to only keep the last N tokens in the input sequence during generation. + + Args: + batch: Batch dictionary (e.g., input_ids, attention_mask, etc.). + local_rank: Local rank in the context parallel group. + local_world_size: World size of the context parallel group. + gradient_accumulation_steps: Number of steps to accumulate gradients over. + ring_attn_func: Which ring attention function to use. Currently unused, but + related to above TODO. + + Returns: + tuple of: + - Batch dictionary with sliced tensors. + - The original sequence length before padding. + - The number of padding tokens added. + """ + original_seq_len = batch["input_ids"].size(1) + + # Update ring attention params if needed + if batch.get("position_ids") is not None: + update_ring_attn_params(position_ids=batch["position_ids"]) + else: + # If position_ids aren't already in the batch, create them + batch["position_ids"] = torch.arange( + 0, + original_seq_len, + dtype=torch.long, + device=batch["input_ids"].device, + ).expand(batch["input_ids"].size(0), -1) + + if "logits_to_keep" in batch and isinstance(batch["logits_to_keep"], int): + logits_to_keep = batch["logits_to_keep"] + + # Calculate which positions in the full sequence contain the last N tokens + start_position = max(0, original_seq_len - logits_to_keep) + chunk_size = original_seq_len // local_world_size + rank_start = local_rank * chunk_size + rank_end = rank_start + chunk_size + + # Create a boolean mask tensor for this rank's chunk + mask = torch.zeros( + chunk_size, + dtype=torch.bool, + device=batch["input_ids"].device, + ) + + if rank_end > start_position: + # Calculate how many of the last N tokens fall within this rank's range + tokens_in_rank = min(rank_end, original_seq_len) - max( + rank_start, start_position + ) + + # Calculate where these tokens start in the local chunk + local_start_idx = max(0, start_position - rank_start) + + # Set the appropriate positions in the mask to True + mask[local_start_idx : local_start_idx + tokens_in_rank] = True + + # Replace the integer with the boolean mask + batch["logits_to_keep"] = mask + + # Add padding to make sequence length divisible by local_world_size + total_seq_len = original_seq_len + pad_len = 0 + divisor = min(local_world_size, 64) + if total_seq_len % divisor != 0: + pad_len = divisor - (total_seq_len % divisor) + + # Apply padding to all relevant tensors + for key in batch: + if ( + isinstance(batch[key], torch.Tensor) + and batch[key].dim() > 1 + and batch[key].size(1) == total_seq_len + ): + # Create padding tensor + pad_value = -100 if key == "labels" else 0 + padding = torch.full( + (batch[key].size(0), pad_len, *batch[key].shape[2:]), + pad_value, + dtype=batch[key].dtype, + device=batch[key].device, + ) + + # Concatenate padding to the right side of the tensor + batch[key] = torch.cat([batch[key], padding], dim=1) + if key == "logits_to_keep": + # Create padding tensor + padding = torch.ones( + 1, + dtype=batch[key].dtype, + device=batch[key].device, + ) + + # Concatenate padding to the right side of the tensor + batch[key] = torch.cat([batch[key], padding], dim=0) + + # Update the total sequence length after padding + total_seq_len = batch["input_ids"].size(1) + + # Slice batch for context parallel + for key in batch: + if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1: + continue + + # Split in sequential fashion and grab this rank's chunk + if batch[key].size(1) == total_seq_len: + batch[key] = ( + batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous() + ) + elif key == "logits_to_keep": + batch[key] = ( + batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous() + ) + + # Handle num_items_in_batch + if "num_items_in_batch" in batch: + # Approximation; this needed since num_items_in_batch may be counted across + # all samples in a gradient accumulated batch, not on a per-step basis. + batch["num_items_in_batch"] = ( + batch["labels"] != -100 + ).sum() * gradient_accumulation_steps + + return batch, original_seq_len, pad_len \ No newline at end of file diff --git a/src/axolotl/utils/ctx_managers/utils.py b/src/axolotl/utils/ctx_managers/utils.py new file mode 100644 index 000000000..a2a9ba725 --- /dev/null +++ b/src/axolotl/utils/ctx_managers/utils.py @@ -0,0 +1,103 @@ +import contextlib +from typing import Callable, Generator, Optional, Union + +import torch + +from torch.distributed.tensor.experimental import context_parallel +from torch.distributed.tensor.experimental._attention import set_rotate_method +from torch.nn.attention import sdpa_kernel, SDPBackend +from torch.nn.attention.flex_attention import BlockMask +from transformers import PreTrainedModel + + +def _get_sdpa_context() -> ( + Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]] +): + """ + Creates a context manager to confine to flash/efficient/cuDNN attention backends. + + Returns: + A context manager function that takes an optional context parallel context. + """ + + @contextlib.contextmanager + def context(cp_context: Union[Generator[None, None, None], None] = None): + with contextlib.ExitStack() as stack: + if cp_context is not None: + stack.enter_context( + sdpa_kernel( + [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + ] + ) + ) + stack.enter_context(cp_context) + + yield + + return context + + +def get_context_parallel_manager( + *, + enabled: bool = False, + world_mesh: torch.distributed.DeviceMesh, + model: PreTrainedModel, +) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]: + """ + Context manager for applying context parallelism to a model. In addition to applying the + standard context manager to patch SDPA and shard model inputs and buffers along the sequence + dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends. + + Args: + enabled: Whether context parallel is enabled. Default: False + world_mesh: Global device mesh. + model: Model to apply context parallelism to. + + Returns: + A context manager applying context parallelism if enabled is True. Otherwise a context manager + disabling the math SDPA backend. + + Raises: + ValueError: if enabled is True but world_mesh does not contain a "cp" dimension + """ + + if enabled and "cp" not in world_mesh.mesh_dim_names: + raise ValueError( + "Context parallel is enabled but no context parallel device mesh is provided." + ) + # TODO: context parallel for multimodal models requires extra work + if enabled and not isinstance(model, TransformerDecoder): + raise ValueError("Context parallel is only supported for text models") + # TODO: this is a hacky proxy for whether we use flex for chunked attention + # remove this once flex is supported + if enabled and any([layer.mask_mod is not None for layer in model.layers]): + raise ValueError("Context parallel with flex attention is not yet supported") + model_buffers = list(model.buffers()) + + @contextlib.contextmanager + def context(model_inputs: list[torch.Tensor]): + # Create context parallel context if enabled + cp_context = None + if enabled and any([isinstance(input, BlockMask) for input in model_inputs]): + raise ValueError( + "Context parallel with flex attention is not yet supported" + ) + if enabled: + set_rotate_method("allgather") + cp_context = context_parallel( + world_mesh["cp"], + buffers=model_inputs + model_buffers, + buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers), + no_restore_buffers=set(model_inputs), + ) + + # Create and enter the train context with the optional cp_context + sdpa_context = _get_sdpa_context() + + with sdpa_context(cp_context): + yield + + return context \ No newline at end of file diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 2b4d11b30..28146aca7 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import ( register_ring_attn, set_ring_attn_group, ) -from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism +from axolotl.utils.ctx_managers.sequence_parallel import apply_context_parallelism from axolotl.utils.dict import DictDefault from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.trl import TRLConfig @@ -328,7 +328,7 @@ class TestApplySequenceParallelism: """Test that function returns original batch when world size is 1.""" mock_get_ring_attn_group.return_value = 0 - result, _, _ = apply_sequence_parallelism( + result, _, _ = apply_context_parallelism( batch=sequence_parallel_batch, local_rank=0, local_world_size=1, @@ -347,7 +347,7 @@ class TestApplySequenceParallelism: batch = sequence_parallel_batch seq_len = batch["input_ids"].size(1) - result, _, _ = apply_sequence_parallelism( + result, _, _ = apply_context_parallelism( batch=batch, local_rank=0, local_world_size=2, @@ -374,7 +374,7 @@ class TestApplySequenceParallelism: seq_len = batch["input_ids"].size(1) original_input_ids = batch["input_ids"].clone() - result, _, _ = apply_sequence_parallelism( + result, _, _ = apply_context_parallelism( batch=batch, local_rank=1, local_world_size=2, @@ -440,7 +440,7 @@ class TestApplySequenceParallelism: # Create a partially applied function rank0_ring_parallel = functools.partial( - apply_sequence_parallelism, + apply_context_parallelism, local_rank=0, local_world_size=2, gradient_accumulation_steps=1, @@ -466,7 +466,7 @@ class TestApplySequenceParallelism: original_input_ids = batch["input_ids"].clone() # This should run without error even though position_ids is missing - result, _, _ = apply_sequence_parallelism( + result, _, _ = apply_context_parallelism( batch=batch, local_rank=0, local_world_size=2,