diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index 0167df67a..c412bfc72 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -8,7 +8,7 @@ on: - 'setup.py' - 'pyproject.toml' - '.github/workflows/multi-gpu-e2e.yml' - - 'src/axolotl/core/trainers/mixins/sequence_parallel.py' + - 'src/axolotl/core/trainers/mixins/context_parallel.py' - 'src/axolotl/utils/distributed.py' workflow_dispatch: schedule: diff --git a/_quarto.yml b/_quarto.yml index 9b97095ce..5b67fa0af 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -75,7 +75,7 @@ quartodoc: - title: Context Managers desc: Context managers for altering trainer behaviors contents: - - utils.ctx_managers.sequence_parallel + - utils.ctx_managers.context_parallel - title: Prompt Strategies desc: Prompt formatting strategies contents: @@ -274,7 +274,7 @@ website: - docs/unsloth.qmd - docs/torchao.qmd - docs/custom_integrations.qmd - - docs/sequence_parallelism.qmd + - docs/context_parallelism.qmd - section: "Troubleshooting" contents: diff --git a/docs/config.qmd b/docs/config.qmd index 519065554..ff9fc26f0 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -764,13 +764,13 @@ ddp_timeout: ddp_bucket_cap_mb: ddp_broadcast_buffers: -# Sequence parallelism +# Context parallelism # Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. # Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. # E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized # subsequences, or set to 4 to split into four equal-sized subsequences. -# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details. -sequence_parallel_degree: +# See https://docs.axolotl.ai/docs/context_parallelism.html for more details. +context_parallel_degree: # Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Must evenly divide the number of KV heads in your model. heads_k_stride: 1 diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index fee7d17e5..c23794d9c 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -18,7 +18,7 @@ Axolotl supports several methods for multi-GPU training: - DeepSpeed (recommended) - FSDP (Fully Sharded Data Parallel) -- Sequence parallelism +- Context parallelism - FSDP + QLoRA ## DeepSpeed {#sec-deepspeed} @@ -80,14 +80,14 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer ``` -## Sequence parallelism {#sec-sequence-parallelism} +## Context parallelism {#sec-sequence-parallelism} -We support sequence parallelism (SP) via the +We support context parallelism (SP) via the [ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This allows one to split up sequences across GPUs, which is useful in the event that a single sequence causes OOM errors during model training. -See our [dedicated guide](sequence_parallelism.qmd) for more information. +See our [dedicated guide](context_parallelism.qmd) for more information. ### FSDP + QLoRA {#sec-fsdp-qlora} diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index b98206135..5ca50f236 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -1,16 +1,16 @@ --- -title: Sequence Parallelism +title: Context Parallelism description: Train with long sequences split across multiple GPUs. --- -Sequence parallelism is a technique that splits sequences across multiple GPUs, +Context parallelism is a technique that splits sequences across multiple GPUs, allowing you to train with very long sequences that wouldn't fit on a single GPU. Each GPU processes a different portion of the sequence, and the results are aggregated through a ring communication pattern. -## When to Use Sequence Parallelism +## When to Use Context Parallelism -Use sequence parallelism when: +Use context parallelism when: - You need to train with sequence lengths that don't fit into a single GPU's memory - You have multiple GPUs available @@ -18,11 +18,11 @@ Use sequence parallelism when: ## Configuration -To enable sequence parallelism, add the following to your configuration file: +To enable context parallelism, add the following to your configuration file: ```yaml # Set to a divisor (> 1) of the number of GPUs available -sequence_parallel_degree: 4 # Split sequences across 4 GPUs +context_parallel_degree: 4 # Split sequences across 4 GPUs # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 # Optional; one of "varlen_llama3" or "batch_ring". Defaults to @@ -30,23 +30,23 @@ heads_k_stride: 1 ring_attn_func: ``` -The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: +The `context_parallel_degree` should be a divisor of the total number of GPUs. For example: - With 8 GPUs, valid values would be 2, 4, or 8 - With 4 GPUs, valid values would be 2 or 4 ## Implementation Details -When sequence parallelism is enabled: +When context parallelism is enabled: -1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group +1. Each sequence is divided into equal chunks across the GPUs in a context parallel group 2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids 3. Position IDs are adjusted to maintain proper relative positions 4. The trainer uses special ring communication patterns for attention operations ## Requirements -To use sequence parallelism, you need: +To use context parallelism, you need: - Multiple GPUs (at least 2) - The `ring-flash-attn` package. Install with: @@ -66,7 +66,7 @@ sequence_len: 8192 ... -sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU +context_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 # Optional; one of "varlen_llama3" or "batch_ring". Defaults to @@ -79,22 +79,22 @@ ring_attn_func: This will train the Llama 3 8B model with 8K context length, with each sequence split into 2 subsequences of length 4096 across 2 GPUs. -## Sample Packing with Sequence Parallelism +## Sample Packing with Context Parallelism -Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together: +Context parallelism is compatible with Axolotl's sample packing functionality. When using both features together: 1. Samples are first packed together -2. The packed sequences are then divided across GPUs in the sequence parallel group +2. The packed sequences are then divided across GPUs in the context parallel group 3. Position IDs are automatically adjusted to maintain proper relative positions ## Effect on Batch Size -When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: +When using context parallelism, your effective global batch size is **divided** by the `context_parallel_degree`. This happens because: -- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) +- Each group of `context_parallel_degree` GPUs works on the same batch (just different parts of each sequence) - The number of batches processed per step decreases For example: -- With 8 GPUs and no sequence parallelism: 8 different batches processed per step -- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) +- With 8 GPUs and no context parallelism: 8 different batches processed per step +- With 8 GPUs and `context_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 36cfdec4e..fcbb38778 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -73,7 +73,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: load_in_8bit=False, load_in_4bit=False, flash_attention=False, - sequence_parallel_degree=None, + context_parallel_degree=None, deepspeed=None, fsdp=None, fsdp_config=None, diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 14dbfa715..68487e3a9 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -54,7 +54,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl is RLType.GRPO: trainer_cls = GRPOStrategy.get_trainer_class( - sequence_parallel=self.cfg.sequence_parallel_degree > 1 + context_parallel=self.cfg.context_parallel_degree > 1 ) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index 2cdc9c195..8d42dec7b 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -5,7 +5,7 @@ from .base import AxolotlTrainer from .dpo.trainer import AxolotlDPOTrainer -from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer +from .grpo.trainer import AxolotlGRPOContextParallelTrainer, AxolotlGRPOTrainer from .mamba import AxolotlMambaTrainer from .relora import ReLoRATrainer from .trl import ( diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index c0f10be23..5a6fcc3d1 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -8,7 +8,7 @@ from trl.trainer.grpo_trainer import RewardFunc from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig from axolotl.core.trainers.grpo.trainer import ( - AxolotlGRPOSequenceParallelTrainer, + AxolotlGRPOContextParallelTrainer, AxolotlGRPOTrainer, ) from axolotl.utils.dict import DictDefault @@ -23,10 +23,10 @@ class GRPOStrategy: @classmethod def get_trainer_class( - cls, sequence_parallel: bool - ) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]: - if sequence_parallel: - return AxolotlGRPOSequenceParallelTrainer + cls, context_parallel: bool + ) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOContextParallelTrainer]: + if context_parallel: + return AxolotlGRPOContextParallelTrainer return AxolotlGRPOTrainer @classmethod @@ -69,8 +69,8 @@ class GRPOStrategy: grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print - if cfg.sequence_parallel_degree > 1: - grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree + if cfg.context_parallel_degree > 1: + grpo_args_kwargs["context_parallel_degree"] = cfg.context_parallel_degree if trl.reward_weights: grpo_args_kwargs["reward_weights"] = trl.reward_weights diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index 5c8b1a33b..6d57fed11 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): """Axolotl GRPO Config for GRPO training""" - sequence_parallel_degree: int | None = None + context_parallel_degree: int | None = None diff --git a/src/axolotl/core/trainers/grpo/sampler.py b/src/axolotl/core/trainers/grpo/sampler.py index ebc6e19e2..fc90f3e41 100644 --- a/src/axolotl/core/trainers/grpo/sampler.py +++ b/src/axolotl/core/trainers/grpo/sampler.py @@ -1,7 +1,7 @@ """Repeat random sampler (similar to the one implemented in https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds -sequence parallelism functionality; i.e., duplicating data across ranks in the same -sequence parallel group. +context parallelism functionality; i.e., duplicating data across ranks in the same +context parallel group. """ from typing import Iterator, Sized @@ -10,26 +10,26 @@ import torch from torch.utils.data import Sampler -class SequenceParallelRepeatRandomSampler(Sampler): - """Sampler for GRPO training with sequence parallelism. +class ContextParallelRepeatRandomSampler(Sampler): + """Sampler for GRPO training with context parallelism. This sampler ensures: - - Ranks in the same sequence parallel (SP) group receive identical data. + - Ranks in the same context parallel (SP) group receive identical data. - Each index is repeated multiple times for sampling different completions. - Entire batches are repeated for reuse in multiple updates. - - Data is properly distributed across SP groups. + - Data is properly distributed across CP groups. - In the table below, the values represent dataset indices. Each SP group has - `sequence_parallel_degree = 2` GPUs working together on the same data. There are 2 - SP groups (SP0 and SP1), with `world_size = 4` total GPUs. + In the table below, the values represent dataset indices. Each CP group has + `context_parallel_degree = 2` GPUs working together on the same data. There are 2 + CP groups (SP0 and SP1), with `world_size = 4` total GPUs. - Sequence Parallel Groups + Context Parallel Groups | SP0 | SP1 | | GPU 0 | GPU 1 | GPU 2 | GPU 3 | global_step step <---> mini_repeat_count=3 - <----------> batch_size=2 per SP group - grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data - ▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU + <----------> batch_size=2 per CP group + grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- CP groups get different data + ▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each CP group GPU | | 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation @@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): rank: Rank of current process. batch_size: Number of samples per batch. repeat_count: How many times to repeat the full sampling process. - sequence_parallel_degree: Number of ranks in a sequence parallel group. + context_parallel_degree: Number of ranks in a context parallel group. shuffle: Whether to shuffle the dataset. seed: Random seed for shuffling. drop_last: Whether to drop the last incomplete batch. @@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): rank: int, batch_size: int = 1, repeat_count: int = 1, - sequence_parallel_degree: int = 1, + context_parallel_degree: int = 1, shuffle: bool = True, seed: int = 0, drop_last: bool = False, @@ -76,16 +76,16 @@ class SequenceParallelRepeatRandomSampler(Sampler): self.world_size = world_size self.rank = rank - # Sequence parallelism parameters - self.sequence_parallel_degree = sequence_parallel_degree - self.num_sp_groups = world_size // sequence_parallel_degree - self.sp_group_id = rank // sequence_parallel_degree + # Context parallelism parameters + self.context_parallel_degree = context_parallel_degree + self.num_sp_groups = world_size // context_parallel_degree + self.sp_group_id = rank // context_parallel_degree # Adjust dataset size for distributed sampling self.num_samples = len(self.dataset) self.total_size = self.num_samples - # Calculate effective number of samples per SP group + # Calculate effective number of samples per CP group if ( self.drop_last and self.total_size % (self.num_sp_groups * self.batch_size) != 0 @@ -125,8 +125,8 @@ class SequenceParallelRepeatRandomSampler(Sampler): padding = indices[: self.batch_size - len(indices) % self.batch_size] indices += padding - # Subsample based on SP group ID - # Each SP group gets distinct batches of data + # Subsample based on CP group ID + # Each CP group gets distinct batches of data batch_indices = [] for i in range(0, len(indices), self.batch_size * self.num_sp_groups): start_idx = i + self.sp_group_id * self.batch_size diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index dccc85d80..0734cdac6 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,4 +1,4 @@ -"""Axolotl GRPO trainers (with and without sequence parallelism handling)""" +"""Axolotl GRPO trainers (with and without context parallelism handling)""" # pylint: disable=too-many-lines,duplicate-code,protected-access,no-member @@ -41,7 +41,7 @@ from trl.trainer.grpo_config import GRPOConfig from trl.trainer.grpo_trainer import RewardFunc, nanstd from trl.trainer.utils import pad -from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler +from axolotl.core.trainers.grpo.sampler import ContextParallelRepeatRandomSampler from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.monkeypatch.ring_attn import get_ring_attn_group @@ -59,8 +59,8 @@ class AxolotlGRPOTrainer( _tag_names = ["trl", "grpo", "axolotl"] -class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): - """Extend the base GRPOTrainer for sequence parallelism handling""" +class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer): + """Extend the base GRPOTrainer for context parallelism handling""" def __init__( self, @@ -97,11 +97,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, ) - # Get number of SP groups (number of processes divided by SP degree) + # Get number of CP groups (number of processes divided by CP degree) num_processes = self.accelerator.num_processes - num_sp_groups = num_processes // self.args.sequence_parallel_degree + num_sp_groups = num_processes // self.args.context_parallel_degree - # Calculate batch size per SP group (not per process) + # Calculate batch size per CP group (not per process) sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups possible_values = [ n_gen @@ -111,7 +111,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): if self.num_generations not in possible_values: raise ValueError( - f"The batch size per SP group ({num_sp_groups} x " + f"The batch size per CP group ({num_sp_groups} x " f"{self.args.per_device_train_batch_size}) must be evenly divisible by " f"the number of generations per prompt ({self.num_generations}). Given " "the current configuration, the valid values for the number of " @@ -119,7 +119,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): ) if self.args.eval_strategy != "no": - # If sequence parallelism is enabled, calculate batch size per SP group + # If context parallelism is enabled, calculate batch size per CP group sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr] possible_values = [ n_gen @@ -129,8 +129,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): if self.num_generations not in possible_values: raise ValueError( - f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), " - f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " + f"With context parallelism (degree {self.args.context_parallel_degree}), " + f"the eval batch size per CP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " f"must be evenly divisible by the number of generations per prompt " f"({self.num_generations}). Given the current eval batch size, " f"the valid values for the number of generations are: {possible_values}." @@ -143,7 +143,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): self.local_world_size = 1 def train(self, *args, **kwargs): - # Initialize the SP group + # Initialize the CP group self.sp_group = get_ring_attn_group() self.rank = dist.get_rank() self.world_size = dist.get_world_size() @@ -159,16 +159,16 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): * self.args.gradient_accumulation_steps ) - return SequenceParallelRepeatRandomSampler( + return ContextParallelRepeatRandomSampler( dataset=self.train_dataset, mini_repeat_count=self.num_generations, world_size=self.world_size, rank=self.rank, batch_size=effective_batch_size // self.num_generations - // self.args.sequence_parallel_degree, + // self.args.context_parallel_degree, repeat_count=self.num_iterations * self.args.gradient_accumulation_steps, - sequence_parallel_degree=self.args.sequence_parallel_degree, + context_parallel_degree=self.args.context_parallel_degree, shuffle=True, seed=self.args.seed, drop_last=True, @@ -226,11 +226,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): ): self.accelerator.even_batches = False - # Return unprepared dataloader if using sequence parallelism + # Return unprepared dataloader if using context parallelism # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., # slice each batch along the sequence dimension). - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_degree > 1: return dataloader # Otherwise prepare with accelerator @@ -303,21 +303,21 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # Generate completions using vLLM: gather all prompts and use them in a single call in the main process all_prompts_text = gather_object(prompts_text) if self.accelerator.is_main_process: - if self.args.sequence_parallel_degree > 1: - # Calculate sequence parallel group information + if self.args.context_parallel_degree > 1: + # Calculate context parallel group information world_size = self.accelerator.num_processes - sequence_parallel_degree = self.args.sequence_parallel_degree - num_sp_groups = world_size // sequence_parallel_degree + context_parallel_degree = self.args.context_parallel_degree + num_sp_groups = world_size // context_parallel_degree - # Since processes in the same SP group have the same prompts, we need to ensure - # we only take one copy of each prompt from each SP group + # Since processes in the same CP group have the same prompts, we need to ensure + # we only take one copy of each prompt from each CP group ordered_set_of_prompts = [] for sp_group_id in range(num_sp_groups): - # Get the first process from each SP group (typically the group leader) - group_leader_rank = sp_group_id * sequence_parallel_degree + # Get the first process from each CP group (typically the group leader) + group_leader_rank = sp_group_id * context_parallel_degree - # Extract prompts from this SP group, accounting for num_generations duplicates - # We only need prompts from one rank in each SP group + # Extract prompts from this CP group, accounting for num_generations duplicates + # We only need prompts from one rank in each CP group group_prompts = all_prompts_text[ group_leader_rank * len(prompts_text) : (group_leader_rank + 1) @@ -330,7 +330,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. ordered_set_of_prompts = all_prompts_text[ - :: self.num_generations * self.args.sequence_parallel_degree + :: self.num_generations * self.args.context_parallel_degree ] with profiling_context(self, "vLLM.generate"): @@ -347,28 +347,28 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): ) else: completion_ids = [None] * ( - len(all_prompts_text) // self.args.sequence_parallel_degree + len(all_prompts_text) // self.args.context_parallel_degree ) # Broadcast the completions from the main process to all processes completion_ids = broadcast_object_list(completion_ids, from_process=0) - # Determine the appropriate slice based on sequence parallelism - if self.args.sequence_parallel_degree > 1: - # Calculate SP group ID (which group of ranks this rank belongs to) + # Determine the appropriate slice based on context parallelism + if self.args.context_parallel_degree > 1: + # Calculate CP group ID (which group of ranks this rank belongs to) sp_group_id = self.accelerator.process_index // self.local_world_size - # Calculate the start index for this SP group + # Calculate the start index for this CP group sp_group_start = sp_group_id * len(prompts) * self.local_world_size - # All ranks in the same SP group get the same data slice + # All ranks in the same CP group get the same data slice process_slice = slice( sp_group_start, sp_group_start + len(prompts), ) completion_ids = completion_ids[process_slice] else: - # Original behavior for non-sequence parallel case + # Original behavior for non-context parallel case process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), @@ -578,20 +578,20 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): advantages = advantages / (std_grouped_rewards + 1e-4) # Slice to keep only the local part of the data - if self.args.sequence_parallel_degree > 1: - # Calculate SP group ID (which group of ranks this rank belongs to) + if self.args.context_parallel_degree > 1: + # Calculate CP group ID (which group of ranks this rank belongs to) sp_group_id = self.accelerator.process_index // self.local_world_size - # Calculate the start index for this SP group + # Calculate the start index for this CP group sp_group_start = sp_group_id * len(prompts) * self.local_world_size - # All ranks in the same SP group get the same data slice + # All ranks in the same CP group get the same data slice process_slice = slice( sp_group_start, sp_group_start + len(prompts), ) else: - # Original behavior for non-sequence parallel case + # Original behavior for non-context parallel case process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index ade858b46..4c4a8cdbc 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -2,10 +2,10 @@ Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention) package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in -their sequence parallel version of Flash Attention 2. +their context parallel version of Flash Attention 2. We also provide some patches for accelerate functions to prepare the dataloader for -sequence parallelism training. +context parallelism training. """ import inspect @@ -63,15 +63,15 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): def register_ring_attn( - sequence_parallel_degree: int, + context_parallel_degree: int, heads_k_stride: int | None, ring_attn_func: RingAttnFunc | None, ): """Create ring attention group and substitute flash attn with ring flash attn. Args: - sequence_parallel_degree: Sequence parallelism factor. - heads_k_stride: Sequence parallelism K head stride size. Passed through to + context_parallel_degree: Context parallelism factor. + heads_k_stride: Context parallelism K head stride size. Passed through to `varlen_llama3` `ring_flash_attn` implementation. ring_attn_func: `ring_flash_attn` ring attention implemention. If sample packing is enabled, it must be a `varlen` function; otherwise, it must be a @@ -81,17 +81,17 @@ def register_ring_attn( world_size = dist.get_world_size() LOG.info( - "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs" + "Enabling ring attention context parallelism: " + f"each sequence will be processed across {context_parallel_degree} GPUs" ) - # Assign ranks to sequence parallel groups + # Assign ranks to context parallel groups group_assignments = {} - for i in range(world_size // sequence_parallel_degree): + for i in range(world_size // context_parallel_degree): ring_attn_ranks = list( range( - i * sequence_parallel_degree, - (i + 1) * sequence_parallel_degree, + i * context_parallel_degree, + (i + 1) * context_parallel_degree, ) ) group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") @@ -103,7 +103,7 @@ def register_ring_attn( if rank in ring_attn_ranks: set_ring_attn_group(group) - LOG.info(f"Sequence parallel group assignments: {group_assignments}") + LOG.info(f"Context parallel group assignments: {group_assignments}") if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: from ring_flash_attn import substitute_hf_flash_attn @@ -138,7 +138,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None): def patch_prepare_data_loader(): - """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. + """Patch `accelerate.data_loader.prepare_data_loader` to respect the CP degree. Raies: RuntimeError: If source code to patch does not exist. @@ -164,15 +164,15 @@ def patch_prepare_data_loader(): patched_function = namespace["prepare_data_loader"] accelerate.data_loader.prepare_data_loader = patched_function - LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") + LOG.info("Patched accelerate.data_loader.prepare_data_loader for CP support") -def patch_prepare_device_mesh(sequence_parallel_degree: int): +def patch_prepare_device_mesh(context_parallel_degree: int): """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh - that includes sequence parallelism with the specified degree. + that includes context parallelism with the specified degree. Args: - sequence_parallel_degree (int): The degree of sequence parallelism to use. + context_parallel_degree (int): The degree of context parallelism to use. """ def _prepare_device_mesh(self): @@ -187,11 +187,11 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int): ): return self.state.ds_device_mesh - # Create device mesh with sequence parallelism + # Create device mesh with context parallelism world_size = dist.get_world_size() mesh_shape = ( - world_size // sequence_parallel_degree, - sequence_parallel_degree, + world_size // context_parallel_degree, + context_parallel_degree, ) device_ids = list(range(world_size)) @@ -209,5 +209,5 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int): LOG.info( "Successfully patched Accelerator._prepare_device_mesh " - f"with sequence_parallel_degree={sequence_parallel_degree}" + f"with context_parallel_degree={context_parallel_degree}" ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 7003c73ae..d60f432aa 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -33,7 +33,7 @@ from axolotl.loaders import ( load_processor, load_tokenizer, ) -from axolotl.utils.ctx_managers.sequence_parallel import ContextParallelContextManager +from axolotl.utils.ctx_managers.context_parallel import ContextParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except @@ -203,7 +203,7 @@ def execute_training( ) ) - if cfg.sequence_parallel_degree > 1: + if cfg.context_parallel_degree > 1: # Models to enter context parallel manager for models = [trainer.model] if hasattr(trainer, "ref_model") and trainer.ref_model: @@ -216,7 +216,7 @@ def execute_training( ContextParallelContextManager( models=models, backend=backend, - context_parallel_degree=cfg.sequence_parallel_degree, + context_parallel_degree=cfg.context_parallel_degree, gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, heads_k_stride=cfg.heads_k_stride, diff --git a/src/axolotl/utils/ctx_managers/__init__.py b/src/axolotl/utils/ctx_managers/__init__.py index b92bfdf94..15eec49e4 100644 --- a/src/axolotl/utils/ctx_managers/__init__.py +++ b/src/axolotl/utils/ctx_managers/__init__.py @@ -1,6 +1,6 @@ -"""Init for context manager submodule""" +"""Init for context manager submodule.""" -# pylint: disable=unused-import -# flake8: noqa -from .sequence_parallel import ContextParallelContextManager +from .context_parallel.manager import ContextParallelContextManager + +__all__ = ["ContextParallelContextManager"] diff --git a/src/axolotl/utils/ctx_managers/context_parallel.py b/src/axolotl/utils/ctx_managers/context_parallel.py index 01c724f8d..e69de29bb 100644 --- a/src/axolotl/utils/ctx_managers/context_parallel.py +++ b/src/axolotl/utils/ctx_managers/context_parallel.py @@ -1,269 +0,0 @@ -"""Module for Axolotl trainer sequence parallelism manager and utilities""" - -import functools -import inspect -from typing import Literal - -import torch -import torch.distributed as dist -from torch.distributed.tensor.experimental import context_parallel -from torch.utils.hooks import RemovableHandle -from transformers import PreTrainedModel -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.utils import ModelOutput - -from axolotl.monkeypatch.ring_attn import ( - get_ring_attn_group, - patch_prepare_data_loader, - patch_prepare_device_mesh, - register_ring_attn, -) -from axolotl.utils.schemas.enums import RingAttnFunc -from axolotl.utils.ctx_managers.utils import get_context_parallel_manager - - -class ContextParallelContextManager: - """Context manager for context parallelism operations. - - This class provides a context that will automatically apply context parallelism - during model forward passes using a pre-forward hook, and gather outputs from - across the context parallelism group using a post-forward hook. - - Args: - models: List of models to apply context parallelism to pre- and post- forward - hooks. - backend: Which attention backend to use. - context_parallel_degree: Number of processes to split sequences over. - gradient_accumulation_steps: Number of steps to accumulate gradients over. - ring_attn_func: Which ring attention function to use. Currently unused. - heads_k_stride: Context parallelism K head stride size. Passed through to - `varlen_llama3` `ring_flash_attn` implementation. - """ - - def __init__( - self, - models: list[PreTrainedModel], - backend: Literal["sdp_attention", "flash_attention"], - context_parallel_degree: int, - gradient_accumulation_steps: int, - ring_attn_func: RingAttnFunc, - heads_k_stride: int | None, - ): - self.models = models - self.backend = backend - self.context_parallel_degree = context_parallel_degree - self.gradient_accumulation_steps = gradient_accumulation_steps - self.ring_attn_func = ring_attn_func - self.heads_k_stride = heads_k_stride - self._register_ring_attn() - - # Set distributed info for local rank - self.process_group = get_ring_attn_group() - self.local_rank = dist.get_rank(self.process_group) - self.local_world_size = dist.get_world_size(self.process_group) - - # Will store hook handles for removal - self.hook_handles: list[RemovableHandle] = [] - - # Store original sequence length and padding information - self.original_seq_len = 0 - self.pad_len = 0 - - # Create a partially applied version of the apply_sequence_parallelism function - self.apply_context_parallelism = functools.partial( - apply_context_parallelism, - local_rank=self.local_rank, - local_world_size=self.local_world_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - ring_attn_func=self.ring_attn_func, - ) - - # SPDA CP initialization - world_size = dist.get_world_size() - mesh_shape = ( - world_size // self.context_parallel_degree, - self.context_parallel_degree, - ) - world_mesh = dist.DeviceMesh( - "cuda", - torch.tensor(list(range(world_size))).reshape(mesh_shape), - mesh_dim_names=("dp", "cp"), - ) - self.context_parallel_managers = [] - for model in models: - ctx_manager = get_context_parallel_manager( - enabled=self.context_parallel_degree > 1, - world_mesh=world_mesh, - model=model, - ) - self.context_parallel_managers.append(ctx_manager) - - def __enter__(self): - self._register_model_hooks() - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # Remove all hooks - for handle in self.hook_handles: - handle.remove() - self.hook_handles = [] - - # TODO(djsaunde): Un-patch attention and accelerate functions (low priority) - - def _register_ring_attn(self): - if self.backend == "flash_attention": - # Initialize ring attn for context parallelism - register_ring_attn( - sequence_parallel_degree=self.context_parallel_degree, - heads_k_stride=self.heads_k_stride, - ring_attn_func=self.ring_attn_func, - ) - else: - stack.enter_context(context_parallel(mesh=mesh)) - - # Patches for accelerate functionality - patch_prepare_data_loader() - patch_prepare_device_mesh( - sequence_parallel_degree=self.context_parallel_degree - ) - - def _register_model_hooks(self): - # Forward pre-hook to apply sequence parallelism - def cp_flash_pre_hook(_, args, kwargs): - # Get parameter names from the model's forward function - forward_params = list( - inspect.signature(self.models[0].forward).parameters.keys() - ) - - updated_kwargs = kwargs.copy() - for i, arg in enumerate(args): - if i < len(forward_params): - updated_kwargs[forward_params[i]] = arg - - # Any excess positional arguments are kept as-is - remaining_args = args[len(forward_params) :] - - # Apply sequence parallelism to updated kwargs - updated_kwargs, self.original_seq_len, self.pad_len = ( - self.apply_context_parallelism(updated_kwargs) - ) - - return remaining_args, updated_kwargs - - # Forward post-hook to gather outputs - def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput: - # Gather the sharded outputs - output = self._gather_outputs(output) - - # Remove padding if it was added - if self.pad_len > 0: - for key, value in output.items(): - if isinstance(value, torch.Tensor) and value.dim() > 1: - if value.size(1) == self.original_seq_len + self.pad_len: - # Slice to remove padding - output[key] = value[:, : self.original_seq_len].contiguous() - - return output - - def cp_sdpa_pre_hook(_, args, kwargs): - with self.context_parallel_managers[?](list(inputs.values())): - - - # Register both hooks - for model in self.models: - self.hook_handles.append( - model.register_forward_pre_hook( - cp_flash_pre_hook, with_kwargs=True - ) - ) - self.hook_handles.append( - model.register_forward_hook(cp_flash_post_hook) - ) - - def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: - """Gather sharded outputs from all ranks and reconstruct the full tensor.""" - for key, value in output.items(): - if isinstance(value, torch.Tensor) and value.dim() > 1: - output[key] = AllGatherWithGrad.apply(value, self.process_group) - - return output - - -class AllGatherWithGrad(torch.autograd.Function): - """Custom autograd function for all-gather to preserve gradients.""" - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - input_tensor: torch.Tensor, - group: dist.ProcessGroup, - ) -> torch.Tensor: - """ - Forward pass of all-gather of data with sequence dimension. - - Args: - ctx: `torch.autograd` function context. - input_tensor: Tensor from model output with sequence dimension. - group: `torch.distributed` process group. - - Returns: - Tensor from gathering the `input_tensor` from across the process group and - concatenating along the sequence dimension. - """ - ctx.group = group - ctx.rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - - # Gather shape metadata - local_shape = torch.tensor(list(input_tensor.shape), device=input_tensor.device) - all_shapes = [torch.zeros_like(local_shape) for _ in range(world_size)] - dist.all_gather(all_shapes, local_shape, group=group) - - # Store sequence lengths for backward pass - seq_lens = [int(shape[1].item()) for shape in all_shapes] - ctx.seq_lens = seq_lens - - # Perform all_gather operation - gathered = [ - torch.zeros( - tuple(shape.tolist()), - dtype=input_tensor.dtype, - device=input_tensor.device, - ) - for shape in all_shapes - ] - dist.all_gather(gathered, input_tensor, group=group) - - # Concatenate tensors along sequence dimension - result = torch.cat(gathered, dim=1) - - return result - - @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor - ) -> tuple[torch.Tensor, None]: - """ - Backward pass for all-gather operation. - - Extracts the gradient slice corresponding to this rank's original input - from the full gradient tensor. - - Args: - ctx: `torch.autograd` function context. - grad_output: Gradient from subsequent layers with respect to the - concatenated output tensor. - - Returns: - Tuple containing the gradient slice for this rank's input tensor and `None` - for the process group parameter which doesn't require gradients. - """ - rank = ctx.rank - seq_lens = ctx.seq_lens - - # Extract gradient for this rank's chunk - offset = sum(seq_lens[:rank]) - grad_slice = grad_output[:, offset : offset + seq_lens[rank]].contiguous() - - return grad_slice, None diff --git a/src/axolotl/utils/ctx_managers/utils.py b/src/axolotl/utils/ctx_managers/context_parallel/distributed.py similarity index 58% rename from src/axolotl/utils/ctx_managers/utils.py rename to src/axolotl/utils/ctx_managers/context_parallel/distributed.py index a2a9ba725..c2aa603c6 100644 --- a/src/axolotl/utils/ctx_managers/utils.py +++ b/src/axolotl/utils/ctx_managers/context_parallel/distributed.py @@ -1,3 +1,37 @@ +# BSD 3-Clause License + +# Copyright 2024 Meta + +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice,this list +# of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, this +# list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its contributors may +# be used to endorse or promote products derived from this software without specific +# prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT +# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. + +""" +Distributed utils for SDPA context parallel implementation. Slightly modified from +https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5c2/torchtune/training/_distributed.py. +""" + import contextlib from typing import Callable, Generator, Optional, Union @@ -42,7 +76,6 @@ def _get_sdpa_context() -> ( def get_context_parallel_manager( *, - enabled: bool = False, world_mesh: torch.distributed.DeviceMesh, model: PreTrainedModel, ) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]: @@ -64,16 +97,16 @@ def get_context_parallel_manager( ValueError: if enabled is True but world_mesh does not contain a "cp" dimension """ - if enabled and "cp" not in world_mesh.mesh_dim_names: + if "cp" not in world_mesh.mesh_dim_names: raise ValueError( "Context parallel is enabled but no context parallel device mesh is provided." ) # TODO: context parallel for multimodal models requires extra work - if enabled and not isinstance(model, TransformerDecoder): + if not isinstance(model, TransformerDecoder): raise ValueError("Context parallel is only supported for text models") # TODO: this is a hacky proxy for whether we use flex for chunked attention # remove this once flex is supported - if enabled and any([layer.mask_mod is not None for layer in model.layers]): + if any([layer.mask_mod is not None for layer in model.layers]): raise ValueError("Context parallel with flex attention is not yet supported") model_buffers = list(model.buffers()) @@ -81,18 +114,17 @@ def get_context_parallel_manager( def context(model_inputs: list[torch.Tensor]): # Create context parallel context if enabled cp_context = None - if enabled and any([isinstance(input, BlockMask) for input in model_inputs]): + if any([isinstance(input, BlockMask) for input in model_inputs]): raise ValueError( "Context parallel with flex attention is not yet supported" ) - if enabled: - set_rotate_method("allgather") - cp_context = context_parallel( - world_mesh["cp"], - buffers=model_inputs + model_buffers, - buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers), - no_restore_buffers=set(model_inputs), - ) + set_rotate_method("allgather") + cp_context = context_parallel( + world_mesh["cp"], + buffers=model_inputs + model_buffers, + buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers), + no_restore_buffers=set(model_inputs), + ) # Create and enter the train context with the optional cp_context sdpa_context = _get_sdpa_context() diff --git a/src/axolotl/utils/ctx_managers/context_parallel/manager.py b/src/axolotl/utils/ctx_managers/context_parallel/manager.py new file mode 100644 index 000000000..bdc98e1cd --- /dev/null +++ b/src/axolotl/utils/ctx_managers/context_parallel/manager.py @@ -0,0 +1,196 @@ +"""Module for Axolotl trainer context parallelism manager and utilities.""" + +import functools +import inspect +from typing import Callable, Literal + +import torch +import torch.distributed as dist +from torch.utils.hooks import RemovableHandle +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import ModelOutput + +from axolotl.monkeypatch.ring_attn import ( + get_ring_attn_group, + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, +) +from axolotl.utils.ctx_managers.context_parallel.utils import AllGatherWithGrad, apply_context_parallelism +from axolotl.utils.ctx_managers.utils import get_context_parallel_manager +from axolotl.utils.schemas.enums import RingAttnFunc + + +class ContextParallelContextManager: + """Context manager for context parallelism operations. + + This class provides a context that will automatically apply context parallelism + during model forward passes using a pre-forward hook, and gather outputs from + across the context parallelism group using a post-forward hook. + + Args: + models: List of models to apply context parallelism to pre- and post- forward + hooks. + backend: Which attention backend to use. + context_parallel_degree: Number of processes to split sequences over. + gradient_accumulation_steps: Number of steps to accumulate gradients over. + ring_attn_func: Which ring attention function to use. Currently unused. + heads_k_stride: Context parallelism K head stride size. Passed through to + `varlen_llama3` `ring_flash_attn` implementation. + """ + + def __init__( + self, + models: list[PreTrainedModel], + backend: Literal["sdp_attention", "flash_attention"], + context_parallel_degree: int, + gradient_accumulation_steps: int, + ring_attn_func: RingAttnFunc, + heads_k_stride: int | None, + ): + self.models = models + self.backend = backend + self.context_parallel_degree = context_parallel_degree + self.gradient_accumulation_steps = gradient_accumulation_steps + self.ring_attn_func = ring_attn_func + self.heads_k_stride = heads_k_stride + self._register_ring_attn() + + # Set distributed info for local rank + self.process_group = get_ring_attn_group() + self.local_rank = dist.get_rank(self.process_group) + self.local_world_size = dist.get_world_size(self.process_group) + + # Will store hook handles for removal + self.hook_handles: list[RemovableHandle] = [] + + # Store original sequence length and padding information + self.original_seq_len = 0 + self.pad_len = 0 + + # Create a partially applied version of the apply_context_parallelism function + self.apply_context_parallelism = functools.partial( + apply_context_parallelism, + local_rank=self.local_rank, + local_world_size=self.local_world_size, + gradient_accumulation_steps=self.gradient_accumulation_steps, + ring_attn_func=self.ring_attn_func, + ) + + # SPDA CP initialization + world_size = dist.get_world_size() + mesh_shape = ( + world_size // self.context_parallel_degree, + self.context_parallel_degree, + ) + world_mesh = dist.DeviceMesh( + "cuda", + torch.tensor(list(range(world_size))).reshape(mesh_shape), + mesh_dim_names=("dp", "cp"), + ) + self.context_parallel_managers = [] + for model in models: + ctx_manager = get_context_parallel_manager( + enabled=self.context_parallel_degree > 1, + world_mesh=world_mesh, + model=model, + ) + self.context_parallel_managers.append(ctx_manager) + + def __enter__(self): + self._register_model_hooks() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Remove all hooks + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] + + # TODO(djsaunde): Un-patch attention and accelerate functions (low priority) + + def _register_ring_attn(self): + if self.backend == "flash_attention": + # Initialize ring attn for context parallelism + register_ring_attn( + context_parallel_degree=self.context_parallel_degree, + heads_k_stride=self.heads_k_stride, + ring_attn_func=self.ring_attn_func, + ) + + # Patches for accelerate functionality + patch_prepare_data_loader() + patch_prepare_device_mesh(context_parallel_degree=self.context_parallel_degree) + + def _register_model_hooks(self): + # Forward pre-hook to apply context parallelism + def cp_flash_pre_hook(_, args, kwargs): + # Get parameter names from the model's forward function + forward_params = list( + inspect.signature(self.models[0].forward).parameters.keys() + ) + + updated_kwargs = kwargs.copy() + for i, arg in enumerate(args): + if i < len(forward_params): + updated_kwargs[forward_params[i]] = arg + + # Any excess positional arguments are kept as-is + remaining_args = args[len(forward_params) :] + + # Apply context parallelism to updated kwargs + updated_kwargs, self.original_seq_len, self.pad_len = ( + self.apply_context_parallelism(updated_kwargs) + ) + + return remaining_args, updated_kwargs + + # Forward post-hook to gather outputs + def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput: + # Gather the sharded outputs + output = self._gather_outputs(output) + + # Remove padding if it was added + if self.pad_len > 0: + for key, value in output.items(): + if isinstance(value, torch.Tensor) and value.dim() > 1: + if value.size(1) == self.original_seq_len + self.pad_len: + # Slice to remove padding + output[key] = value[:, : self.original_seq_len].contiguous() + + return output + + # Register both hooks + for i, model in enumerate(self.models): + if self.backend == "flash_attention": + self.hook_handles.append( + model.register_forward_pre_hook(cp_flash_pre_hook, with_kwargs=True) + ) + self.hook_handles.append( + model.register_forward_hook(cp_flash_post_hook) + ) + else: + + def make_sdpa_pre_hook(model_idx: int) -> Callable: + def cp_sdpa_pre_hook(_, args, kwargs): + with self.context_parallel_managers[model_idx]: + return args, kwargs + + return cp_sdpa_pre_hook + + self.hook_handles.append( + model.register_forward_pre_hook( + make_sdpa_pre_hook(i), with_kwargs=True + ) + ) + + def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: + """Gather sharded outputs from all ranks and reconstruct the full tensor.""" + for key, value in output.items(): + if isinstance(value, torch.Tensor) and value.dim() > 1: + output[key] = AllGatherWithGrad.apply(value, self.process_group) + + return output + diff --git a/src/axolotl/utils/ctx_managers/context_parallel/utils.py b/src/axolotl/utils/ctx_managers/context_parallel/utils.py index 3652250f0..03e559978 100644 --- a/src/axolotl/utils/ctx_managers/context_parallel/utils.py +++ b/src/axolotl/utils/ctx_managers/context_parallel/utils.py @@ -1,6 +1,7 @@ """Utils for context parallel context manager.""" import torch +import torch.distributed as dist from axolotl.monkeypatch.ring_attn.patch import update_ring_attn_params from axolotl.utils.schemas.enums import RingAttnFunc @@ -14,7 +15,7 @@ def apply_context_parallelism( local_world_size: int, gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument -) -> tuple[dict[str, torch.Tensor], int, iwnt]: +) -> tuple[dict[str, torch.Tensor], int, int]: """ Apply context parallelism slicing to a batch. @@ -142,4 +143,83 @@ def apply_context_parallelism( batch["labels"] != -100 ).sum() * gradient_accumulation_steps - return batch, original_seq_len, pad_len \ No newline at end of file + return batch, original_seq_len, pad_len + + +class AllGatherWithGrad(torch.autograd.Function): + """Custom autograd function for all-gather to preserve gradients.""" + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + ) -> torch.Tensor: + """ + Forward pass of all-gather of data with sequence dimension. + + Args: + ctx: `torch.autograd` function context. + input_tensor: Tensor from model output with sequence dimension. + group: `torch.distributed` process group. + + Returns: + Tensor from gathering the `input_tensor` from across the process group and + concatenating along the sequence dimension. + """ + ctx.group = group + ctx.rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + # Gather shape metadata + local_shape = torch.tensor(list(input_tensor.shape), device=input_tensor.device) + all_shapes = [torch.zeros_like(local_shape) for _ in range(world_size)] + dist.all_gather(all_shapes, local_shape, group=group) + + # Store sequence lengths for backward pass + seq_lens = [int(shape[1].item()) for shape in all_shapes] + ctx.seq_lens = seq_lens + + # Perform all_gather operation + gathered = [ + torch.zeros( + tuple(shape.tolist()), + dtype=input_tensor.dtype, + device=input_tensor.device, + ) + for shape in all_shapes + ] + dist.all_gather(gathered, input_tensor, group=group) + + # Concatenate tensors along sequence dimension + result = torch.cat(gathered, dim=1) + + return result + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor + ) -> tuple[torch.Tensor, None]: + """ + Backward pass for all-gather operation. + + Extracts the gradient slice corresponding to this rank's original input + from the full gradient tensor. + + Args: + ctx: `torch.autograd` function context. + grad_output: Gradient from subsequent layers with respect to the + concatenated output tensor. + + Returns: + Tuple containing the gradient slice for this rank's input tensor and `None` + for the process group parameter which doesn't require gradients. + """ + rank = ctx.rank + seq_lens = ctx.seq_lens + + # Extract gradient for this rank's chunk + offset = sum(seq_lens[:rank]) + grad_slice = grad_output[:, offset : offset + seq_lens[rank]].contiguous() + + return grad_slice, None diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index ec1a4e23d..22a597544 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -262,7 +262,7 @@ class AxolotlInputConfig( val_set_size: float | None = Field(default=0.0) - sequence_parallel_degree: int | None = None + context_parallel_degree: int | None = None heads_k_stride: int | None = None ring_attn_func: RingAttnFunc | None = None @@ -1179,39 +1179,39 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod - def check_grpo_liger_sequence_parallel(cls, data): + def check_grpo_liger_context_parallel(cls, data): if ( data.get("rl") == "grpo" and data.get("trl", {}) and data.get("trl").get("use_liger_loss") - and data.get("sequence_parallel_degree", 1) > 1 + and data.get("context_parallel_degree", 1) > 1 ): - raise ValueError("GRPO + SP + Liger not currently supported") + raise ValueError("GRPO + CP + Liger not currently supported") return data @model_validator(mode="after") - def check_sequence_parallel_degree(self): - if not self.sequence_parallel_degree: - self.sequence_parallel_degree = 1 - elif self.sequence_parallel_degree > 1: + def check_context_parallel_degree(self): + if not self.context_parallel_degree: + self.context_parallel_degree = 1 + elif self.context_parallel_degree > 1: import torch world_size = torch.cuda.device_count() - if not world_size >= self.sequence_parallel_degree: + if not world_size >= self.context_parallel_degree: raise ValueError( f"World size ({world_size}) must be greater " - f"than or equal to SP degree ({self.sequence_parallel_degree})" + f"than or equal to CP degree ({self.context_parallel_degree})" ) - if not world_size % self.sequence_parallel_degree == 0: + if not world_size % self.context_parallel_degree == 0: raise ValueError( - f"SP degree ({self.sequence_parallel_degree}) " + f"SP degree ({self.context_parallel_degree}) " f"must evenly divide world size ({world_size})" ) if not (self.flash_attention or self.sdp_attention): raise ValueError( "flash_attention: true or sdp_attention: true " - "must be set with sequence_parallel_degree > 1" + "must be set with context_parallel_degree > 1" ) if self.sample_packing and self.micro_batch_size > 1: @@ -1225,17 +1225,17 @@ class AxolotlInputConfig( import ring_flash_attn # noqa: F401 # pylint:disable=unused-import except ImportError as exception: raise ImportError( - "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " + "context_parallel_degree > 1 but ring_flash_attn is not installed. " "Please install it with `pip install axolotl[ring-flash-attn] " "or `pip install ring-flash-attn>=0.1.4`." ) from exception - # TODO: monkeypatch / callback to average losses correctly across SP ranks - # / fix gradient scaling across SP ranks. Losses, grads should be scaled + # TODO: monkeypatch / callback to average losses correctly across CP ranks + # / fix gradient scaling across CP ranks. Losses, grads should be scaled # according to the proportion of non-padding tokens per rank. LOG.warning( - "Sequence parallelism (SP) is enabled with " - f"sequence_parallel_degree={self.sequence_parallel_degree}. " + "Context parallelism (SP) is enabled with " + f"context_parallel_degree={self.context_parallel_degree}. " "Please note that logged losses may differ slightly to the non-SP " "losses due to transformers Trainer implementation details. " "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " @@ -1246,7 +1246,7 @@ class AxolotlInputConfig( @model_validator(mode="after") def validate_ring_attn_func(self): - if getattr(self, "sequence_parallel_degree", 1) == 1: + if getattr(self, "context_parallel_degree", 1) == 1: return self if self.ring_attn_func is not None: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 67f590a37..61a594e42 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): - 1 ) * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_degree ) LOG.debug( f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" @@ -479,7 +479,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): # on the agreed on value for sample_packing_eff_est total_num_steps = int( math.floor( - data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree + data_loader_len * cfg.num_epochs * cfg.context_parallel_degree ) ) @@ -502,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): math.ceil( len(train_dataset) * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_degree / cfg.batch_size ) ) diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index cde7b74ce..d45e45507 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -64,7 +64,7 @@ def fixture_base_cfg(): "dataloader_num_workers": 1, "dataloader_pin_memory": True, "dataloader_prefetch_factor": 2, - "sequence_parallel_degree": 1, + "context_parallel_degree": 1, # Dtype "fp16": False, "bf16": False, diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py index e90def2b7..42ea3d0c2 100644 --- a/tests/e2e/multigpu/patched/test_sp.py +++ b/tests/e2e/multigpu/patched/test_sp.py @@ -1,4 +1,4 @@ -"""E2E tests for sequence parallelism""" +"""E2E tests for context parallelism""" from pathlib import Path @@ -12,10 +12,10 @@ from axolotl.utils.dict import DictDefault from ...utils import check_tensorboard -class TestSequenceParallelism: - """Test case for training with sequence parallelism enabled""" +class TestContextParallelism: + """Test case for training with context parallelism enabled""" - def _run_sequence_parallel_test( + def _run_context_parallel_test( self, temp_dir, sample_packing=True, @@ -24,7 +24,7 @@ class TestSequenceParallelism: ring_attn_func=None, threshold=2.0, ): - """Helper method to run sequence parallel tests with different configurations""" + """Helper method to run context parallel tests with different configurations""" cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -66,7 +66,7 @@ class TestSequenceParallelism: "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, - "sequence_parallel_degree": 2, + "context_parallel_degree": 2, "ring_attn_func": ring_attn_func, } ) @@ -109,7 +109,7 @@ class TestSequenceParallelism: "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func", ], ) - def test_sequence_parallel_training( + def test_context_parallel_training( self, temp_dir, sample_packing, @@ -118,8 +118,8 @@ class TestSequenceParallelism: ring_attn_func, threshold, ): - """Test sequence parallel training with different configurations""" - self._run_sequence_parallel_test( + """Test context parallel training with different configurations""" + self._run_context_parallel_test( temp_dir, sample_packing=sample_packing, micro_batch_size=micro_batch_size, diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index 8ea2e3ce4..f09fba920 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -296,7 +296,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_linear": True, - "sequence_parallel_degree": 2, + "context_parallel_degree": 2, "flash_attention": True, "sequence_len": 1024, "special_tokens": { diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 28146aca7..de840591c 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -1,4 +1,4 @@ -"""Tests for sequence parallelism functionality.""" +"""Tests for context parallelism functionality.""" # pylint: disable=redefined-outer-name,unused-argument @@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import ( register_ring_attn, set_ring_attn_group, ) -from axolotl.utils.ctx_managers.sequence_parallel import apply_context_parallelism +from axolotl.utils.ctx_managers.context_parallel import apply_context_parallelism from axolotl.utils.dict import DictDefault from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.trl import TRLConfig @@ -54,8 +54,8 @@ def fixture_cfg(): @pytest.fixture -def sequence_parallel_batch(): - """Create a test batch for sequence parallelism tests.""" +def context_parallel_batch(): + """Create a test batch for context parallelism tests.""" batch_size = 1 seq_len = 8 @@ -110,7 +110,7 @@ class TestRingAttention: # Call register_ring_attn with size 4 register_ring_attn( - sequence_parallel_degree=4, + context_parallel_degree=4, heads_k_stride=1, ring_attn_func=RingAttnFunc.VARLEN_LLAMA3, ) @@ -126,7 +126,7 @@ class TestRingAttention: class TestConfigValidation: - """Tests for validating sequence parallelism configurations.""" + """Tests for validating context parallelism configurations.""" @pytest.fixture(autouse=True) def setup_mocks(self, monkeypatch): @@ -155,24 +155,24 @@ class TestConfigValidation: [ # Valid configuration ( - {"sequence_parallel_degree": 2, "flash_attention": True}, - {"sequence_parallel_degree": 2, "flash_attention": True}, + {"context_parallel_degree": 2, "flash_attention": True}, + {"context_parallel_degree": 2, "flash_attention": True}, True, None, ), - # Default sequence_parallel_degree - ({}, {"sequence_parallel_degree": 1}, True, None), - # Invalid: sequence_parallel_degree > 1 without flash_attention + # Default context_parallel_degree + ({}, {"context_parallel_degree": 1}, True, None), + # Invalid: context_parallel_degree > 1 without flash_attention ( - {"sequence_parallel_degree": 2, "flash_attention": False}, + {"context_parallel_degree": 2, "flash_attention": False}, None, False, "flash_attention: true must be set", ), - # Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1 + # Invalid: context_parallel_degree > 1 with sample_packing and micro_batch_size > 1 ( { - "sequence_parallel_degree": 2, + "context_parallel_degree": 2, "flash_attention": True, "sample_packing": True, "micro_batch_size": 2, @@ -185,32 +185,32 @@ class TestConfigValidation: # Valid: Basic GRPO config ( { - "sequence_parallel_degree": 2, + "context_parallel_degree": 2, "flash_attention": True, "micro_batch_size": 2, "trl": {"use_liger_loss": True}, }, { - "sequence_parallel_degree": 2, + "context_parallel_degree": 2, "flash_attention": True, "micro_batch_size": 2, "trl": TRLConfig(use_liger_loss=True), }, True, - "GRPO + SP + Liger not currently supported", + "GRPO + CP + Liger not currently supported", ), # Invalid: GRPO config with Liger loss ( { "rl": "grpo", - "sequence_parallel_degree": 2, + "context_parallel_degree": 2, "flash_attention": True, "micro_batch_size": 2, "trl": {"use_liger_loss": True}, }, None, False, - "GRPO + SP + Liger not currently supported", + "GRPO + CP + Liger not currently supported", ), ], ids=[ @@ -222,10 +222,10 @@ class TestConfigValidation: "grpo_with_liger_loss", ], ) - def test_sequence_parallel_config_validation( + def test_context_parallel_config_validation( self, base_cfg, config_updates, expected_values, should_pass, error_msg ): - """Test various sequence parallelism configuration scenarios.""" + """Test various context parallelism configuration scenarios.""" from axolotl.utils.schemas.config import AxolotlInputConfig # Apply updates to base config @@ -261,7 +261,7 @@ class TestConfigValidation: # Apply updates to base config cfg = base_cfg | { - "sequence_parallel_degree": 2, + "context_parallel_degree": 2, "flash_attention": True, "sample_packing": sample_packing, } @@ -281,7 +281,7 @@ class TestConfigValidation: # Invalid configuration with invalid ring_attn_func cfg = base_cfg | { - "sequence_parallel_degree": 2, + "context_parallel_degree": 2, "flash_attention": True, "ring_attn_func": "INVALID_FUNC", } @@ -294,8 +294,8 @@ class TestConfigValidation: assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value) -class TestApplySequenceParallelism: - """Tests for the apply_sequence_parallelism function.""" +class TestApplyContextParallelism: + """Tests for the apply_context_parallelism function.""" @pytest.fixture(autouse=True) def mock_distributed(self, monkeypatch): @@ -324,12 +324,12 @@ class TestApplySequenceParallelism: ) @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch): + def test_world_size_one(self, mock_get_ring_attn_group, context_parallel_batch): """Test that function returns original batch when world size is 1.""" mock_get_ring_attn_group.return_value = 0 result, _, _ = apply_context_parallelism( - batch=sequence_parallel_batch, + batch=context_parallel_batch, local_rank=0, local_world_size=1, gradient_accumulation_steps=1, @@ -337,14 +337,14 @@ class TestApplySequenceParallelism: ) # Should return the original batch unchanged - assert result == sequence_parallel_batch + assert result == context_parallel_batch @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch): + def test_batch_ring_rank0(self, mock_get_ring_attn_group, context_parallel_batch): """Test BATCH_RING sharding for rank 0 in a 2-process group.""" mock_get_ring_attn_group.return_value = 0 - batch = sequence_parallel_batch + batch = context_parallel_batch seq_len = batch["input_ids"].size(1) result, _, _ = apply_context_parallelism( @@ -366,11 +366,11 @@ class TestApplySequenceParallelism: ) @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch): + def test_batch_ring_rank1(self, mock_get_ring_attn_group, context_parallel_batch): """Test BATCH_RING sharding for rank 1 in a 2-process group.""" mock_get_ring_attn_group.return_value = 0 - batch = sequence_parallel_batch + batch = context_parallel_batch seq_len = batch["input_ids"].size(1) original_input_ids = batch["input_ids"].clone() @@ -386,14 +386,14 @@ class TestApplySequenceParallelism: assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :]) # TODO(djsaunde): add back once implemented. - # def test_batch_zigzag(self, sequence_parallel_batch): + # def test_batch_zigzag(self, context_parallel_batch): # """Test BATCH_ZIGZAG sharding pattern.""" - # batch = sequence_parallel_batch + # batch = context_parallel_batch # original_input_ids = batch["input_ids"].clone() # seq_len = batch["input_ids"].size(1) # # Test rank 0 - # result_rank0 = apply_sequence_parallelism( + # result_rank0 = apply_context_parallelism( # batch={k: v.clone() for k, v in batch.items()}, # local_rank=0, # local_world_size=2, @@ -401,7 +401,7 @@ class TestApplySequenceParallelism: # ) # # Test rank 1 - # result_rank1 = apply_sequence_parallelism( + # result_rank1 = apply_context_parallelism( # batch={k: v.clone() for k, v in batch.items()}, # local_rank=1, # local_world_size=2, @@ -430,12 +430,12 @@ class TestApplySequenceParallelism: @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") def test_partial_application( - self, mock_get_ring_attn_group, sequence_parallel_batch + self, mock_get_ring_attn_group, context_parallel_batch ): """Test that we can create a partially applied version of the function.""" mock_get_ring_attn_group.return_value = 0 - batch = sequence_parallel_batch + batch = context_parallel_batch original_input_ids = batch["input_ids"].clone() # Create a partially applied function @@ -457,12 +457,10 @@ class TestApplySequenceParallelism: original_input_ids[:, : original_input_ids.shape[1] // 2], ) - def test_missing_position_ids(self, sequence_parallel_batch): + def test_missing_position_ids(self, context_parallel_batch): """Test handling of batch without position_ids.""" # Create a batch without position_ids - batch = { - k: v for k, v in sequence_parallel_batch.items() if k != "position_ids" - } + batch = {k: v for k, v in context_parallel_batch.items() if k != "position_ids"} original_input_ids = batch["input_ids"].clone() # This should run without error even though position_ids is missing