From b13b6e185f71b2800dc61708c3e5ef78ea5f5856 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 17 Apr 2025 04:06:18 +0000 Subject: [PATCH] stronger subclassing of TRL GRPO trainer; custom distributed sampler --- src/axolotl/common/datasets.py | 3 +- src/axolotl/core/trainer_builder.py | 26 +- src/axolotl/core/trainers/dpo/__init__.py | 3 +- src/axolotl/core/trainers/grpo/args.py | 4 +- src/axolotl/core/trainers/grpo/sampler.py | 124 ++++ src/axolotl/core/trainers/grpo/trainer.py | 669 ++++++++++++++++-- src/axolotl/core/training_args.py | 2 +- .../attention/ring_attn/__init__.py | 1 - .../attention/ring_attn/adapters/batch.py | 2 +- .../monkeypatch/attention/ring_attn/patch.py | 14 +- src/axolotl/train.py | 3 +- src/axolotl/utils/data/rl.py | 15 +- src/axolotl/utils/models.py | 3 +- src/axolotl/utils/schemas/config.py | 6 +- src/axolotl/utils/schemas/enums.py | 23 +- 15 files changed, 778 insertions(+), 120 deletions(-) create mode 100644 src/axolotl/core/trainers/grpo/sampler.py diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 3e712f772..397617159 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -14,6 +14,7 @@ from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_processor, load_tokenizer +from axolotl.utils.schemas.enums import RLType from axolotl.utils.tokenization import check_dataset_labels LOG = logging.getLogger(__name__) @@ -125,7 +126,7 @@ def load_preference_datasets( total_num_steps: Optional[int] = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) - if cfg.rl == "grpo": + if cfg.rl is RLType.GRPO: total_num_steps = None if cli_args.debug or cfg.debug: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7527069ee..ad1502c8c 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -84,7 +84,7 @@ from axolotl.utils.collators import ( ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype -from axolotl.utils.schemas.enums import CustomSupportedOptimizers +from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType try: import torch._dynamo # pylint: disable=ungrouped-imports @@ -1054,7 +1054,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_cls = None blocklist_args_kwargs = [] - if self.cfg.rl == "simpo": + if self.cfg.rl is RLType.SIMPO: training_args_cls = AxolotlCPOConfig training_args_kwargs["loss_type"] = "simpo" training_args_kwargs["max_length"] = self.cfg.sequence_len @@ -1062,13 +1062,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha - elif self.cfg.rl == "orpo": + elif self.cfg.rl is RLType.ORPO: training_args_cls = AxolotlORPOConfig training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len - elif self.cfg.rl == "kto": + elif self.cfg.rl is RLType.KTO: training_args_cls = AxolotlKTOConfig training_args_kwargs["desirable_weight"] = ( @@ -1082,14 +1082,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len - elif self.cfg.rl == "grpo": + elif self.cfg.rl is RLType.GRPO: training_args_cls = GRPOStrategy.get_training_args_class() training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() else: training_args_cls = AxolotlDPOConfig - if self.cfg.rl == "ipo": + if self.cfg.rl is RLType.IPO: training_args_kwargs["loss_type"] = "ipo" training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_completion_length"] = None @@ -1127,7 +1127,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def build(self, total_num_steps): training_args = self.build_training_arguments(total_num_steps) dpo_trainer_kwargs = {} - if self.cfg.rl == "ipo": + if self.cfg.rl is RLType.IPO: if self.cfg.dpo_label_smoothing: dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing if self.eval_dataset: @@ -1138,21 +1138,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["precompute_ref_log_probs"] = ( self.cfg.precompute_ref_log_probs ) - if self.cfg.rl == "grpo": + if self.cfg.rl is RLType.GRPO: trainer_cls = GRPOStrategy.get_trainer_class() trainer_cls_args = [self.model] trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) - elif self.cfg.rl in ["dpo", "ipo"]: + elif self.cfg.rl in [RLType.DPO, RLType.IPO]: trainer_cls = DPOStrategy.get_trainer_class() trainer_cls_args = [self.model, self.model_ref] - elif self.cfg.rl == "orpo": + elif self.cfg.rl is RLType.ORPO: trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model] - elif self.cfg.rl in ["kto"]: + elif self.cfg.rl is RLType.KTO: trainer_cls = AxolotlKTOTrainer trainer_cls_args = [self.model] - elif self.cfg.rl in ["simpo"]: + elif self.cfg.rl is RLType.SIMPO: trainer_cls = AxolotlCPOTrainer trainer_cls_args = [self.model] else: @@ -1179,7 +1179,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): ) if self.cfg.fsdp: ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) - if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model: + if self.cfg.rl in [RLType.DPO, RLType.IPO] and dpo_trainer.ref_model: ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype) dpo_trainer = self.hook_post_create_trainer(dpo_trainer) diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 2d6835cf7..64f7b0c0c 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -3,6 +3,7 @@ DPO Specific Strategy for training """ from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer +from axolotl.utils.schemas.enums import RLType class DPOStrategy: @@ -23,7 +24,7 @@ class DPOStrategy: @classmethod def set_training_args_kwargs(cls, cfg): training_args_kwargs = {} - if cfg.rl == "ipo": + if cfg.rl is RLType.IPO: training_args_kwargs["loss_type"] = "ipo" training_args_kwargs["max_length"] = cfg.sequence_len training_args_kwargs["max_completion_length"] = None diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index 5460edca9..76be88c89 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -11,6 +11,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins @dataclass class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): - """ - Axolotl GRPO Config for GRPO training - """ + """Axolotl GRPO Config for GRPO training""" diff --git a/src/axolotl/core/trainers/grpo/sampler.py b/src/axolotl/core/trainers/grpo/sampler.py new file mode 100644 index 000000000..b2010e9b7 --- /dev/null +++ b/src/axolotl/core/trainers/grpo/sampler.py @@ -0,0 +1,124 @@ +""" +Repeat random sampler (akin to the one implemented in +https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds +sequence parallelism functionality; i.e., duplicating data across ranks in the same +sequencee parallel group. +""" + +from typing import Optional, Sized + +import torch +from torch.utils.data import Sampler + + +class SequenceParallelRepeatRandomSampler(Sampler): + """ + Sampler for GRPO training with sequence parallelism that ensures: + 1. Ranks in the same sequence parallel group receive identical data + 2. Each index is repeated multiple times for sampling different completions + 3. Entire batches are repeated for reuse in multiple updates + """ + + def __init__( + self, + dataset: Sized, + mini_repeat_count: int, + batch_size: int = 1, + repeat_count: int = 1, + sequence_parallel_degree: int = 1, + world_size: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ): + self.dataset = dataset + self.mini_repeat_count = mini_repeat_count + self.batch_size = batch_size + self.repeat_count = repeat_count + self.shuffle = shuffle + self.seed = seed + self.drop_last = drop_last + self.epoch = 0 + + self.world_size = world_size + self.rank = rank + + # Sequence parallelism parameters + self.sequence_parallel_degree = sequence_parallel_degree + self.num_sp_groups = world_size // sequence_parallel_degree + self.sp_group_id = rank // sequence_parallel_degree + + # Adjust dataset size for distributed sampling + self.num_samples = len(self.dataset) + self.total_size = self.num_samples + + # Calculate effective number of samples per SP group + if ( + self.drop_last + and self.total_size % (self.num_sp_groups * self.batch_size) != 0 + ): + # Drop last incomplete batch if drop_last is True + self.num_samples_per_sp_group = ( + self.total_size // self.batch_size // self.num_sp_groups + ) * self.batch_size + else: + # Round up to include last batch if drop_last is False + self.num_samples_per_sp_group = ( + (self.total_size + self.batch_size * self.num_sp_groups - 1) + // (self.batch_size * self.num_sp_groups) + * self.batch_size + ) + + def __iter__(self): + # Deterministically shuffle based on epoch and seed + if self.shuffle: + # Use same seed for all ranks in the same SP group + g = torch.Generator() + seed_value = self.seed + self.epoch + self.sp_group_id * 10000 + g.manual_seed(seed_value) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # Add extra samples to make it evenly divisible by batch_size + if len(indices) % self.batch_size != 0: + padding = indices[: self.batch_size - len(indices) % self.batch_size] + indices += padding + + # Subsample based on SP group ID + # Each SP group gets distinct batches of data + batch_indices = [] + for i in range(0, len(indices), self.batch_size * self.num_sp_groups): + start_idx = i + self.sp_group_id * self.batch_size + end_idx = min(start_idx + self.batch_size, len(indices)) + if start_idx < len(indices): + for j in range(self.batch_size): + if start_idx + j < end_idx: + batch_indices.append(indices[start_idx + j]) + + # Make sure batch_indices is exactly batch_size * num_batches_per_sp_group + if self.drop_last: + num_batches_per_sp_group = self.num_samples_per_sp_group // self.batch_size + target_len = self.batch_size * num_batches_per_sp_group + if len(batch_indices) > target_len: + batch_indices = batch_indices[:target_len] + + # Apply the GRPO repeat pattern + final_indices = [] + for _ in range(self.repeat_count): + for idx in batch_indices: + for _ in range(self.mini_repeat_count): + final_indices.append(idx) + + return iter(final_indices) + + def __len__(self): + # Total length including all repetitions + return ( + self.num_samples_per_sp_group * self.mini_repeat_count * self.repeat_count + ) + + def set_epoch(self, epoch): + """Sets the epoch for this sampler""" + self.epoch = epoch diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index ea15088a4..1b7e36bce 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,28 +1,186 @@ """Axolotl GRPO trainer""" +import warnings from contextlib import nullcontext +from typing import Any +import datasets import torch import torch.distributed as dist -from accelerate.utils import is_deepspeed_available, is_peft_model -from trl import GRPOTrainer -from trl.extras.profiling import profiling_decorator -from trl.trainer.utils import selective_log_softmax - -from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin -from axolotl.monkeypatch.attention.ring_attn import ( - get_ring_attn_group, +from accelerate.utils import ( + broadcast_object_list, + gather, + gather_object, + is_peft_model, ) +from torch import nn +from torch.utils.data import ( + BatchSampler, + DataLoader, + Sampler, +) +from transformers import Trainer, is_wandb_available +from transformers.trainer_utils import seed_worker +from trl import GRPOTrainer +from trl.data_utils import ( + apply_chat_template, + is_conversational, + maybe_apply_chat_template, +) +from trl.extras.profiling import profiling_context, profiling_decorator +from trl.import_utils import is_deepspeed_available, is_rich_available +from trl.models import unwrap_model_for_generation +from trl.trainer.utils import ( + pad, + print_prompt_completions_sample, + selective_log_softmax, +) + +from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler +from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group if is_deepspeed_available(): import deepspeed +if is_wandb_available(): + import wandb + class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): """Extend the base GRPOTrainer for axolotl helpers""" _tag_names = ["trl", "grpo", "axolotl"] + def __init__(self, *args, **kwargs): + # Call parent constructor with all arguments + super().__init__(*args, **kwargs) + + # Initialize the SP group + self.sp_group = get_ring_attn_group() + self.local_rank = dist.get_rank(group=self.sp_group) + self.local_world_size = dist.get_world_size(group=self.sp_group) + + def _get_train_sampler(self) -> Sampler: + # Get distributed training info + world_size = dist.get_world_size() + rank = dist.get_rank() + + effective_batch_size = ( + self.args.per_device_train_batch_size + * world_size + * self.args.gradient_accumulation_steps + ) + + return SequenceParallelRepeatRandomSampler( + dataset=self.train_dataset, + mini_repeat_count=self.num_generations, + batch_size=effective_batch_size + // self.num_generations + // self.args.sequence_parallel_degree, + repeat_count=self.num_iterations, + sequence_parallel_degree=self.args.sequence_parallel_degree, + world_size=world_size, + rank=rank, + shuffle=True, + seed=self.args.seed, + drop_last=True, + ) + + def _create_dataloader_params(self, is_eval=False, custom_batch_size=None): + """Create common dataloader parameters for train or eval.""" + batch_size = custom_batch_size or ( + self.args.eval_batch_size if is_eval else self._train_batch_size + ) + + params = { + "batch_size": batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + + # Add persistent workers only for training + if not is_eval and hasattr(self.args, "dataloader_persistent_workers"): + params["persistent_workers"] = self.args.dataloader_persistent_workers + + # Add prefetch factor if specified + if self.args.dataloader_prefetch_factor: + params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return params + + def _prepare_dataloader( + self, dataset, sampler, is_eval=False, custom_batch_size=None + ): + """Prepare a dataloader with the given dataset and sampler.""" + # Get base parameters + dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size) + + # Add sampler configuration + if not isinstance(dataset, torch.utils.data.IterableDataset): + if isinstance(sampler, BatchSampler): + # batch_size and batch_sampler are mutually exclusive + dataloader_params["batch_sampler"] = sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + if not is_eval: + dataloader_params["worker_init_fn"] = seed_worker + + # Create the dataloader + dataloader = DataLoader(dataset, **dataloader_params) + + if self.args.sample_packing and ( + (not is_eval and not self.args.pretraining) + or (is_eval and self.args.eval_sample_packing is not False) + ): + self.accelerator.even_batches = False + + # Return unprepared dataloader if using sequence parallelism + # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation + # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., + # slice each batch along the sequence dimension). + if self.args.sequence_parallel_degree > 1: + return dataloader + + # Otherwise prepare with accelerator + return self.accelerator.prepare_data_loader(dataloader) + + def get_train_dataloader(self) -> DataLoader: + """Get dataloader for training""" + train_dataset = self.train_dataset + data_collator = self.data_collator # type: ignore + + # Initialize SP group attributes if sequence parallelism is enabled + if self.args.sequence_parallel_degree > 1: + self.sp_group = get_ring_attn_group() + self.local_rank = dist.get_rank(group=self.sp_group) + self.local_world_size = dist.get_world_size(group=self.sp_group) + + # Handle dataset preprocessing + if isinstance(train_dataset, datasets.Dataset): + # Add debug print before any modifications + if self.args.sample_packing and not self.args.pretraining: + train_dataset = train_dataset.remove_columns(["length"]) + if not self.args.sample_packing or self.args.pretraining: + train_dataset = self._remove_unused_columns( + train_dataset, description="training" + ) + else: + self.data_collator = self._get_collator_with_removed_columns( + data_collator, + description="training", + ) + + # Get sampler and create dataloader + sampler = self._get_train_sampler() + dataloader = self._prepare_dataloader(train_dataset, sampler, is_eval=False) + + return dataloader + @profiling_decorator def _move_model_to_vllm(self): # For DeepSpeed ZeRO-3, we need to gather all parameters before operations @@ -70,20 +228,376 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): if self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() + def _generate_and_score_completions( + self, inputs: dict[str | torch.Tensor | Any] + ) -> dict[str, torch.Tensor | Any]: + device = self.accelerator.device + prompts = [x["prompt"] for x in inputs] + prompts_text = [ + maybe_apply_chat_template(example, self.processing_class)["prompt"] + for example in inputs + ] + prompt_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + add_special_tokens=False, + ) + prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs) + + prompt_ids, prompt_mask = ( + prompt_inputs["input_ids"], + prompt_inputs["attention_mask"], + ) + + if self.max_prompt_length is not None: + prompt_ids = prompt_ids[:, -self.max_prompt_length :] + prompt_mask = prompt_mask[:, -self.max_prompt_length :] + + # Generate completions using either vLLM or regular generation + if self.args.use_vllm: + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_prompts_text = gather_object(prompts_text) + + if self.accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + with profiling_context(self, "vLLM.generate"): + completion_ids = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + ) + else: + completion_ids = [None] * len(all_prompts_text) + + # Broadcast the completions from the main process to all processes + completion_ids = broadcast_object_list(completion_ids, from_process=0) + + # Determine the appropriate slice based on sequence parallelism + if self.args.sequence_parallel_degree > 1: + # Calculate SP group ID (which group of ranks this rank belongs to) + sp_group_id = self.accelerator.process_index // self.local_world_size + + # Calculate the start index for this SP group + sp_group_start = sp_group_id * len(prompts) * self.local_world_size + + # All ranks in the same SP group get the same data slice + # This ensures identical inputs for sequence-parallel processing + process_slice = slice( + sp_group_start, + sp_group_start + len(prompts) * self.local_world_size, + ) + + # Take the full SP group's worth of completions + completion_ids = completion_ids[process_slice] + else: + # Original behavior for non-sequence-parallel case + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + completion_ids = completion_ids[process_slice] + + if dist.get_rank() == 0: + import ipdb + + ipdb.set_trace() + dist.barrier() + if dist.get_rank() == 1: + import ipdb + + ipdb.set_trace() + dist.barrier() + + # Pad the completions, and concatenate them with the prompts + completion_ids = [ + torch.tensor(ids, device=device) for ids in completion_ids + ] + completion_ids = pad( + completion_ids, padding_value=self.processing_class.pad_token_id + ) + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + else: + # Regular generation path + with unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + ) as unwrapped_model: + prompt_completion_ids = unwrapped_model.generate( + prompt_ids, + attention_mask=prompt_mask, + generation_config=self.generation_config, + ) + + # Compute prompt length and extract completion ids + prompt_length = prompt_ids.size(1) + prompt_ids = prompt_completion_ids[:, :prompt_length] + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.processing_class.eos_token_id + eos_idx = torch.full( + (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device + ) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand( + is_eos.size(0), -1 + ) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # Concatenate prompt_mask with completion_mask for logit computation + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + logits_to_keep = completion_ids.size( + 1 + ) # we only need to compute the logits for the completion tokens + + with torch.no_grad(): + # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's + # computation here, and use per_token_logps.detach() instead. + if self.num_iterations > 1: + old_per_token_logps = self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask, logits_to_keep + ) + else: + old_per_token_logps = None + + if self.beta == 0.0: + ref_per_token_logps = None + elif self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + ) + + # Decode the generated completions + completions_text = self.processing_class.batch_decode( + completion_ids, skip_special_tokens=True + ) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = ( + prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + ) + completions.append( + [{"role": "assistant", "content": bootstrap + completion}] + ) + else: + completions = completions_text + + rewards_per_func = torch.zeros( + len(prompts), len(self.reward_funcs), device=device + ) + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes) + ): + if isinstance( + reward_func, nn.Module + ): # Module instead of PretrainedModel for compat with compiled models + reward_func_name = ( + f"reward {reward_func.config._name_or_path.split('/')[-1]}" + ) + else: + reward_func_name = reward_func.__name__ + with profiling_context(self, reward_func_name): + if isinstance( + reward_func, nn.Module + ): # Module instead of PretrainedModel for compat with compiled models + if is_conversational(inputs[0]): + messages = [ + {"messages": p + c} for p, c in zip(prompts, completions) + ] + texts = [ + apply_chat_template(x, reward_processing_class)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + text=texts, + return_tensors="pt", + padding=True, + padding_side="right", + add_special_tokens=False, + ) + reward_inputs = Trainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[ + :, 0 + ] # Shape (B*G,) + else: + # Repeat all input columns (but "prompt" and "completion") to match the number of generations + keys = [ + key for key in inputs[0] if key not in ["prompt", "completion"] + ] + reward_kwargs = { + key: [example[key] for example in inputs] for key in keys + } + output_reward_func = reward_func( + prompts=prompts, completions=completions, **reward_kwargs + ) + # Convert None values to NaN + output_reward_func = [ + reward if reward is not None else torch.nan + for reward in output_reward_func + ] + + rewards_per_func[:, i] = torch.tensor( + output_reward_func, dtype=torch.float32, device=device + ) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = ( + torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + ) + row_reward_kwargs = { + key: value[nan_row_idx] for key, value in reward_kwargs.items() + } + row_reward_kwargs["prompt"] = prompts[nan_row_idx] + row_reward_kwargs["completion"] = completions[nan_row_idx] + warnings.warn( + f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " + "Please ensure that at least one reward function returns a valid reward." + ) + + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + + # Apply weights to each reward function's output and sum + rewards = ( + rewards_per_func * self.reward_weights.to(device).unsqueeze(0) + ).nansum(dim=1) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave( + self.num_generations, dim=0 + ) + std_grouped_rewards = std_grouped_rewards.repeat_interleave( + self.num_generations, dim=0 + ) + advantages = rewards - mean_grouped_rewards + if self.args.scale_rewards: + advantages = advantages / (std_grouped_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + + # Log the metrics + mode = "eval" if self.control.should_evaluate else "train" + + if mode == "train": + self._total_train_tokens += ( + self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item() + ) + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + completion_length = ( + self.accelerator.gather_for_metrics(completion_mask.sum(1)) + .float() + .mean() + .item() + ) + self._metrics[mode]["completion_length"].append(completion_length) + + # Calculate mean reward per function, but only for samples where the function was applied + for i, reward_func in enumerate(self.reward_funcs): + if isinstance( + reward_func, nn.Module + ): # Module instead of PretrainedModel for compat with compiled models + reward_func_name = reward_func.config._name_or_path.split("/")[-1] + else: + reward_func_name = reward_func.__name__ + # Only calculate mean for samples where this reward function was applied (non-NaN values) + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards) + self._metrics[mode]["reward"].append(rewards.mean().item()) + self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) + + if ( + self.log_completions + and self.state.global_step % self.args.logging_steps == 0 + ): + prompts_to_log = gather_object(prompts_text) + completions_to_log = gather_object(completions_text) + rewards_to_log = rewards.tolist() + + if self.accelerator.is_main_process: + if is_rich_available(): + print_prompt_completions_sample( + prompts_to_log, + completions_to_log, + rewards_to_log, + self.state.global_step, + ) + if ( + self.args.report_to + and "wandb" in self.args.report_to + and wandb.run is not None + ): + import pandas as pd + + # For logging + table = { + "step": [str(self.state.global_step)] * len(rewards), + "prompt": prompts_to_log, + "completion": completions_to_log, + "reward": rewards.tolist(), + } + df = pd.DataFrame(table) + wandb.log({"completions": wandb.Table(dataframe=df)}) + + return { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "old_per_token_logps": old_per_token_logps, + "ref_per_token_logps": ref_per_token_logps, + "advantages": advantages, + } + # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - if dist.get_rank() == 0: - import ipdb; ipdb.set_trace() - dist.barrier() - - if dist.get_rank() == 1: - import ipdb; ipdb.set_trace() - dist.barrier() - if self.args.sequence_parallel_degree > 1: - sp_group = get_ring_attn_group() - self.local_rank = dist.get_rank(group=sp_group) - self.local_world_size = dist.get_world_size(group=sp_group) + print(f"{self.local_rank}: input_ids.shape: {input_ids.shape}") + print(f"{self.local_rank}: input_ids[0, :20]: {input_ids[0, :20]}") + print(f"{self.local_rank}: input_ids[0, -20:]: {input_ids[0, -20:]}") # Pad sequence if needed total_seq_len = input_ids.shape[1] @@ -123,7 +637,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): # Calculate if this rank contains any tokens we need to keep tokens_before_our_slice = self.local_rank * slice_size print(f"{self.local_rank}: slice_size: {slice_size}") - print(f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}") + print( + f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}" + ) if tokens_before_our_slice < logits_to_keep: # How many tokens from our slice are needed tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice @@ -132,59 +648,76 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): # This rank doesn't contain any tokens we need to keep logits_to_keep = 0 - print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}") + print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}") - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits - logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + logits_to_keep=logits_to_keep + 1, + ).logits + logits = logits[ + :, :-1, : + ] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - print(f"{self.local_rank}: logits.shape: {logits.shape}") + print(f"{self.local_rank}: logits.shape: {logits.shape}") - # First, let all ranks know the shape of each rank's tensor - local_shape = torch.tensor([logits.shape[0], logits.shape[1], logits.shape[2]], device=logits.device) - all_shapes = [torch.zeros_like(local_shape) for _ in range(self.local_world_size)] - dist.all_gather(all_shapes, local_shape, group=sp_group) - - # Use a list-based approach to collect logits of different sizes - if self.local_rank == 0: - # Root process allocates space for receiving - gathered_logits = [] - for shape in all_shapes: - b, s, v = shape.tolist() - gathered_logits.append(torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device)) - else: - gathered_logits = None - - # Gather to rank 0 - dist.gather(logits, gathered_logits, dst=0, group=sp_group) - - # On rank 0, concatenate and distribute the result - if self.local_rank == 0: - concatenated_logits = torch.cat(gathered_logits, dim=1) - # Trim to keep only what we need - if concatenated_logits.shape[1] > logits_to_keep: - concatenated_logits = concatenated_logits[:, -logits_to_keep:, :] - else: - concatenated_logits = torch.zeros( - (logits.shape[0], logits_to_keep, logits.shape[2]), - dtype=logits.dtype, - device=logits.device + # First, let all ranks know the shape of each rank's tensor + local_shape = torch.tensor( + [logits.shape[0], logits.shape[1], logits.shape[2]], + device=logits.device, ) - - # Broadcast the result back to all ranks - dist.broadcast(concatenated_logits, src=0, group=sp_group) - logits = concatenated_logits + all_shapes = [ + torch.zeros_like(local_shape) for _ in range(self.local_world_size) + ] + dist.all_gather(all_shapes, local_shape, group=self.sp_group) - input_ids = input_ids[:, -logits_to_keep:] - # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. - # See https://github.com/huggingface/trl/issues/2770 - logits = logits[:, -logits_to_keep:] - # Divide logits by sampling temperature. - # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details - logits = logits / self.temperature + # Use a list-based approach to collect logits of different sizes + if self.local_rank == 0: + # Root process allocates space for receiving + gathered_logits = [] + for shape in all_shapes: + b, s, v = shape.tolist() + gathered_logits.append( + torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device) + ) + else: + gathered_logits = None - dist.barrier() + # Gather to rank 0 + dist.gather(logits, gathered_logits, dst=0, group=self.sp_group) - return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + # On rank 0, concatenate and distribute the result + if self.local_rank == 0: + concatenated_logits = torch.cat(gathered_logits, dim=1) + # Trim to keep only what we need + if concatenated_logits.shape[1] > logits_to_keep: + concatenated_logits = concatenated_logits[:, -logits_to_keep:, :] + else: + concatenated_logits = torch.zeros( + (logits.shape[0], logits_to_keep, logits.shape[2]), + dtype=logits.dtype, + device=logits.device, + ) - # super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + # Broadcast the result back to all ranks + dist.broadcast(concatenated_logits, src=0, group=self.sp_group) + logits = concatenated_logits + + input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + logits = logits[:, -logits_to_keep:] + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits = logits / self.temperature + + dist.barrier() + + return selective_log_softmax( + logits, input_ids + ) # compute logprobs for the input tokens + else: + super()._get_per_token_logps( + model, input_ids, attention_mask, logits_to_keep + ) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 3fe32f507..0b14e7661 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -9,7 +9,7 @@ from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig -from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc +from axolotl.utils.schemas.enums import RingAttnFunc @dataclass diff --git a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py b/src/axolotl/monkeypatch/attention/ring_attn/__init__.py index 055607e92..a50ad456e 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/__init__.py @@ -4,7 +4,6 @@ # flake8: noqa from .patch import ( - RingAttnFunc, get_ring_attn_group, register_ring_attn, set_ring_attn_group, diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py index a88c9f6f1..13daf7451 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py @@ -28,7 +28,7 @@ from transformers.modeling_flash_attention_utils import ( ) from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc +from axolotl.utils.schemas.enums import RingAttnFunc RING_ATTN_FUNC_MAPPING = { RingAttnFunc.BATCH_RING: ring_flash_attn_func, diff --git a/src/axolotl/monkeypatch/attention/ring_attn/patch.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py index b5587ddca..1087d1605 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/patch.py @@ -6,14 +6,13 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc their sequence parallel version of Flash Attention 2. """ -from enum import Enum - import torch import torch.distributed as dist from accelerate.logging import get_logger from axolotl.logging_config import configure_logging from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.schemas.enums import RingAttnFunc configure_logging() LOG = get_logger(__name__) @@ -43,17 +42,6 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): RING_ATTN_GROUP = ring_attn_group -class RingAttnFunc(str, Enum): - """Enum class for supported `ring-flash-attn` implementations""" - - # VARLEN_RING = "varlen_ring" - # VARLEN_ZIGZAG = "varlen_zigzag" - VARLEN_LLAMA3 = "varlen_llama3" - BATCH_RING = "batch_ring" - BATCH_ZIGZAG = "batch_zigzag" - BATCH_STRIPE = "batch_stripe" - - def register_ring_attn( sequence_parallel_degree: int, heads_k_stride: int | None, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d116ea4fd..3829dbcef 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -34,6 +34,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer +from axolotl.utils.schemas.enums import RLType from axolotl.utils.trainer import setup_trainer try: @@ -108,7 +109,7 @@ def setup_reference_model( Reference model if needed for RL training, `None` otherwise. """ model_ref = None - if cfg.rl and cfg.rl != "orpo": + if cfg.rl and cfg.rl != RLType.ORPO: if cfg.adapter and not cfg.rl_adapter_ref_model: # use built-in trl autounwrap LOG.debug("Passing model_ref: None to RL trainer") diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 4c7b71292..e1cd4e0e8 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -18,8 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.models import load_tokenizer +from axolotl.utils.schemas.enums import RLType -LOG = logging.getLogger("axolotl") +LOG = logging.getLogger(__name__) def _get_path(ds_hash, cfg): @@ -80,7 +81,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): def drop_long_rl_seq( sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name ): - if rl in ("dpo", "ipo", "orpo", "simpo"): + if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO): if not ( sample.get("prompt") and sample.get("chosen") and sample.get("rejected") ): @@ -100,7 +101,7 @@ def drop_long_rl_seq( len_prompt + len_rejected ) <= sequence_len - if rl == "kto": + if rl is RLType.KTO: if not (sample.get("prompt") and sample.get("completion")): raise ValueError("Prompt and completion keys are required for KTO datasets") @@ -114,7 +115,7 @@ def drop_long_rl_seq( return (len_prompt + len_completion) <= sequence_len - if rl == "grpo": + if rl is RLType.GRPO: return True raise ValueError("Unknown RL type") @@ -137,9 +138,9 @@ def load_prepare_preference_datasets(cfg): if _type: if isinstance(_type, DictDefault): _type = "user_defined.default" - if _cfg.rl == "orpo": + if _cfg.rl is RLType.ORPO: ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) - elif _cfg.rl == "kto": + elif _cfg.rl is RLType.KTO: ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) else: ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) @@ -150,7 +151,7 @@ def load_prepare_preference_datasets(cfg): split_datasets[i] = map_dataset( cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs ) - elif _cfg.rl == "kto": + elif _cfg.rl is RLType.KTO: ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) map_kwargs = {} if isinstance(ds_transform_fn, tuple): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d7105daba..4e9f87d94 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -72,6 +72,7 @@ from axolotl.utils.distributed import ( from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant +from axolotl.utils.schemas.enums import RLType LOG = logging.getLogger(__name__) @@ -1340,7 +1341,7 @@ class ModelLoader: # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config if ( self.cfg.adapter - and self.cfg.rl in ["dpo", "ipo", "kto"] + and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO] and not self.cfg.merge_lora ): _, lora_config = load_lora( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d3f7ae887..37aefaabc 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -29,7 +29,7 @@ from axolotl.utils.schemas.datasets import ( StepwiseSupervisedDataset, ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters -from axolotl.utils.schemas.enums import ChatTemplate, RLType +from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType from axolotl.utils.schemas.integrations import ( CometConfig, GradioConfig, @@ -261,7 +261,7 @@ class AxolotlInputConfig( sequence_parallel_degree: int | None = None heads_k_stride: int | None = None - ring_attn_func: str | None = None + ring_attn_func: RingAttnFunc | None = None special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None @@ -785,7 +785,7 @@ class AxolotlInputConfig( @model_validator(mode="after") def check_simpo_warmup(self): - if self.rl == "simpo" and self.warmup_ratio: + if self.rl is RLType.SIMPO and self.warmup_ratio: raise ValueError( "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" ) diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index d8d9f2834..942a96591 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -6,12 +6,12 @@ from enum import Enum class RLType(str, Enum): """RL trainer type configuration subset""" - dpo = "dpo" # pylint: disable=invalid-name - grpo = "grpo" # pylint: disable=invalid-name - ipo = "ipo" # pylint: disable=invalid-name - orpo = "orpo" # pylint: disable=invalid-name - kto = "kto" # pylint: disable=invalid-name - simpo = "simpo" # pylint: disable=invalid-name + DPO = "dpo" # pylint: disable=invalid-name + GRPO = "grpo" # pylint: disable=invalid-name + IPO = "ipo" # pylint: disable=invalid-name + ORPO = "orpo" # pylint: disable=invalid-name + KTO = "kto" # pylint: disable=invalid-name + SIMPO = "simpo" # pylint: disable=invalid-name class ChatTemplate(str, Enum): @@ -53,3 +53,14 @@ class CustomSupportedOptimizers(str, Enum): ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name muon = "muon" # pylint: disable=invalid-name + + +class RingAttnFunc(str, Enum): + """Enum class for supported `ring-flash-attn` implementations""" + + # VARLEN_RING = "varlen_ring" + # VARLEN_ZIGZAG = "varlen_zigzag" + VARLEN_LLAMA3 = "varlen_llama3" + BATCH_RING = "batch_ring" + BATCH_ZIGZAG = "batch_zigzag" + BATCH_STRIPE = "batch_stripe"