From 6aa41740df7623d1f8995a1efd3b668f4a57c5cf Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 21 May 2025 11:20:20 -0400 Subject: [PATCH] SP dataloader patching + removing custom sampler / dataloader logic (#2686) * utilize accelerate prepare_data_loader with patching * lint * cleanup, fix * update to support DPO quirk * small change * coderabbit commits, cleanup, remove dead code * quarto fix * patch fix * review comments * moving monkeypatch up one level * fix --- _quarto.yml | 1 - docs/multi-gpu.qmd | 15 +- docs/sequence_parallelism.qmd | 6 +- examples/qwen2/dpo.yaml | 1 - src/axolotl/core/trainer_builder.py | 9 - src/axolotl/core/trainers/base.py | 50 +--- src/axolotl/core/trainers/dpo/trainer.py | 157 +----------- src/axolotl/core/trainers/grpo/trainer.py | 2 +- src/axolotl/core/trainers/mixins/__init__.py | 1 - .../core/trainers/mixins/sequence_parallel.py | 87 ------- src/axolotl/core/training_args.py | 13 - .../attention/ring_attn/__init__.py | 11 - .../monkeypatch/attention/ring_attn/patch.py | 131 ---------- src/axolotl/monkeypatch/ring_attn/__init__.py | 22 ++ .../ring_attn/adapters/__init__.py | 0 .../ring_attn/adapters/batch.py | 0 src/axolotl/monkeypatch/ring_attn/patch.py | 223 ++++++++++++++++++ .../utils/ctx_managers/sequence_parallel.py | 24 +- src/axolotl/utils/models.py | 22 +- tests/e2e/patched/test_sp.py | 6 +- 20 files changed, 304 insertions(+), 477 deletions(-) delete mode 100644 src/axolotl/core/trainers/mixins/sequence_parallel.py delete mode 100644 src/axolotl/monkeypatch/attention/ring_attn/__init__.py delete mode 100644 src/axolotl/monkeypatch/attention/ring_attn/patch.py create mode 100644 src/axolotl/monkeypatch/ring_attn/__init__.py rename src/axolotl/monkeypatch/{attention => }/ring_attn/adapters/__init__.py (100%) rename src/axolotl/monkeypatch/{attention => }/ring_attn/adapters/batch.py (100%) create mode 100644 src/axolotl/monkeypatch/ring_attn/patch.py diff --git a/_quarto.yml b/_quarto.yml index dc5071838..c09aecaea 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -60,7 +60,6 @@ quartodoc: - core.trainers.mixins.optimizer - core.trainers.mixins.rng_state_loader - core.trainers.mixins.scheduler - - core.trainers.mixins.sequence_parallel - title: Context Managers desc: Context managers for altering trainer behaviors contents: diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index 55eaca6c3..fee7d17e5 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -87,20 +87,7 @@ We support sequence parallelism (SP) via the allows one to split up sequences across GPUs, which is useful in the event that a single sequence causes OOM errors during model training. -First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`, -or from source with `pip install .[ring-flash-attn]`. - -Your Axolotl YAML config should contain the following lines: - -```{.yaml} -sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU -flash_attention: true # Required with sequence parallelism - -# Optional; strides across the key dimension. Larger values use more memory but will make training faster. -heads_k_stride: 1 -``` - -See our [dedicated guide](sequence_parallelism.qmd) for more details. +See our [dedicated guide](sequence_parallelism.qmd) for more information. ### FSDP + QLoRA {#sec-fsdp-qlora} diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index 1bff17ce9..b98206135 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -41,7 +41,7 @@ When sequence parallelism is enabled: 1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group 2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids -3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences +3. Position IDs are adjusted to maintain proper relative positions 4. The trainer uses special ring communication patterns for attention operations ## Requirements @@ -67,9 +67,11 @@ sequence_len: 8192 ... sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU -flash_attention: true # Required with sequence parallelism # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 +# Optional; one of "varlen_llama3" or "batch_ring". Defaults to +# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise. +ring_attn_func: ... ``` diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml index 3547c6c98..bd896c2b3 100644 --- a/examples/qwen2/dpo.yaml +++ b/examples/qwen2/dpo.yaml @@ -2,7 +2,6 @@ base_model: Qwen/Qwen2.5-0.5B # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name - chat_template: qwen_25 rl: dpo datasets: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d82e4d20b..878dd176a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -798,11 +798,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.kd_top_k_before_softmax ) - training_arguments_kwargs["sequence_parallel_degree"] = ( - self.cfg.sequence_parallel_degree - ) - training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func - if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig elif self.cfg.process_reward_model: @@ -1083,10 +1078,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.use_wandb: training_args_kwargs["run_name"] = self.cfg.wandb_name - training_args_kwargs["sequence_parallel_degree"] = ( - self.cfg.sequence_parallel_degree - ) - training_args_cls = None blocklist_args_kwargs = [] if self.cfg.rl is RLType.SIMPO: diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 2f0ce6894..d5cfc23df 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -29,7 +29,6 @@ from axolotl.core.trainers.mixins import ( OptimizerMixin, RngLoaderMixin, SchedulerMixin, - SequenceParallelMixin, ) from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, @@ -40,9 +39,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = logging.getLogger(__name__) -class AxolotlTrainer( - SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer -): +class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] @@ -68,10 +65,6 @@ class AxolotlTrainer( if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - # Initialize sequence parallelism if enabled - if self.args.sequence_parallel_degree > 1: - self._setup_sequence_parallel() - def _wrap_model(self, model, training=True, dataloader=None): if self.args.torch_compile: torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access @@ -122,8 +115,8 @@ class AxolotlTrainer( def _get_train_sampler(self) -> Sampler | None: """ - Helper method to get the sampler for training. Handles cases for sequence - parallelism, sample packing, and curriculum sampling (sequential). + Helper method to get the sampler for training. Handles cases for sample packing + and curriculum sampling (sequential). Returns: If the dataset is non-empty, a sampler is returned, the type of which @@ -132,9 +125,7 @@ class AxolotlTrainer( use_sample_packing = self.args.sample_packing and not self.args.pretraining # Determine the base sampler first - if self.args.sequence_parallel_degree > 1: - base_sampler = self._sp_get_train_sampler(self.train_dataset) - elif self.args.curriculum_sampling: + if self.args.curriculum_sampling: base_sampler = SequentialSampler(self.train_dataset) elif use_sample_packing: base_sampler = RandomSampler(self.train_dataset) @@ -153,8 +144,7 @@ class AxolotlTrainer( def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None: """ - Helper method to get the sampler for evaluation. Handles sequence parallelism - and sample packing cases. + Helper method to get the sampler for evaluation. Handles sample packing case. Returns: If the dataset is non-empty, a sampler is returned, the type of which @@ -168,9 +158,7 @@ class AxolotlTrainer( ) # Determine the base sampler - if self.args.sequence_parallel_degree > 1: - base_sampler = self._sp_get_eval_sampler(eval_dataset) - elif use_multipack: + if use_multipack: base_sampler = SequentialSampler(eval_dataset) else: return super()._get_eval_sampler(eval_dataset) @@ -236,14 +224,6 @@ class AxolotlTrainer( ): self.accelerator.even_batches = False - # Return unprepared dataloader if using sequence parallelism - # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation - # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., - # slice each batch along the sequence dimension). - if self.args.sequence_parallel_degree > 1: - return dataloader - - # Otherwise prepare with accelerator return self.accelerator.prepare_data_loader(dataloader) def get_train_dataloader(self) -> DataLoader: @@ -287,12 +267,7 @@ class AxolotlTrainer( return dataloader - # Handle sample packing or sequence parallelism - if ( - self.args.sample_packing - and self.args.eval_sample_packing is not False - or self.args.sequence_parallel_degree > 1 - ): + if self.args.sample_packing and self.args.eval_sample_packing is not False: # Get appropriate data collator self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.eval_data_collator @@ -302,17 +277,6 @@ class AxolotlTrainer( if "length" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns(["length"]) - # Handle dataset preprocessing for SP - if self.args.sequence_parallel_degree > 1: - if isinstance(eval_dataset, datasets.Dataset): - eval_dataset = self._remove_unused_columns( - eval_dataset, description="evaluation" - ) - else: - self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init - self.data_collator, description="evaluation" - ) - # Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise batch_size = ( self.args.eval_batch_size diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 1ce7deea7..c2c80c0bc 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -1,31 +1,15 @@ -""" -DPO trainer for axolotl -""" +"""DPO trainer for axolotl""" import gc -import random from functools import wraps -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union -import pandas as pd import torch -import wandb -from accelerate import PartialState -from datasets import Dataset, IterableDataset from peft.optimizers import create_loraplus_optimizer from torch import nn -from torch.utils.data import DataLoader -from transformers import ( - BaseImageProcessor, - FeatureExtractionMixin, - PreTrainedTokenizerBase, - ProcessorMixin, - Trainer, -) -from transformers.trainer_utils import EvalLoopOutput +from transformers import Trainer from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt -from trl.trainer.utils import log_table_to_comet_experiment +from trl import DPOTrainer from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.utils import ( @@ -38,9 +22,7 @@ if is_sagemaker_mp_enabled(): class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): - """ - Extend the base DPOTrainer for axolotl helpers - """ + """Extend the base DPOTrainer for axolotl helpers.""" tag_names = ["axolotl", "dpo"] @@ -85,8 +67,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): @wraps(DPOTrainer.push_to_hub) def push_to_hub(self, *args, **kwargs) -> str: """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + Overwrite the `push_to_hub` method in order to force-add the tags when pushing + the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` + for more details. """ kwargs = sanitize_kwargs_for_ds_tagging( dataset_tags=self.dataset_tags, kwargs=kwargs @@ -95,64 +78,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): return super().push_to_hub(*args, **kwargs) - # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release - def _prepare_dataset( - self, - dataset: Union[Dataset, IterableDataset], - processing_class: Union[ - PreTrainedTokenizerBase, - BaseImageProcessor, - FeatureExtractionMixin, - ProcessorMixin, - ], - args: DPOConfig, - dataset_name: str, - ) -> Union[Dataset, IterableDataset]: - # Build the kwargs for the `map` function - map_kwargs: Dict[str, Any] = {"writer_batch_size": 10} - if isinstance(dataset, Dataset): # IterableDataset does not support num_proc - map_kwargs["num_proc"] = args.dataset_num_proc - - with PartialState().main_process_first(): - # Extract prompt if needed - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" - dataset = dataset.map(maybe_extract_prompt, **map_kwargs) - - # Apply the chat template if needed - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" - dataset = dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, - **map_kwargs, - ) - - # Tokenize the dataset - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - - dataset = dataset.map( - self.tokenize_row if not self.is_vision_model else self.process_row, - remove_columns=["chosen", "rejected"], - fn_kwargs={ - "processing_class": processing_class, - "max_prompt_length": args.max_prompt_length, - "max_completion_length": args.max_completion_length, - # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) - "add_special_tokens": False, - }, - **map_kwargs, - ) - - return dataset - @staticmethod def tokenize_row( features, @@ -192,69 +117,3 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): gc.collect() torch.cuda.empty_cache() return loss - - # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. - Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample( - range(num_samples), k=self.args.eval_batch_size - ) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - policy_output_decoded, ref_output_decoded = ( - self.generate_from_model_and_ref(self.model, random_batch) - ) - - table = pd.DataFrame( - columns=["Prompt", "Policy", "Ref Model"], - data=[ - [prompt, pol[len(prompt) :], ref[len(prompt) :]] - for prompt, pol, ref in zip( - random_batch_dataset["prompt"], - policy_output_decoded, - ref_output_decoded, - ) - ], - ) - if "wandb" in self.args.report_to and self.accelerator.is_main_process: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - # Base evaluation - initial_output = super( # pylint: disable=bad-super-call - DPOTrainer, self - ).evaluation_loop( - dataloader, - description, - prediction_loss_only, - ignore_keys, - metric_key_prefix, - ) - - return initial_output diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 8a89de333..a603ed860 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -43,7 +43,7 @@ from trl.trainer.utils import pad from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin -from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group +from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group if is_peft_available(): # pylint: disable=unused-import diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 44751b465..a71cb321a 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -6,4 +6,3 @@ from .optimizer import OptimizerMixin from .rng_state_loader import RngLoaderMixin from .scheduler import SchedulerMixin -from .sequence_parallel import SequenceParallelMixin diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py deleted file mode 100644 index 0f30458cd..000000000 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Module for Axolotl trainer sequence parallelism mixin""" - -import torch.distributed as dist -from datasets import Dataset -from torch.utils.data import DistributedSampler, Sampler - -from axolotl.monkeypatch.attention.ring_attn import ( - get_ring_attn_group, -) - - -class SequenceParallelMixin: - """ - Mixin class for sequence parallelism support in trainers. - - This mixin provides functionality for handling sequence parallelism, - specifically for creating appropriate data samplers. - """ - - args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] - - def _setup_sequence_parallel(self): - """Set up sequence parallelism environment.""" - self.ring_attn_group = get_ring_attn_group() - - def _create_sequence_parallel_sampler( - self, - dataset: Dataset, - shuffle: bool = True, - is_eval: bool = False, - ) -> DistributedSampler: - """ - Helper method to create sampler for sequence parallelism (SP). - - We create a distributed sampler with rank equal to the SP group ID, which - means that all ranks in the SP group receive the same sample / set of samples - per training step. We also set the number of replicas equal to the number of - SP groups, which is a bit of a hack / unintended use, but works! - - Args: - dataset: Dataset to sample from. - shuffle: Whether to shuffle the dataset. - is_eval: Whether we are creating a sampler for evaluation or training. - - Returns: - Distributed sampler. - """ - num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree - sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree - - return DistributedSampler( - dataset, - num_replicas=num_sp_groups, - rank=sp_group_id, - seed=self.args.seed if shuffle else None, - shuffle=shuffle, - drop_last=not is_eval, - ) - - def _sp_get_train_sampler(self, dataset) -> Sampler | None: - """ - Get a training sampler configured for sequence parallelism. - - Args: - dataset: The training dataset - - Returns: - Configured sequence parallel sampler. - """ - return self._create_sequence_parallel_sampler( - dataset, - shuffle=not self.args.curriculum_sampling, - ) - - def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None: - """ - Get an evaluation sampler configured for sequence parallelism. - - Args: - eval_dataset: The evaluation dataset. - - Returns: - Configured sequence parallel sampler. - """ - return self._create_sequence_parallel_sampler( - eval_dataset, shuffle=False, is_eval=True - ) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index a81c33801..9c93f77c7 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -9,8 +9,6 @@ from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig -from axolotl.utils.schemas.enums import RingAttnFunc - @dataclass class AxolotlTrainingMixins: @@ -216,17 +214,6 @@ class AxolotlTrainingMixins: }, ) - sequence_parallel_degree: Optional[int] = field( - default=1, - metadata={"help": "The number of workers to use in sequence parallelism"}, - ) - ring_attn_func: Optional[RingAttnFunc] = field( - default=None, - metadata={ - "help": "The ring-flash-attn function to use in sequence parallelism" - }, - ) - adam_beta3: Optional[float] = field( default=None, metadata={ diff --git a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py b/src/axolotl/monkeypatch/attention/ring_attn/__init__.py deleted file mode 100644 index a50ad456e..000000000 --- a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Init for ring attention monkeypatch module""" - -# pylint: disable=unused-import -# flake8: noqa - -from .patch import ( - get_ring_attn_group, - register_ring_attn, - set_ring_attn_group, - update_ring_attn_params, -) diff --git a/src/axolotl/monkeypatch/attention/ring_attn/patch.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py deleted file mode 100644 index 8cbba338a..000000000 --- a/src/axolotl/monkeypatch/attention/ring_attn/patch.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -Ring attention group registration and flash attention patching. - -Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention) -package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in -their sequence parallel version of Flash Attention 2. -""" - -import torch -import torch.distributed as dist -from accelerate.logging import get_logger - -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids -from axolotl.utils.schemas.enums import RingAttnFunc - -LOG = get_logger(__name__) - - -RING_ATTN_GROUP = None - - -def get_ring_attn_group() -> dist.ProcessGroup: - """ - Getter for ring attention group on this rank. - - Returns: - The process group for ring attention for this rank. - """ - return RING_ATTN_GROUP - - -def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): - """ - Setter for ring attention group on this rank. - - Args: - Process group for ring attention. - """ - global RING_ATTN_GROUP # pylint: disable=global-statement - RING_ATTN_GROUP = ring_attn_group - - -def register_ring_attn( - sequence_parallel_degree: int, - heads_k_stride: int | None, - ring_attn_func: RingAttnFunc | None, -): - """ - Create ring attention group and substitute flash attn with ring flash attn. - - Args: - sequence_parallel_degree: Sequence parallelism factor. - heads_k_stride: Sequence parallelism K head stride size. Passed - through to `ring_flash_attn.substitute_hf_flash_attn`. - ring_attn_func: `ring_flash_attn` ring attention implemention. If sample - packing is enabled, it must be a `varlen` function; otherwise, it must be a - `batch` function. - """ - if get_ring_attn_group() is not None: - LOG.info("Ring attention already registered, exiting early...") - return - - LOG.info( - "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs" - ) - - rank = dist.get_rank() - world_size = dist.get_world_size() - - assert sequence_parallel_degree <= world_size, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must be less than or equal to world_size ({world_size})" - ) - assert world_size % sequence_parallel_degree == 0, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must evenly divide world_size ({world_size})" - ) - - # Assign ranks to sequence parallel groups - group_assignments = {} - for i in range(world_size // sequence_parallel_degree): - ring_attn_ranks = list( - range( - i * sequence_parallel_degree, - (i + 1) * sequence_parallel_degree, - ) - ) - group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") - - # Track which GPUs are in which groups - for r in ring_attn_ranks: - group_assignments[r] = i - - if rank in ring_attn_ranks: - set_ring_attn_group(group) - - # Log the GPU group assignments - if rank == 0: - LOG.info(f"Sequence parallel group assignments: {group_assignments}") - - if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: - from ring_flash_attn import substitute_hf_flash_attn - - substitute_hf_flash_attn( - process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 - ) - elif ring_attn_func is RingAttnFunc.BATCH_RING: - from axolotl.monkeypatch.attention.ring_attn.adapters.batch import ( - substitute_hf_flash_attn, - ) - - substitute_hf_flash_attn( - process_group=get_ring_attn_group(), - ring_attn_func=ring_attn_func, - ) - - -def update_ring_attn_params(position_ids: torch.Tensor | None): - """ - Calculate the cumulative sequence lengths for the current forward pass and pass the - value to the substituted `ring_flash_attn`. - - Args: - position_ids: Optional tensor of position IDs (for sample packed data). - """ - from ring_flash_attn import update_ring_flash_attn_params - - cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) - update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) diff --git a/src/axolotl/monkeypatch/ring_attn/__init__.py b/src/axolotl/monkeypatch/ring_attn/__init__.py new file mode 100644 index 000000000..5833b9ce4 --- /dev/null +++ b/src/axolotl/monkeypatch/ring_attn/__init__.py @@ -0,0 +1,22 @@ +"""Init for ring attention monkeypatch module""" + +# pylint: disable=unused-import +# flake8: noqa + +from .patch import ( + get_ring_attn_group, + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, + set_ring_attn_group, + update_ring_attn_params, +) + +__all__ = ( + "get_ring_attn_group", + "patch_prepare_data_loader", + "patch_prepare_device_mesh", + "register_ring_attn", + "set_ring_attn_group", + "update_ring_attn_params", +) diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/__init__.py b/src/axolotl/monkeypatch/ring_attn/adapters/__init__.py similarity index 100% rename from src/axolotl/monkeypatch/attention/ring_attn/adapters/__init__.py rename to src/axolotl/monkeypatch/ring_attn/adapters/__init__.py diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py b/src/axolotl/monkeypatch/ring_attn/adapters/batch.py similarity index 100% rename from src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py rename to src/axolotl/monkeypatch/ring_attn/adapters/batch.py diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py new file mode 100644 index 000000000..4329d9f13 --- /dev/null +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -0,0 +1,223 @@ +"""Ring attention group registration and flash attention patching. + +Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention) +package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in +their sequence parallel version of Flash Attention 2. + +We also provide some patches for accelerate functions to prepare the dataloader for +sequence parallelism training. +""" + +import inspect + +import accelerate +import torch +import torch.distributed as dist +from accelerate.logging import get_logger + +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.schemas.enums import RingAttnFunc + +LOG = get_logger(__name__) + + +RING_ATTN_GROUP = None + +ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 + submesh_dp_size = 1 + submesh_tp_size = 1 + if "tp" in torch_device_mesh.mesh_dim_names: + submesh_tp_size = torch_device_mesh["tp"].size() + if "dp" in torch_device_mesh.mesh_dim_names: + submesh_dp_size = torch_device_mesh["dp"].size() + if "fsdp" in torch_device_mesh.mesh_dim_names: + submesh_fsdp_size = torch_device_mesh["fsdp"].size() + process_index = process_index // submesh_tp_size""" + +NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 + submesh_dp_size = 1 + submesh_tp_size = 1 + submesh_cp_size = 1 + if "cp" in torch_device_mesh.mesh_dim_names: + submesh_cp_size = torch_device_mesh["cp"].size() + if "tp" in torch_device_mesh.mesh_dim_names: + submesh_tp_size = torch_device_mesh["tp"].size() + if "dp" in torch_device_mesh.mesh_dim_names: + submesh_dp_size = torch_device_mesh["dp"].size() + if "fsdp" in torch_device_mesh.mesh_dim_names: + submesh_fsdp_size = torch_device_mesh["fsdp"].size() + process_index = process_index // (submesh_tp_size * submesh_cp_size)""" + + +def get_ring_attn_group() -> dist.ProcessGroup: + """Getter for ring attention group on this rank.""" + return RING_ATTN_GROUP + + +def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): + """Setter for ring attention group on this rank.""" + global RING_ATTN_GROUP # pylint: disable=global-statement + RING_ATTN_GROUP = ring_attn_group + + +def register_ring_attn( + sequence_parallel_degree: int, + heads_k_stride: int | None, + ring_attn_func: RingAttnFunc | None, +): + """Create ring attention group and substitute flash attn with ring flash attn. + + Args: + sequence_parallel_degree: Sequence parallelism factor. + heads_k_stride: Sequence parallelism K head stride size. Passed + through to `ring_flash_attn.substitute_hf_flash_attn`. + ring_attn_func: `ring_flash_attn` ring attention implemention. If sample + packing is enabled, it must be a `varlen` function; otherwise, it must be a + `batch` function. + """ + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + LOG.info( + "Enabling ring attention sequence parallelism: " + f"each sequence will be processed across {sequence_parallel_degree} GPUs" + ) + + assert sequence_parallel_degree <= world_size, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " + f"must be less than or equal to world_size ({world_size})" + ) + assert world_size % sequence_parallel_degree == 0, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " + f"must evenly divide world_size ({world_size})" + ) + + # Assign ranks to sequence parallel groups + group_assignments = {} + for i in range(world_size // sequence_parallel_degree): + ring_attn_ranks = list( + range( + i * sequence_parallel_degree, + (i + 1) * sequence_parallel_degree, + ) + ) + group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") + + # Track which GPUs are in which groups + for r in ring_attn_ranks: + group_assignments[r] = i + + if rank in ring_attn_ranks: + set_ring_attn_group(group) + + # Log the GPU group assignments + if rank == 0: + LOG.info(f"Sequence parallel group assignments: {group_assignments}") + + if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: + from ring_flash_attn import substitute_hf_flash_attn + + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 + ) + elif ring_attn_func is RingAttnFunc.BATCH_RING: + from axolotl.monkeypatch.ring_attn.adapters.batch import ( + substitute_hf_flash_attn, + ) + + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), + ring_attn_func=ring_attn_func, + ) + + +def update_ring_attn_params(position_ids: torch.Tensor | None): + """ + Calculate the cumulative sequence lengths for the current forward pass and pass the + value to the substituted `ring_flash_attn`. + + Args: + position_ids: Optional tensor of position IDs (for sample packed data). + """ + from ring_flash_attn import update_ring_flash_attn_params + + cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) + cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) + update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) + + +def patch_prepare_data_loader(): + """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. + + Raies: + RuntimeError: If source code to patch does not exist. + """ + original_fn = accelerate.data_loader.prepare_data_loader + original_source = inspect.getsource(original_fn) + + if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source: + raise RuntimeError( + "SP patch failed - target snippet not found. " + "Check accelerate's version or update the patch." + ) + + patched_source = original_source.replace( + ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE + ) + + # Create a new function from the patched source + namespace = {} + exec( # pylint: disable=exec-used # nosec B102 + patched_source, accelerate.data_loader.__dict__, namespace + ) + patched_function = namespace["prepare_data_loader"] + + accelerate.data_loader.prepare_data_loader = patched_function + LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") + + +def patch_prepare_device_mesh(sequence_parallel_degree: int): + """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh + that includes sequence parallelism with the specified degree. + + Args: + sequence_parallel_degree (int): The degree of sequence parallelism to use. + """ + + def _prepare_device_mesh(self): + """Prepare the device mesh for distributed training. The dataloader will + determine how to load data based on the device mesh. + """ + if self.state.torch_tp_plugin: + return self.state.torch_tp_plugin.torch_device_mesh + if ( + self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED + and hasattr(self.state, "ds_device_mesh") + ): + return self.state.ds_device_mesh + + # Create device mesh with sequence parallelism + world_size = dist.get_world_size() + mesh_shape = ( + world_size // sequence_parallel_degree, + sequence_parallel_degree, + ) + device_ids = list(range(world_size)) + + # Note that we use "cp" instead of "sp" to match the PyTorch native "context + # parallelism" implementation naming + return dist.DeviceMesh( + "cuda", + torch.tensor(device_ids).reshape(mesh_shape), + mesh_dim_names=("dp", "cp"), + ) + + # Replace the original method with our new method + # pylint: disable=protected-access + accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh + + LOG.info( + "Successfully patched Accelerator._prepare_device_mesh " + f"with sequence_parallel_degree={sequence_parallel_degree}" + ) diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 66044f7f0..6e4f9bada 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -1,6 +1,7 @@ """Module for Axolotl trainer sequence parallelism manager and utilities""" import functools +import inspect import torch import torch.distributed as dist @@ -9,7 +10,7 @@ from torch.utils.hooks import RemovableHandle from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import ModelOutput -from axolotl.monkeypatch.attention.ring_attn.patch import ( +from axolotl.monkeypatch.ring_attn.patch import ( get_ring_attn_group, update_ring_attn_params, ) @@ -206,12 +207,25 @@ class SequenceParallelContextManager: def __enter__(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): - # Apply sequence parallelism to kwargs and get original sequence length and padding info - kwargs, self.original_seq_len, self.pad_len = ( - self.apply_sequence_parallelism(batch=kwargs) + # Get parameter names from the model's forward function + forward_params = list( + inspect.signature(self.models[0].forward).parameters.keys() ) - return args, kwargs + updated_kwargs = kwargs.copy() + for i, arg in enumerate(args): + if i < len(forward_params): + updated_kwargs[forward_params[i]] = arg + + # Any excess positional arguments are kept as-is + remaining_args = args[len(forward_params) :] + + # Apply sequence parallelism to updated kwargs + updated_kwargs, self.original_seq_len, self.pad_len = ( + self.apply_sequence_parallelism(updated_kwargs) + ) + + return remaining_args, updated_kwargs # Forward post-hook to gather outputs def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 316fbec8c..6236f78e8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -59,6 +59,7 @@ from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) +from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config @@ -681,16 +682,25 @@ class ModelLoader: patch_self_attn_lora(self.cfg) if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: - from axolotl.monkeypatch.attention.ring_attn import register_ring_attn + from axolotl.monkeypatch.ring_attn import ( + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, + ) # Initialize ring attn for sequence parallelism. This must be done after # model init but before the first forward pass, since it modifies flash # attn to use ring comm for SP training across multiple GPUs. - register_ring_attn( - sequence_parallel_degree=self.cfg.sequence_parallel_degree, - heads_k_stride=self.cfg.heads_k_stride, - ring_attn_func=self.cfg.ring_attn_func, - ) + if get_ring_attn_group() is None: # If already set, this is already patched + register_ring_attn( + sequence_parallel_degree=self.cfg.sequence_parallel_degree, + heads_k_stride=self.cfg.heads_k_stride, + ring_attn_func=self.cfg.ring_attn_func, + ) + patch_prepare_data_loader() + patch_prepare_device_mesh( + sequence_parallel_degree=self.cfg.sequence_parallel_degree + ) def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 8efe62940..83faa779f 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -10,7 +10,7 @@ import pytest import torch from accelerate.state import PartialState -from axolotl.monkeypatch.attention.ring_attn import ( +from axolotl.monkeypatch.ring_attn import ( get_ring_attn_group, register_ring_attn, set_ring_attn_group, @@ -313,13 +313,13 @@ class TestApplySequenceParallelism: # Mock the process group monkeypatch.setattr( - "axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group", + "axolotl.monkeypatch.ring_attn.get_ring_attn_group", MagicMock, ) # Mock update_ring_attn_params monkeypatch.setattr( - "axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params", + "axolotl.monkeypatch.ring_attn.update_ring_attn_params", lambda **kwargs: None, )