diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index b98206135..d1933a145 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file: ```yaml # Set to a divisor (> 1) of the number of GPUs available -sequence_parallel_degree: 4 # Split sequences across 4 GPUs +context_parallel_size: 4 # Split sequences across 4 GPUs # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 # Optional; one of "varlen_llama3" or "batch_ring". Defaults to @@ -30,7 +30,7 @@ heads_k_stride: 1 ring_attn_func: ``` -The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: +The `context_parallel_size` should be a divisor of the total number of GPUs. For example: - With 8 GPUs, valid values would be 2, 4, or 8 - With 4 GPUs, valid values would be 2 or 4 @@ -66,7 +66,7 @@ sequence_len: 8192 ... -sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU +context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 # Optional; one of "varlen_llama3" or "batch_ring". Defaults to @@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality. ## Effect on Batch Size -When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: +When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because: -- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) +- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence) - The number of batches processed per step decreases For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step -- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) +- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 diff --git a/requirements.txt b/requirements.txt index 6c167d3fa..a6659a972 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,9 +13,9 @@ packaging==23.2 huggingface_hub>=0.33.0 peft==0.16.0 -transformers==4.53.2 +transformers @ git+https://github.com/winglian/transformers.git@ndp tokenizers>=0.21.1 -accelerate==1.9.0 +accelerate @ git+https://github.com/SalmanMohammadi/accelerate.git@device_mesh_parallelism_config datasets==4.0.0 deepspeed>=0.17.0 trl==0.19.1 diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index d639b3aee..3a26c1730 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -70,7 +70,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: load_in_8bit=False, load_in_4bit=False, flash_attention=False, - sequence_parallel_degree=None, + context_parallel_size=None, deepspeed=None, fsdp=None, fsdp_config=None, diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index e60b0e958..8cc6eeebf 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl is RLType.GRPO: trainer_cls = GRPOStrategy.get_trainer_class( - sequence_parallel=self.cfg.sequence_parallel_degree > 1 + sequence_parallel=self.cfg.context_parallel_size > 1 ) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 2c6eb8c6f..0910f0b21 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -82,8 +82,8 @@ class GRPOStrategy: grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print - if cfg.sequence_parallel_degree > 1: - grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree + if cfg.context_parallel_size > 1: + grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size if trl.reward_weights: grpo_args_kwargs["reward_weights"] = trl.reward_weights diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index 5c8b1a33b..2ea52998e 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): """Axolotl GRPO Config for GRPO training""" - sequence_parallel_degree: int | None = None + context_parallel_size: int | None = None diff --git a/src/axolotl/core/trainers/grpo/sampler.py b/src/axolotl/core/trainers/grpo/sampler.py index ebc6e19e2..df679a6d2 100644 --- a/src/axolotl/core/trainers/grpo/sampler.py +++ b/src/axolotl/core/trainers/grpo/sampler.py @@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): - Data is properly distributed across SP groups. In the table below, the values represent dataset indices. Each SP group has - `sequence_parallel_degree = 2` GPUs working together on the same data. There are 2 + `context_parallel_size = 2` GPUs working together on the same data. There are 2 SP groups (SP0 and SP1), with `world_size = 4` total GPUs. Sequence Parallel Groups @@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): rank: Rank of current process. batch_size: Number of samples per batch. repeat_count: How many times to repeat the full sampling process. - sequence_parallel_degree: Number of ranks in a sequence parallel group. + context_parallel_size: Number of ranks in a sequence parallel group. shuffle: Whether to shuffle the dataset. seed: Random seed for shuffling. drop_last: Whether to drop the last incomplete batch. @@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): rank: int, batch_size: int = 1, repeat_count: int = 1, - sequence_parallel_degree: int = 1, + context_parallel_size: int = 1, shuffle: bool = True, seed: int = 0, drop_last: bool = False, @@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler): self.rank = rank # Sequence parallelism parameters - self.sequence_parallel_degree = sequence_parallel_degree - self.num_sp_groups = world_size // sequence_parallel_degree - self.sp_group_id = rank // sequence_parallel_degree + self.context_parallel_size = context_parallel_size + self.num_sp_groups = world_size // context_parallel_size + self.sp_group_id = rank // context_parallel_size # Adjust dataset size for distributed sampling self.num_samples = len(self.dataset) diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 70b3cf3b5..1a053497e 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -100,7 +100,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # Get number of SP groups (number of processes divided by SP degree) num_processes = self.accelerator.num_processes - num_sp_groups = num_processes // self.args.sequence_parallel_degree + num_sp_groups = num_processes // self.args.context_parallel_size # Calculate batch size per SP group (not per process) sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups @@ -130,7 +130,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): if self.num_generations not in possible_values: raise ValueError( - f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), " + f"With sequence parallelism (degree {self.args.context_parallel_size}), " f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " f"must be evenly divisible by the number of generations per prompt " f"({self.num_generations}). Given the current eval batch size, " @@ -167,9 +167,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): rank=self.rank, batch_size=effective_batch_size // self.num_generations - // self.args.sequence_parallel_degree, + // self.args.context_parallel_size, repeat_count=self.num_iterations * self.args.gradient_accumulation_steps, - sequence_parallel_degree=self.args.sequence_parallel_degree, + context_parallel_size=self.args.context_parallel_size, shuffle=True, seed=self.args.seed, drop_last=True, @@ -235,7 +235,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., # slice each batch along the sequence dimension). - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: return dataloader # Otherwise prepare with accelerator @@ -308,18 +308,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # Generate completions using vLLM: gather all prompts and use them in a single call in the main process all_prompts_text = gather_object(prompts_text) if self.accelerator.is_main_process: - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: # Calculate sequence parallel group information world_size = self.accelerator.num_processes - sequence_parallel_degree = self.args.sequence_parallel_degree - num_sp_groups = world_size // sequence_parallel_degree + context_parallel_size = self.args.context_parallel_size + num_sp_groups = world_size // context_parallel_size # Since processes in the same SP group have the same prompts, we need to ensure # we only take one copy of each prompt from each SP group ordered_set_of_prompts = [] for sp_group_id in range(num_sp_groups): # Get the first process from each SP group (typically the group leader) - group_leader_rank = sp_group_id * sequence_parallel_degree + group_leader_rank = sp_group_id * context_parallel_size # Extract prompts from this SP group, accounting for num_generations duplicates # We only need prompts from one rank in each SP group @@ -335,7 +335,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. ordered_set_of_prompts = all_prompts_text[ - :: self.num_generations * self.args.sequence_parallel_degree + :: self.num_generations * self.args.context_parallel_size ] with profiling_context(self, "vLLM.generate"): @@ -352,14 +352,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): ) else: completion_ids = [None] * ( - len(all_prompts_text) // self.args.sequence_parallel_degree + len(all_prompts_text) // self.args.context_parallel_size ) # Broadcast the completions from the main process to all processes completion_ids = broadcast_object_list(completion_ids, from_process=0) # Determine the appropriate slice based on sequence parallelism - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: # Calculate SP group ID (which group of ranks this rank belongs to) sp_group_id = self.accelerator.process_index // self.local_world_size @@ -583,7 +583,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): advantages = advantages / (std_grouped_rewards + 1e-4) # Slice to keep only the local part of the data - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: # Calculate SP group ID (which group of ranks this rank belongs to) sp_group_id = self.accelerator.process_index // self.local_world_size diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index 94ba83dd5..34f2f9068 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -16,8 +16,6 @@ Module for handling LIGER input arguments. """ -from typing import Optional - from pydantic import BaseModel, model_validator from axolotl.utils.logging import get_logger @@ -30,13 +28,13 @@ class LigerArgs(BaseModel): Input args for LIGER. """ - liger_rope: Optional[bool] = None - liger_rms_norm: Optional[bool] = None - liger_layer_norm: Optional[bool] = None - liger_swiglu: Optional[bool] = None - liger_glu_activation: Optional[bool] = None - liger_cross_entropy: Optional[bool] = None - liger_fused_linear_cross_entropy: Optional[bool] = None + liger_rope: bool | None = None + liger_rms_norm: bool | None = None + liger_layer_norm: bool | None = None + liger_swiglu: bool | None = None + liger_glu_activation: bool | None = None + liger_cross_entropy: bool | None = None + liger_fused_linear_cross_entropy: bool | None = None @model_validator(mode="before") @classmethod @@ -62,3 +60,12 @@ class LigerArgs(BaseModel): "You cannot have both `liger_glu_activation` and `tiled_mlp` set." ) return data + + @model_validator(mode="before") + @classmethod + def check_liger_rms_norm_tensor_parallel(cls, data): + if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1: + raise ValueError( + "`liger_rms_norm` is incompatible with tensor parallelism, " + "see https://github.com/linkedin/Liger-Kernel/issues/826" + ) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 1ce98ef31..1041892f5 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -13,7 +13,8 @@ import peft import torch import transformers import transformers.modeling_utils -from accelerate import init_empty_weights +from accelerate import PartialState, init_empty_weights +from accelerate.utils.dataclasses import ParallelismConfig from peft import ( PeftConfig, PeftMixedModel, @@ -51,6 +52,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import ( get_device_count, get_device_type, + get_world_size, ) from axolotl.utils.logging import get_logger from axolotl.utils.model_shard_quant import load_sharded_model_quant @@ -182,6 +184,7 @@ class ModelLoader: def _apply_pre_model_load_setup(self): """Apply patches and setup configurations before model loading.""" + self._set_parallel_config() self._set_auto_model_loader() self._set_device_map_config() if self.cfg.revision_of_model: @@ -389,6 +392,32 @@ class ModelLoader: gc.collect() torch.cuda.empty_cache() + def _set_parallel_config(self): + """Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator""" + dp_replicate_size = get_world_size() + pc_kwargs = {} + if self.cfg.dp_shard_size > 1: + pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size + dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size + if self.cfg.tensor_parallel_size > 1: + pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size + dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size + if self.cfg.context_parallel_size > 1: + pc_kwargs["cp_size"] = self.cfg.context_parallel_size + dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size + if dp_replicate_size > 1: + pc_kwargs["dp_replicate_size"] = dp_replicate_size + + parallelism_config = ParallelismConfig( + **pc_kwargs, + ) + mesh_dim_names, mesh_shape = parallelism_config.get_mesh() + device_mesh = torch.distributed.init_device_mesh( + "cuda", mesh_shape, mesh_dim_names=mesh_dim_names + ) + PartialState().parallelism_config = parallelism_config + PartialState().device_mesh = device_mesh + def _set_auto_model_loader(self): """Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM` (set at `__init__`). When using a multimodal model, `self.auto_model_loader` diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 533bd0f7a..2ba330d74 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -261,14 +261,14 @@ class PatchManager: def _apply_sequence_parallel_patches(self): """Apply sequence parallelism patches.""" - if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: + if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1: from axolotl.monkeypatch.ring_attn.patch import ( patch_prepare_data_loader, patch_prepare_device_mesh, ) patch_prepare_data_loader() - patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp) + patch_prepare_device_mesh(self.cfg.context_parallel_size, self.cfg.fsdp) def _apply_tiled_mlp(self, model_type: str): if self.cfg.tiled_mlp: diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 803659232..d7270679c 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -254,6 +254,9 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: "offload_policy": fsdp2_plugin.cpu_offload, # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), + "mesh": accelerator.state.device_mesh[ + accelerator.state.parallelism_config.model_shard_dim_names + ], } model_has_params4bit = False diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 9c9ba4553..8022455bc 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -162,14 +162,14 @@ def create_ring_flash_attention_forward( def register_ring_attn( - sequence_parallel_degree: int, + context_parallel_size: int, heads_k_stride: int | None, ring_attn_func: RingAttnFunc | None, ): """Create ring attention group and substitute flash attn with ring flash attn. Args: - sequence_parallel_degree: Sequence parallelism factor. + context_parallel_size: Sequence parallelism factor. heads_k_stride: Sequence parallelism K head stride size. Passed through to `varlen_llama3` `ring_flash_attn` implementation. ring_attn_func: `ring_flash_attn` ring attention implemention. If sample @@ -182,25 +182,25 @@ def register_ring_attn( if rank == 0: LOG.info( "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs" + f"each sequence will be processed across {context_parallel_size} GPUs" ) - assert sequence_parallel_degree <= world_size, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " + assert context_parallel_size <= world_size, ( + f"context_parallel_size ({context_parallel_size}) " f"must be less than or equal to world_size ({world_size})" ) - assert world_size % sequence_parallel_degree == 0, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " + assert world_size % context_parallel_size == 0, ( + f"context_parallel_size ({context_parallel_size}) " f"must evenly divide world_size ({world_size})" ) # Assign ranks to sequence parallel groups group_assignments = {} - for i in range(world_size // sequence_parallel_degree): + for i in range(world_size // context_parallel_size): ring_attn_ranks = list( range( - i * sequence_parallel_degree, - (i + 1) * sequence_parallel_degree, + i * context_parallel_size, + (i + 1) * context_parallel_size, ) ) group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") @@ -299,12 +299,12 @@ def patch_prepare_data_loader(): LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") -def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False): +def patch_prepare_device_mesh(context_parallel_size: int, fsdp: bool = False): """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh that includes sequence parallelism with the specified degree. Args: - sequence_parallel_degree: The degree of sequence parallelism to use. + context_parallel_size: The degree of sequence parallelism to use. fsdp: Whether to use FSDP. """ @@ -323,8 +323,8 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False) # Create device mesh with sequence parallelism world_size = dist.get_world_size() mesh_shape = ( - world_size // sequence_parallel_degree, - sequence_parallel_degree, + world_size // context_parallel_size, + context_parallel_size, ) device_ids = list(range(world_size)) @@ -344,5 +344,5 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False) LOG.info( "Successfully patched Accelerator._prepare_device_mesh " - f"with sequence_parallel_degree={sequence_parallel_degree}" + f"with context_parallel_size={context_parallel_size}" ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 967179903..027ef7d37 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -202,7 +202,7 @@ def execute_training( ) ) - if cfg.sequence_parallel_degree > 1: + if cfg.context_parallel_size > 1: models = [trainer.model] if hasattr(trainer, "ref_model") and trainer.ref_model: models.append(trainer.ref_model) @@ -210,7 +210,7 @@ def execute_training( stack.enter_context( SequenceParallelContextManager( models=models, - sequence_parallel_degree=cfg.sequence_parallel_degree, + context_parallel_size=cfg.context_parallel_size, gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, heads_k_stride=cfg.heads_k_stride, diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 1ac805a73..50861fe28 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -167,7 +167,7 @@ class SequenceParallelContextManager: Args: models: List of models to apply sequence parallelism to pre- and post- forward hooks. - sequence_parallel_degree: Number of processes to split sequences over. + context_parallel_size: Number of processes to split sequences over. gradient_accumulation_steps: Number of steps to accumulate gradients over. ring_attn_func: Which ring attention function to use. Currently unused. heads_k_stride: Sequence parallelism K head stride size. Passed through to @@ -179,14 +179,14 @@ class SequenceParallelContextManager: def __init__( self, models: list[nn.Module], - sequence_parallel_degree: int, + context_parallel_size: int, gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, heads_k_stride: int | None, gather_outputs: bool, ): self.models = models - self.sequence_parallel_degree = sequence_parallel_degree + self.context_parallel_size = context_parallel_size self.gradient_accumulation_steps = gradient_accumulation_steps self.ring_attn_func = ring_attn_func self.heads_k_stride = heads_k_stride @@ -231,7 +231,7 @@ class SequenceParallelContextManager: def _register_ring_attn(self): # Initialize ring attn for sequence parallelism register_ring_attn( - sequence_parallel_degree=self.sequence_parallel_degree, + context_parallel_size=self.context_parallel_size, heads_k_stride=self.heads_k_stride, ring_attn_func=self.ring_attn_func, ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 96b694043..fb6de2b5a 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -644,7 +644,7 @@ class AxolotlInputConfig( }, ) - sequence_parallel_degree: int | None = Field( + context_parallel_size: int | None = Field( default=None, json_schema_extra={ "description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details." diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 0c1a97fcd..ed2b70307 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -686,7 +686,7 @@ class RLValidationMixin: data.get("rl") == "grpo" and data.get("trl", {}) and data.get("trl").get("use_liger_loss") - and data.get("sequence_parallel_degree", 1) > 1 + and data.get("context_parallel_size", 1) > 1 ): raise ValueError("GRPO + SP + Liger not currently supported") return data @@ -1235,13 +1235,13 @@ class ComplexValidationMixin: return self @model_validator(mode="after") - def check_sequence_parallel_degree(self): - if not self.sequence_parallel_degree: - self.sequence_parallel_degree = 1 - elif self.sequence_parallel_degree > 1: + def check_context_parallel_size(self): + if not self.context_parallel_size: + self.context_parallel_size = 1 + elif self.context_parallel_size > 1: if not self.flash_attention: raise ValueError( - "flash_attention: true must be set with sequence_parallel_degree > 1" + "flash_attention: true must be set with context_parallel_size > 1" ) if self.sample_packing and self.micro_batch_size > 1: @@ -1254,14 +1254,14 @@ class ComplexValidationMixin: import ring_flash_attn # noqa: F401 # pylint:disable=unused-import except ImportError as exception: raise ImportError( - "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " + "context_parallel_size > 1 but ring_flash_attn is not installed. " "Please install it with `pip install axolotl[ring-flash-attn] " "or `pip install ring-flash-attn>=0.1.4`." ) from exception LOG.warning( "Sequence parallelism (SP) is enabled with " - f"sequence_parallel_degree={self.sequence_parallel_degree}. " + f"context_parallel_size={self.context_parallel_size}. " "Please note that logged losses may differ slightly to the non-SP " "losses due to transformers Trainer implementation details. " "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " @@ -1272,7 +1272,7 @@ class ComplexValidationMixin: @model_validator(mode="after") def validate_ring_attn_func(self): - if getattr(self, "sequence_parallel_degree", 1) == 1: + if getattr(self, "context_parallel_size", 1) == 1: return self if self.ring_attn_func is not None: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8371b2dd7..90ae1a889 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): - 1 ) * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_size * cfg.tensor_parallel_size ) LOG.debug( @@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): math.floor( data_loader_len * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_size * cfg.tensor_parallel_size ) ) @@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): math.ceil( len(train_dataset) * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_size * cfg.tensor_parallel_size / cfg.batch_size ) diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index 040152beb..5f1aec8ff 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -64,7 +64,7 @@ def fixture_base_cfg(): "dataloader_num_workers": 1, "dataloader_pin_memory": True, "dataloader_prefetch_factor": 2, - "sequence_parallel_degree": 1, + "context_parallel_size": 1, "tensor_parallel_size": 1, # Dtype "fp16": False, diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py index 80098e684..cb3bc08ec 100644 --- a/tests/e2e/multigpu/patched/test_sp.py +++ b/tests/e2e/multigpu/patched/test_sp.py @@ -67,7 +67,7 @@ class TestSequenceParallelism: "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "ring_attn_func": ring_attn_func, "save_first_step": False, } diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index d022ae2d9..92e0f7040 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_linear": True, - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "flash_attention": True, "sequence_len": 1024, "special_tokens": { diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 4a2c69d45..584718b78 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -111,7 +111,7 @@ class TestRingAttention: # Call register_ring_attn with size 4 register_ring_attn( - sequence_parallel_degree=4, + context_parallel_size=4, heads_k_stride=1, ring_attn_func=RingAttnFunc.VARLEN_LLAMA3, ) @@ -156,24 +156,24 @@ class TestConfigValidation: [ # Valid configuration ( - {"sequence_parallel_degree": 2, "flash_attention": True}, - {"sequence_parallel_degree": 2, "flash_attention": True}, + {"context_parallel_size": 2, "flash_attention": True}, + {"context_parallel_size": 2, "flash_attention": True}, True, None, ), - # Default sequence_parallel_degree - ({}, {"sequence_parallel_degree": 1}, True, None), - # Invalid: sequence_parallel_degree > 1 without flash_attention + # Default context_parallel_size + ({}, {"context_parallel_size": 1}, True, None), + # Invalid: context_parallel_size > 1 without flash_attention ( - {"sequence_parallel_degree": 2, "flash_attention": False}, + {"context_parallel_size": 2, "flash_attention": False}, None, False, "flash_attention: true must be set", ), - # Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1 + # Invalid: context_parallel_size > 1 with sample_packing and micro_batch_size > 1 ( { - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "flash_attention": True, "sample_packing": True, "micro_batch_size": 2, @@ -186,13 +186,13 @@ class TestConfigValidation: # Valid: Basic GRPO config ( { - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "flash_attention": True, "micro_batch_size": 2, "trl": {"use_liger_loss": True}, }, { - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "flash_attention": True, "micro_batch_size": 2, "trl": TRLConfig(use_liger_loss=True), @@ -204,7 +204,7 @@ class TestConfigValidation: ( { "rl": "grpo", - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "flash_attention": True, "micro_batch_size": 2, "trl": {"use_liger_loss": True}, @@ -262,7 +262,7 @@ class TestConfigValidation: # Apply updates to base config cfg = base_cfg | { - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "flash_attention": True, "sample_packing": sample_packing, } @@ -282,7 +282,7 @@ class TestConfigValidation: # Invalid configuration with invalid ring_attn_func cfg = base_cfg | { - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "flash_attention": True, "ring_attn_func": "INVALID_FUNC", }