diff --git a/docs/config.qmd b/docs/config.qmd index 73aada105..787632c50 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -620,6 +620,11 @@ ddp_timeout: ddp_bucket_cap_mb: ddp_broadcast_buffers: +# Sequence parallelism +# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. +# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. +sequence_parallel_degree: + # Path to torch distx for optim 'adamw_anyprecision' torchdistx_path: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 5fe5b84a8..908947876 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -763,8 +763,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] = self.cfg.kd_top_k_before_softmax training_arguments_kwargs[ - "sequence_parallel_size" - ] = self.cfg.sequence_parallel_size + "sequence_parallel_degree" + ] = self.cfg.sequence_parallel_degree if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig @@ -911,7 +911,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator = DataCollatorForSeq2Seq kwargs["return_tensors"] = "pt" - kwargs["sequence_parallel_size"] = training_args.sequence_parallel_size + kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree return collator( *collator_args, diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 05a6f92f0..ebc51d5fa 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -423,9 +423,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: # Handle sequence parallelism - if self.args.sequence_parallel_size > 1: - num_sp_groups = self.args.world_size // self.args.sequence_parallel_size - sp_group_id = dist.get_rank() // self.args.sequence_parallel_size + if self.args.sequence_parallel_degree > 1: + num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree + sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree # Create base sampler for SP groups base_sampler = torch.utils.data.distributed.DistributedSampler( @@ -472,10 +472,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset # Handle sequence parallelism - if self.args.sequence_parallel_size > 1: + if self.args.sequence_parallel_degree > 1: # Create sampler for SP groups - num_sp_groups = self.args.world_size // self.args.sequence_parallel_size - sp_group_id = dist.get_rank() // self.args.sequence_parallel_size + num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree + sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree # Create distributed sampler for the SP group base_sampler = torch.utils.data.distributed.DistributedSampler( @@ -570,7 +570,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): # Don't prepare dataloader for sequence parallelism # We use a distributed sampler in this case - if self.args.sequence_parallel_size > 1: + if self.args.sequence_parallel_degree > 1: return dataloader return self.accelerator.prepare_data_loader(dataloader) @@ -625,11 +625,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): # Don't prepare dataloader for sequence parallelism # We use a distributed sampler in this case - if self.args.sequence_parallel_size > 1: + if self.args.sequence_parallel_degree > 1: return dataloader return self.accelerator.prepare_data_loader(dataloader) - if self.args.sequence_parallel_size > 1: + if self.args.sequence_parallel_degree > 1: eval_dataset = ( eval_dataset if eval_dataset is not None else self.eval_dataset ) @@ -949,14 +949,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): """ Perform a training step on a batch of inputs. """ - if self.args.sequence_parallel_size > 1: + if self.args.sequence_parallel_degree > 1: # At this point, inputs should already be partitioned by the sequence # parallel data collator batch_size = inputs["input_ids"].shape[0] seq_len = inputs["input_ids"].shape[1] # Calculate the full sequence length across all GPUs in this SP group - total_seq_len = seq_len * self.args.sequence_parallel_size + total_seq_len = seq_len * self.args.sequence_parallel_degree # Pass the partitioned sequence information to ring flash attention self._update_ring_flash_attn_params( diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 57ad638d6..82a62c049 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -207,7 +207,7 @@ class AxolotlTrainingMixins: }, ) - sequence_parallel_size: Optional[int] = field( + sequence_parallel_degree: Optional[int] = field( default=1, metadata={"help": "The number of workers to use in sequence parallelism"}, ) diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index e45ca249f..eb146609e 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -23,21 +23,21 @@ def set_ring_attn_group(ring_attn_group: Any): RING_ATTN_GROUP = ring_attn_group -def register_ring_attn(sequence_parallel_size: int): +def register_ring_attn(sequence_parallel_degree: int): """ Create ring attention group and substitute flash attn with ring flash attn. Args: - sequence_parallel_size: Sequence parallelism factor. + sequence_parallel_degree: Sequence parallelism factor. """ LOG.info( "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_size} GPUs" + f"each sequence will be processed across {sequence_parallel_degree} GPUs" ) world_size = dist.get_world_size() - assert world_size % sequence_parallel_size == 0, ( - f"sequence_parallel_size ({sequence_parallel_size}) " + assert world_size % sequence_parallel_degree == 0, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " f"must evenly divide world_size ({world_size})" ) @@ -45,11 +45,11 @@ def register_ring_attn(sequence_parallel_size: int): rank = dist.get_rank() group_assignments = {} - for i in range(world_size // sequence_parallel_size): + for i in range(world_size // sequence_parallel_degree): ring_attn_ranks = list( range( - i * sequence_parallel_size, - (i + 1) * sequence_parallel_size, + i * sequence_parallel_degree, + (i + 1) * sequence_parallel_degree, ) ) group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") @@ -65,4 +65,4 @@ def register_ring_attn(sequence_parallel_size: int): if rank == 0: LOG.info(f"Sequence parallel group assignments: {group_assignments}") - substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size) + substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 0a7c79434..9ce9d84ae 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -95,7 +95,7 @@ class DataCollatorForSeq2Seq: The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). return_tensors (`str`): The type of Tensor to return. Allowable values are "np", "pt" and "tf". - sequence_parallel_size (`int`): + sequence_parallel_degree (`int`): The degree of sequence parallelism. Default to 1 for no sequence parallelism. """ @@ -107,10 +107,10 @@ class DataCollatorForSeq2Seq: label_pad_token_id: int = -100 position_pad_token_id: int = 0 return_tensors: str = "pt" - sequence_parallel_size: int = 1 + sequence_parallel_degree: int = 1 def __post_init__(self): - if self.sequence_parallel_size > 1: + if self.sequence_parallel_degree > 1: # Get information about our position in the SP group sp_group = get_ring_attn_group() self.rank = dist.get_rank() @@ -183,7 +183,7 @@ class DataCollatorForSeq2Seq: ) features["decoder_input_ids"] = decoder_input_ids - if self.sequence_parallel_size > 1: + if self.sequence_parallel_degree > 1: features = self.apply_sequence_parallelism(features) return features diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e88d063e6..a38e4aa6c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -552,14 +552,13 @@ class ModelLoader: patch_self_attn_lora(self.cfg) - if self.cfg.sequence_parallel_size > 1: + if self.cfg.sequence_parallel_degree > 1: from axolotl.monkeypatch.attention.ring_attn import register_ring_attn - # Initialize ring attention for sequence parallelism if enabled. - # This must be done after model initialization but before the first forward pass, - # as it modifies the flash attention implementation to use ring communication - # patterns for efficient sequence-parallel training across multiple GPUs. - register_ring_attn(self.cfg.sequence_parallel_size) + # Initialize ring attn for sequence parallelism. This must be done after + # model init but before the first forward pass, since it modifies flash + # attn to use ring comm for SP training across multiple GPUs. + register_ring_attn(self.cfg.sequence_parallel_degree) def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 89b3612fb..3a8e73a4d 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -245,7 +245,7 @@ class AxolotlInputConfig( val_set_size: float | None = Field(default=0.0) - sequence_parallel_size: int | None = 1 + sequence_parallel_degree: int | None = 1 special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None @@ -1107,10 +1107,10 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_sequence_parallel_config(cls, data): - if data.get("sequence_parallel_size") > 1: + if data.get("sequence_parallel_degree") > 1: if not data.get("flash_attention"): raise ValueError( - "flash_attention: true must be set with sequence_parallel_size > 1" + "flash_attention: true must be set with sequence_parallel_degree > 1" ) return data diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 0e04039a0..d2b211bbc 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) - elif cfg.sample_packing or cfg.sequence_parallel_size > 1: + elif cfg.sample_packing or cfg.sequence_parallel_degree > 1: drop_long_kwargs = {} if filter_map_kwargs: drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" @@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): **filter_map_kwargs, **drop_long_kwargs, ) - if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1: + if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1: if eval_dataset: eval_dataset = eval_dataset.map( add_position_ids, @@ -443,7 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): - 1 ) * cfg.num_epochs - * cfg.sequence_parallel_size + * cfg.sequence_parallel_degree ) LOG.debug( f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", @@ -476,7 +476,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): # on the agreed on value for sample_packing_eff_est total_num_steps = int( math.floor( - data_loader_len * cfg.num_epochs * cfg.sequence_parallel_size + data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree ) ) @@ -502,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): math.ceil( len(train_dataset) * cfg.num_epochs - * cfg.sequence_parallel_size + * cfg.sequence_parallel_degree / cfg.batch_size ) ) diff --git a/tests/e2e/multigpu/test_sequence_parallelism.py b/tests/e2e/multigpu/test_sequence_parallelism.py index c57c76caf..19edfb0ad 100644 --- a/tests/e2e/multigpu/test_sequence_parallelism.py +++ b/tests/e2e/multigpu/test_sequence_parallelism.py @@ -25,7 +25,7 @@ def test_integration_with_config(): ], "load_in_8bit": False, "sequence_len": 1024, - "sequence_parallel_size": 2, + "sequence_parallel_degree": 2, "flash_attention": True, "sample_packing": True, "pad_to_sequence_len": True, @@ -58,17 +58,17 @@ def test_integration_with_config(): normalize_config(cfg) # Verify sequence parallelism settings were properly processed - assert cfg.sequence_parallel_size == 2 + assert cfg.sequence_parallel_degree == 2 assert cfg.flash_attention is True - # Check if the sequence_parallel_size was propagated to the training args + # Check if the sequence_parallel_degree was propagated to the training args from axolotl.core.training_args import AxolotlTrainingArguments # pylint: disable=unexpected-keyword-arg training_args = AxolotlTrainingArguments( - output_dir=temp_dir, sequence_parallel_size=cfg.sequence_parallel_size + output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree ) - assert training_args.sequence_parallel_size == 2 + assert training_args.sequence_parallel_degree == 2 def test_ring_attn_group_creation(): @@ -90,7 +90,7 @@ def test_ring_attn_group_creation(): pytest.skip(f"Need an even number of GPUs, but got {world_size}") # Register with sequence parallel size of 2 - register_ring_attn(sequence_parallel_size=2) + register_ring_attn(sequence_parallel_degree=2) # Get the ring attention group group = get_ring_attn_group() diff --git a/tests/e2e/patched/test_sequence_parallelism.py b/tests/e2e/patched/test_sequence_parallelism.py index 126af595f..48d264ece 100644 --- a/tests/e2e/patched/test_sequence_parallelism.py +++ b/tests/e2e/patched/test_sequence_parallelism.py @@ -94,7 +94,7 @@ class TestRingAttention: mock_new_group.return_value = mock_group # Call register_ring_attn with size 4 - register_ring_attn(sequence_parallel_size=4) + register_ring_attn(sequence_parallel_degree=4) # Verify the number of calls without examining the arguments assert mock_new_group.call_count == 2 @@ -175,15 +175,15 @@ def test_config_validation_with_valid_inputs(cfg): # Import the actual model class with appropriate mocks from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig - # Valid configuration: sequence_parallel_size > 1 and flash_attention is True + # Valid configuration: sequence_parallel_degree > 1 and flash_attention is True cfg = cfg | { - "sequence_parallel_size": 2, + "sequence_parallel_degree": 2, "flash_attention": True, } # Should validate without errors config = AxolotlInputConfig(**cfg) - assert config.sequence_parallel_size == 2 + assert config.sequence_parallel_degree == 2 assert config.flash_attention is True @@ -191,9 +191,9 @@ def test_config_validation_with_invalid_inputs(cfg): """Test that invalid sequence parallelism configurations fail validation.""" from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig - # Invalid configuration: sequence_parallel_size > 1 but flash_attention is False + # Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False cfg = cfg | { - "sequence_parallel_size": 2, + "sequence_parallel_degree": 2, "flash_attention": False, }