From c25990fd4f3031af9ce34a37d48fb626d4831267 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 14 May 2025 02:09:20 +0000 Subject: [PATCH] additional RL trainers SP support --- src/axolotl/core/trainers/base.py | 6 +- src/axolotl/core/trainers/dpo/trainer.py | 184 +++++------------- src/axolotl/core/trainers/grpo/trainer.py | 3 - .../utils/ctx_managers/sequence_parallel.py | 25 ++- 4 files changed, 72 insertions(+), 146 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 2f0ce6894..25baed7bc 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -156,6 +156,9 @@ class AxolotlTrainer( Helper method to get the sampler for evaluation. Handles sequence parallelism and sample packing cases. + Args: + eval_dataset: Evaluation dataset. + Returns: If the dataset is non-empty, a sampler is returned, the type of which depends on the passed training args. @@ -237,9 +240,6 @@ class AxolotlTrainer( self.accelerator.even_batches = False # Return unprepared dataloader if using sequence parallelism - # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation - # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., - # slice each batch along the sequence dimension). if self.args.sequence_parallel_degree > 1: return dataloader diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 1ce7deea7..ea608a582 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -1,33 +1,25 @@ -""" -DPO trainer for axolotl -""" +"""DPO trainer for Axolotl""" import gc -import random from functools import wraps -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union -import pandas as pd import torch -import wandb -from accelerate import PartialState -from datasets import Dataset, IterableDataset +from datasets import Dataset from peft.optimizers import create_loraplus_optimizer from torch import nn -from torch.utils.data import DataLoader +from torch.utils.data import Sampler from transformers import ( - BaseImageProcessor, - FeatureExtractionMixin, - PreTrainedTokenizerBase, - ProcessorMixin, Trainer, ) -from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt -from trl.trainer.utils import log_table_to_comet_experiment +from trl import DPOTrainer -from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.core.trainers.mixins import ( + RngLoaderMixin, + SchedulerMixin, + SequenceParallelMixin, +) from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, @@ -37,10 +29,10 @@ if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp -class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): - """ - Extend the base DPOTrainer for axolotl helpers - """ +class AxolotlDPOTrainer( + RngLoaderMixin, SchedulerMixin, SequenceParallelMixin, DPOTrainer +): + """Extend the base DPOTrainer for axolotl helpers""" tag_names = ["axolotl", "dpo"] @@ -95,64 +87,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): return super().push_to_hub(*args, **kwargs) - # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release - def _prepare_dataset( - self, - dataset: Union[Dataset, IterableDataset], - processing_class: Union[ - PreTrainedTokenizerBase, - BaseImageProcessor, - FeatureExtractionMixin, - ProcessorMixin, - ], - args: DPOConfig, - dataset_name: str, - ) -> Union[Dataset, IterableDataset]: - # Build the kwargs for the `map` function - map_kwargs: Dict[str, Any] = {"writer_batch_size": 10} - if isinstance(dataset, Dataset): # IterableDataset does not support num_proc - map_kwargs["num_proc"] = args.dataset_num_proc - - with PartialState().main_process_first(): - # Extract prompt if needed - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" - dataset = dataset.map(maybe_extract_prompt, **map_kwargs) - - # Apply the chat template if needed - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" - dataset = dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, - **map_kwargs, - ) - - # Tokenize the dataset - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - - dataset = dataset.map( - self.tokenize_row if not self.is_vision_model else self.process_row, - remove_columns=["chosen", "rejected"], - fn_kwargs={ - "processing_class": processing_class, - "max_prompt_length": args.max_prompt_length, - "max_completion_length": args.max_completion_length, - # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) - "add_special_tokens": False, - }, - **map_kwargs, - ) - - return dataset - @staticmethod def tokenize_row( features, @@ -193,68 +127,48 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): torch.cuda.empty_cache() return loss - # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: + def _get_train_sampler(self) -> Sampler | None: """ - Overriding built-in evaluation loop to store metrics for each batch. - Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + Helper method to get the sampler for training. Handles cases for sequence + parallelism, sample packing, and curriculum sampling (sequential). - Works both with or without labels. + Returns: + If the dataset is non-empty, a sampler is returned, the type of which + depends on the passed training args. """ + import torch.distributed as dist - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample( - range(num_samples), k=self.args.eval_batch_size - ) + if dist.get_rank() == 0: + import ipdb - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) + ipdb.set_trace() + dist.barrier() + if dist.get_rank() == 1: + import ipdb - policy_output_decoded, ref_output_decoded = ( - self.generate_from_model_and_ref(self.model, random_batch) - ) + ipdb.set_trace() + dist.barrier() - table = pd.DataFrame( - columns=["Prompt", "Policy", "Ref Model"], - data=[ - [prompt, pol[len(prompt) :], ref[len(prompt) :]] - for prompt, pol, ref in zip( - random_batch_dataset["prompt"], - policy_output_decoded, - ref_output_decoded, - ) - ], - ) - if "wandb" in self.args.report_to and self.accelerator.is_main_process: - wandb.log({"game_log": wandb.Table(data=table)}) + if self.args.sequence_parallel_degree > 1: + return self._sp_get_train_sampler(self.train_dataset) - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) + return super()._get_train_sampler() - # Base evaluation - initial_output = super( # pylint: disable=bad-super-call - DPOTrainer, self - ).evaluation_loop( - dataloader, - description, - prediction_loss_only, - ignore_keys, - metric_key_prefix, - ) + def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None: + """ + Helper method to get the sampler for evaluation. Handles sequence parallelism + and sample packing cases. - return initial_output + Args: + eval_dataset: Evaluation dataset. + + Returns: + If the dataset is non-empty, a sampler is returned, the type of which + depends on the passed training args. + """ + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.args.sequence_parallel_degree > 1: + return self._sp_get_eval_sampler(eval_dataset) + + return super()._get_eval_sampler(eval_dataset) diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index bc3d140b1..8205f14e4 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -266,9 +266,6 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): self.accelerator.even_batches = False # Return unprepared dataloader if using sequence parallelism - # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation - # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., - # slice each batch along the sequence dimension). if self.args.sequence_parallel_degree > 1: return dataloader diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 66044f7f0..48bd53c23 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -1,6 +1,7 @@ """Module for Axolotl trainer sequence parallelism manager and utilities""" import functools +import inspect import torch import torch.distributed as dist @@ -32,7 +33,7 @@ def apply_sequence_parallelism( to only keep the last N tokens in the sequence during generation. Args: - batch: Batch dictionary (e.g., input_ids, attention_mask, etc.). + batch: Dictionary of model arguments (e.g., input_ids, attention_mask, etc.). local_rank: Local rank in the sequence parallel group. local_world_size: World size of the sequence parallel group. gradient_accumulation_steps: Number of steps to accumulate gradients over. @@ -206,12 +207,26 @@ class SequenceParallelContextManager: def __enter__(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): - # Apply sequence parallelism to kwargs and get original sequence length and padding info - kwargs, self.original_seq_len, self.pad_len = ( - self.apply_sequence_parallelism(batch=kwargs) + # Convert all args to kwargs using the model's forward function signature + updated_kwargs = kwargs.copy() + + # Get parameter names from the model's forward function + forward_params = list( + inspect.signature(self.models[0].forward).parameters.keys() ) - return args, kwargs + # Map args to their parameter names + for i, arg in enumerate(args): + if i < len(forward_params): + param_name = forward_params[i] + updated_kwargs[param_name] = arg + + # Apply sequence parallelism to empty args and updated kwargs + updated_kwargs, self.original_seq_len, self.pad_len = ( + self.apply_sequence_parallelism(updated_kwargs) + ) + + return (), updated_kwargs # Forward post-hook to gather outputs def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: