From 6810f0ee19e604764b6b8e311f61bbf7ba9e477c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 23 Apr 2025 19:04:26 +0000 Subject: [PATCH] minimize diffs to GRPO trainer --- src/axolotl/core/trainers/grpo/trainer.py | 315 ++-------------------- 1 file changed, 19 insertions(+), 296 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index b7e374356..0186baacc 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-lines,duplicate-code import warnings -from collections import defaultdict from contextlib import nullcontext from typing import Any @@ -15,7 +14,6 @@ from accelerate.utils import ( gather, gather_object, is_peft_model, - set_seed, ) from datasets import Dataset, IterableDataset from torch import nn @@ -25,17 +23,12 @@ from torch.utils.data import ( Sampler, ) from transformers import ( - AutoModelForCausalLM, - AutoModelForSequenceClassification, - AutoTokenizer, - GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainerCallback, is_wandb_available, ) -from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer_utils import seed_worker from transformers.utils import is_peft_available from trl import GRPOTrainer @@ -45,18 +38,13 @@ from trl.data_utils import ( maybe_apply_chat_template, ) from trl.extras.profiling import profiling_context, profiling_decorator -from trl.extras.vllm_client import VLLMClient from trl.import_utils import ( is_deepspeed_available, is_rich_available, - is_vllm_available, ) from trl.models import ( - create_reference_model, - prepare_deepspeed, unwrap_model_for_generation, ) -from trl.trainer.callbacks import SyncRefModelCallback from trl.trainer.grpo_config import GRPOConfig from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.utils import ( @@ -71,7 +59,7 @@ from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group if is_peft_available(): # pylint: disable=unused-import - from peft import PeftConfig, get_peft_model + from peft import PeftConfig if is_deepspeed_available(): import deepspeed @@ -104,191 +92,21 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): ] = (None, None), peft_config: "PeftConfig | None" = None, ): - RngLoaderMixin.__init__(self) - SchedulerMixin.__init__(self) - - # Args - if args is None: - model_name = model if isinstance(model, str) else model.config._name_or_path - model_name = model_name.split("/")[-1] - args = GRPOConfig(f"{model_name}-GRPO") - - # Models - # Trained model - model_init_kwargs = args.model_init_kwargs or {} - if isinstance(model, str): - model_id = model - torch_dtype = model_init_kwargs.get("torch_dtype") - if ( - isinstance(torch_dtype, torch.dtype) - or torch_dtype == "auto" - or torch_dtype is None - ): - pass # torch_dtype is already a torch.dtype or "auto" or None - elif isinstance(torch_dtype, str): # it's a str, but not "auto" - torch_dtype = getattr(torch, torch_dtype) - model_init_kwargs["torch_dtype"] = torch_dtype - else: - raise ValueError( - "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." - ) - # Disable caching if gradient checkpointing is enabled (not supported) - model_init_kwargs["use_cache"] = ( - False - if args.gradient_checkpointing - else model_init_kwargs.get("use_cache") - ) - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) - else: - model_id = model.config._name_or_path - if args.model_init_kwargs is not None: - raise ValueError( - "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " - "This argument can only be used when the `model` argument is a string." - ) - - if peft_config is not None: - if not is_peft_available(): - raise ImportError( - "PEFT is required to use `peft_config`. Run `pip install peft`." - ) - model = get_peft_model(model, peft_config) - - # Enable gradient checkpointing if requested - if args.gradient_checkpointing: - model = self._enable_gradient_checkpointing(model, args) - - # Reference model - self.beta = args.beta - if self.beta == 0.0: - # If beta is 0.0, the reference model is not needed - self.ref_model = None - elif is_deepspeed_zero3_enabled(): - self.ref_model = AutoModelForCausalLM.from_pretrained( - model_id, **model_init_kwargs - ) - elif is_peft_model(model): - # If PEFT is used, the reference model is not needed since the adapter can be disabled - # to revert to the initial model. - self.ref_model = None - else: - # If PEFT configuration is not provided, create a reference model based on the initial model. - self.ref_model = create_reference_model(model) - - # Processing class - if processing_class is None: - processing_class = AutoTokenizer.from_pretrained( - model.config._name_or_path, padding_side="left" - ) - - # Reward functions - if not isinstance(reward_funcs, list): - reward_funcs = [reward_funcs] - for i, reward_func in enumerate(reward_funcs): - if isinstance(reward_func, str): - reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( - reward_func, num_labels=1, **model_init_kwargs - ) - self.reward_funcs = reward_funcs - - # Reward weights - if args.reward_weights is not None: - if len(args.reward_weights) != len(reward_funcs): - raise ValueError( - f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " - f"functions ({len(reward_funcs)})" - ) - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) - else: - self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) - - # Reward processing class - if reward_processing_classes is None: - reward_processing_classes = [None] * len(reward_funcs) - elif not isinstance(reward_processing_classes, list): - reward_processing_classes = [reward_processing_classes] - else: - if len(reward_processing_classes) != len(reward_funcs): - raise ValueError( - "The number of reward processing classes must match the number of reward functions." - ) - - for i, (reward_processing_class, reward_func) in enumerate( - zip(reward_processing_classes, reward_funcs) - ): - if isinstance(reward_func, PreTrainedModel): - if reward_processing_class is None: - reward_processing_class = AutoTokenizer.from_pretrained( - reward_func.config._name_or_path - ) - if reward_processing_class.pad_token_id is None: - reward_processing_class.pad_token = ( - reward_processing_class.eos_token - ) - # The reward model computes the reward for the latest non-padded token in the input sequence. - # So it's important to set the pad token ID to the padding token ID of the processing class. - reward_func.config.pad_token_id = reward_processing_class.pad_token_id - reward_processing_classes[i] = reward_processing_class - self.reward_processing_classes = reward_processing_classes - - # Data collator - def data_collator(features): # No data collation is needed in GRPO - return features - - # Training arguments - self.max_prompt_length = args.max_prompt_length - self.max_completion_length = ( - args.max_completion_length - ) # = |o_i| in the GRPO paper - self.num_generations = args.num_generations # = G in the GRPO paper - self.temperature = args.temperature - self.top_p = args.top_p - self.top_k = args.top_k - self.min_p = args.min_p - self.repetition_penalty = args.repetition_penalty - self.use_vllm = args.use_vllm - - # Multi-step - self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper - self.epsilon_low = args.epsilon - self.epsilon_high = ( - args.epsilon_high if args.epsilon_high is not None else args.epsilon - ) - # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle - self._step = 0 - # Buffer the batch to reuse generated outputs across multiple updates. For more details, see - # `_get_train_sampler` and `_prepare_inputs`. - self._buffered_inputs = [None] * args.gradient_accumulation_steps - - # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the - # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the - # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: - # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To - # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. - # This acts as a flag to indicate that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True - - # Initialize the metrics - self._metrics: dict[str, dict[str, list]] = { - "train": defaultdict(list), - "eval": defaultdict(list), - } - self._total_train_tokens = 0 - self.log_completions = args.log_completions - - Trainer.__init__( - self, + # First call the superclass constructor with all arguments + super().__init__( model=model, + reward_funcs=reward_funcs, args=args, - data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, + reward_processing_classes=reward_processing_classes, callbacks=callbacks, optimizers=optimizers, + peft_config=peft_config, ) + # Now execute your custom logic # Get number of SP groups (number of processes divided by SP degree) num_processes = self.accelerator.num_processes num_sp_groups = num_processes // self.args.sequence_parallel_degree @@ -303,13 +121,16 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): if self.num_generations not in possible_values: raise ValueError( - f"The batch size per SP group ({num_sp_groups} x {self.args.per_device_train_batch_size}) must be evenly " - f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " - f"configuration, the valid values for the number of generations are: {possible_values}." + f"The batch size per SP group ({num_sp_groups} x " + f"{self.args.per_device_train_batch_size}) must be evenly divisible by " + f"the number of generations per prompt ({self.num_generations}). Given " + "the current configuration, the valid values for the number of " + f"generations are: {possible_values}." ) + if self.args.eval_strategy != "no": # If sequence parallelism is enabled, calculate batch size per SP group - sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups + sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr] possible_values = [ n_gen for n_gen in range(2, sp_group_eval_batch_size + 1) @@ -325,108 +146,6 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): f"the valid values for the number of generations are: {possible_values}." ) - # # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations - # num_processes = self.accelerator.num_processes - # global_batch_size = args.per_device_train_batch_size * num_processes - # possible_values = [ - # n_gen - # for n_gen in range(2, global_batch_size + 1) - # if (global_batch_size) % n_gen == 0 - # ] - # if self.num_generations not in possible_values: - # raise ValueError( - # f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " - # f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " - # f"batch size, the valid values for the number of generations are: {possible_values}." - # ) - # if self.args.eval_strategy != "no": - # global_batch_size = args.per_device_eval_batch_size * num_processes - # possible_values = [ - # n_gen - # for n_gen in range(2, global_batch_size + 1) - # if (global_batch_size) % n_gen == 0 - # ] - # if self.num_generations not in possible_values: - # raise ValueError( - # f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " - # f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " - # f"eval batch size, the valid values for the number of generations are: {possible_values}." - # ) - - # Ensure each process receives a unique seed to prevent duplicate completions when generating with - # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but - # it's safer to set it in all cases. - set_seed(args.seed, device_specific=True) - - if self.use_vllm: - if not is_vllm_available(): - raise ImportError( - "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install vllm` to use it." - ) - - if self.accelerator.is_main_process: - self.vllm_client = VLLMClient( - args.vllm_server_host, - args.vllm_server_port, - connection_timeout=args.vllm_server_timeout, - ) - - # vLLM specific sampling arguments - self.guided_decoding_regex = args.vllm_guided_decoding_regex - - self._last_loaded_step = ( - 0 # tag to avoid useless loading during grad accumulation - ) - - # When using vLLM, the main process is responsible for loading the model weights. This can cause process - # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we - # synchronize all processes after vLLM has been fully initialized. - self.accelerator.wait_for_everyone() - else: - self.generation_config = GenerationConfig( - max_new_tokens=self.max_completion_length, - do_sample=True, - pad_token_id=processing_class.pad_token_id, - bos_token_id=processing_class.bos_token_id, - eos_token_id=processing_class.eos_token_id, - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - min_p=self.min_p, - repetition_penalty=self.repetition_penalty, - cache_implementation=args.cache_implementation, - ) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags to the model - self.model.add_model_tags(self._tag_names) - - if self.ref_model is not None: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model( - self.ref_model, evaluation_mode=True - ) - - if args.sync_ref_model: - self.add_callback( - SyncRefModelCallback( - ref_model=self.ref_model, accelerator=self.accelerator - ) - ) - - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, PreTrainedModel): - self.reward_funcs[i] = self.accelerator.prepare_model( - reward_func, evaluation_mode=True - ) - # Initialize the SP group self.sp_group = get_ring_attn_group() self.local_rank = dist.get_rank(group=self.sp_group) @@ -631,8 +350,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): # Generate completions using either vLLM or regular generation if self.args.use_vllm: # First, have main process load weights if needed - if self.state.global_step != self._last_loaded_step: + # pylint: disable=access-member-before-definition + if self.state.global_step != self._last_loaded_step: # type: ignore[has-type] self._move_model_to_vllm() + # pylint: disable=attribute-defined-outside-init self._last_loaded_step = self.state.global_step all_prompts_text = gather_object(prompts_text) @@ -914,9 +635,11 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): mode = "eval" if self.control.should_evaluate else "train" if mode == "train": + # pylint: disable=no-member self._total_train_tokens += ( self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item() ) + # pylint: disable=no-member self._metrics[mode]["num_tokens"] = [self._total_train_tokens] completion_length = (