diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ad1502c8c..61eea48d8 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1126,23 +1126,23 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def build(self, total_num_steps): training_args = self.build_training_arguments(total_num_steps) - dpo_trainer_kwargs = {} + trainer_kwargs = {} if self.cfg.rl is RLType.IPO: if self.cfg.dpo_label_smoothing: - dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing if self.eval_dataset: - dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset + trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: - dpo_trainer_kwargs["peft_config"] = self.peft_config + trainer_kwargs["peft_config"] = self.peft_config if self.cfg.precompute_ref_log_probs is not None: - dpo_trainer_kwargs["precompute_ref_log_probs"] = ( + trainer_kwargs["precompute_ref_log_probs"] = ( self.cfg.precompute_ref_log_probs ) if self.cfg.rl is RLType.GRPO: trainer_cls = GRPOStrategy.get_trainer_class() trainer_cls_args = [self.model] trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) - dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) + trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) elif self.cfg.rl in [RLType.DPO, RLType.IPO]: trainer_cls = DPOStrategy.get_trainer_class() trainer_cls_args = [self.model, self.model_ref] @@ -1160,33 +1160,33 @@ class HFRLTrainerBuilder(TrainerBuilderBase): sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters.keys(): - dpo_trainer_kwargs["tokenizer"] = self.tokenizer + trainer_kwargs["tokenizer"] = self.tokenizer else: - dpo_trainer_kwargs["processing_class"] = self.tokenizer + trainer_kwargs["processing_class"] = self.tokenizer if self.cfg.datasets is not None and ( trainer_cls is DPOStrategy.get_trainer_class() ): - dpo_trainer_kwargs["dataset_tags"] = [ + trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] - dpo_trainer = trainer_cls( + trainer = trainer_cls( *trainer_cls_args, args=training_args, train_dataset=self.train_dataset, callbacks=self.get_callbacks(), - **dpo_trainer_kwargs, + **trainer_kwargs, ) if self.cfg.fsdp: - ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) - if self.cfg.rl in [RLType.DPO, RLType.IPO] and dpo_trainer.ref_model: - ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype) + ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype) + if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model: + ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype) - dpo_trainer = self.hook_post_create_trainer(dpo_trainer) - for callback in self.get_post_trainer_create_callbacks(dpo_trainer): - dpo_trainer.add_callback(callback) + trainer = self.hook_post_create_trainer(trainer) + for callback in self.get_post_trainer_create_callbacks(trainer): + trainer.add_callback(callback) - return dpo_trainer + return trainer class HFPPOTrainerBuilder(TrainerBuilderBase): diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 1b7e36bce..d672770e9 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,6 +1,7 @@ """Axolotl GRPO trainer""" import warnings +from collections import defaultdict from contextlib import nullcontext from typing import Any @@ -12,15 +13,29 @@ from accelerate.utils import ( gather, gather_object, is_peft_model, + set_seed, ) +from datasets import Dataset, IterableDataset from torch import nn from torch.utils.data import ( BatchSampler, DataLoader, Sampler, ) -from transformers import Trainer, is_wandb_available +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + is_wandb_available, +) +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer_utils import seed_worker +from transformers.utils import is_peft_available from trl import GRPOTrainer from trl.data_utils import ( apply_chat_template, @@ -28,8 +43,20 @@ from trl.data_utils import ( maybe_apply_chat_template, ) from trl.extras.profiling import profiling_context, profiling_decorator -from trl.import_utils import is_deepspeed_available, is_rich_available -from trl.models import unwrap_model_for_generation +from trl.extras.vllm_client import VLLMClient +from trl.import_utils import ( + is_deepspeed_available, + is_rich_available, + is_vllm_available, +) +from trl.models import ( + create_reference_model, + prepare_deepspeed, + unwrap_model_for_generation, +) +from trl.trainer.callbacks import SyncRefModelCallback +from trl.trainer.grpo_config import GRPOConfig +from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.utils import ( pad, print_prompt_completions_sample, @@ -40,6 +67,9 @@ from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampl from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group +if is_peft_available(): + from peft import PeftConfig, get_peft_model + if is_deepspeed_available(): import deepspeed @@ -52,9 +82,341 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): _tag_names = ["trl", "grpo", "axolotl"] - def __init__(self, *args, **kwargs): - # Call parent constructor with all arguments - super().__init__(*args, **kwargs) + def __init__( + self, + model: str | PreTrainedModel, + reward_funcs: RewardFunc | list[RewardFunc], + args: GRPOConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: ( + Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None + ) = None, + processing_class: PreTrainedTokenizerBase | None = None, + reward_processing_classes: ( + PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None + ) = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[ + torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None + ] = (None, None), + peft_config: "PeftConfig | None" = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = GRPOConfig(f"{model_name}-GRPO") + + # Models + # Trained model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + torch_dtype = model_init_kwargs.get("torch_dtype") + if ( + isinstance(torch_dtype, torch.dtype) + or torch_dtype == "auto" + or torch_dtype is None + ): + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str): # it's a str, but not "auto" + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + # Disable caching if gradient checkpointing is enabled (not supported) + model_init_kwargs["use_cache"] = ( + False + if args.gradient_checkpointing + else model_init_kwargs.get("use_cache") + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + raise ValueError( + "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " + "This argument can only be used when the `model` argument is a string." + ) + + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "PEFT is required to use `peft_config`. Run `pip install peft`." + ) + model = get_peft_model(model, peft_config) + + # Enable gradient checkpointing if requested + if args.gradient_checkpointing: + model = self._enable_gradient_checkpointing(model, args) + + # Reference model + self.beta = args.beta + if self.beta == 0.0: + # If beta is 0.0, the reference model is not needed + self.ref_model = None + elif is_deepspeed_zero3_enabled(): + self.ref_model = AutoModelForCausalLM.from_pretrained( + model_id, **model_init_kwargs + ) + elif is_peft_model(model): + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None + else: + # If PEFT configuration is not provided, create a reference model based on the initial model. + self.ref_model = create_reference_model(model) + + # Processing class + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained( + model.config._name_or_path, padding_side="left" + ) + + # Reward functions + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, num_labels=1, **model_init_kwargs + ) + self.reward_funcs = reward_funcs + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + # Reward processing class + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + else: + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError( + "The number of reward processing classes must match the number of reward functions." + ) + + for i, (reward_processing_class, reward_func) in enumerate( + zip(reward_processing_classes, reward_funcs) + ): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained( + reward_func.config._name_or_path + ) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = ( + reward_processing_class.eos_token + ) + # The reward model computes the reward for the latest non-padded token in the input sequence. + # So it's important to set the pad token ID to the padding token ID of the processing class. + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + # Data collator + def data_collator(features): # No data collation is needed in GRPO + return features + + # Training arguments + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = ( + args.max_completion_length + ) # = |o_i| in the GRPO paper + self.num_generations = args.num_generations # = G in the GRPO paper + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_vllm = args.use_vllm + + # Multi-step + self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper + self.epsilon_low = args.epsilon + self.epsilon_high = ( + args.epsilon_high if args.epsilon_high is not None else args.epsilon + ) + # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = [None] * args.gradient_accumulation_steps + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: + # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To + # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. + # This acts as a flag to indicate that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self.log_completions = args.log_completions + + Trainer.__init__( + self, + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + ) + + # Get number of SP groups (number of processes divided by SP degree) + num_processes = self.accelerator.num_processes + num_sp_groups = num_processes // self.args.sequence_parallel_degree + + # Calculate batch size per SP group (not per process) + sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups + possible_values = [ + n_gen + for n_gen in range(2, sp_group_batch_size + 1) + if (sp_group_batch_size) % n_gen == 0 + ] + + if self.num_generations not in possible_values: + raise ValueError( + f"The batch size per SP group ({num_sp_groups} x {self.args.per_device_train_batch_size}) must be evenly " + f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " + f"configuration, the valid values for the number of generations are: {possible_values}." + ) + if self.args.eval_strategy != "no": + # If sequence parallelism is enabled, calculate batch size per SP group + sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups + possible_values = [ + n_gen + for n_gen in range(2, sp_group_eval_batch_size + 1) + if (sp_group_eval_batch_size) % n_gen == 0 + ] + + if self.num_generations not in possible_values: + raise ValueError( + f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), " + f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " + f"must be evenly divisible by the number of generations per prompt " + f"({self.num_generations}). Given the current eval batch size, " + f"the valid values for the number of generations are: {possible_values}." + ) + + # # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations + # num_processes = self.accelerator.num_processes + # global_batch_size = args.per_device_train_batch_size * num_processes + # possible_values = [ + # n_gen + # for n_gen in range(2, global_batch_size + 1) + # if (global_batch_size) % n_gen == 0 + # ] + # if self.num_generations not in possible_values: + # raise ValueError( + # f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " + # f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " + # f"batch size, the valid values for the number of generations are: {possible_values}." + # ) + # if self.args.eval_strategy != "no": + # global_batch_size = args.per_device_eval_batch_size * num_processes + # possible_values = [ + # n_gen + # for n_gen in range(2, global_batch_size + 1) + # if (global_batch_size) % n_gen == 0 + # ] + # if self.num_generations not in possible_values: + # raise ValueError( + # f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " + # f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " + # f"eval batch size, the valid values for the number of generations are: {possible_values}." + # ) + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install vllm` to use it." + ) + + if self.accelerator.is_main_process: + self.vllm_client = VLLMClient( + args.vllm_server_host, + args.vllm_server_port, + connection_timeout=args.vllm_server_timeout, + ) + + # vLLM specific sampling arguments + self.guided_decoding_regex = args.vllm_guided_decoding_regex + + self._last_loaded_step = ( + 0 # tag to avoid useless loading during grad accumulation + ) + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + self.generation_config = GenerationConfig( + max_new_tokens=self.max_completion_length, + do_sample=True, + pad_token_id=processing_class.pad_token_id, + bos_token_id=processing_class.bos_token_id, + eos_token_id=processing_class.eos_token_id, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + repetition_penalty=self.repetition_penalty, + cache_implementation=args.cache_implementation, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model( + self.ref_model, evaluation_mode=True + ) + + if args.sync_ref_model: + self.add_callback( + SyncRefModelCallback( + ref_model=self.ref_model, accelerator=self.accelerator + ) + ) + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True + ) # Initialize the SP group self.sp_group = get_ring_attn_group() @@ -255,6 +617,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): prompt_ids = prompt_ids[:, -self.max_prompt_length :] prompt_mask = prompt_mask[:, -self.max_prompt_length :] + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_prompts_text = gather_object(prompts_text) + # Generate completions using either vLLM or regular generation if self.args.use_vllm: # First, have main process load weights if needed @@ -262,14 +627,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): self._move_model_to_vllm() self._last_loaded_step = self.state.global_step - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - all_prompts_text = gather_object(prompts_text) - if self.accelerator.is_main_process: # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + # ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + ordered_set_of_prompts = all_prompts_text[ + :: self.num_generations * self.args.sequence_parallel_degree + ] with profiling_context(self, "vLLM.generate"): completion_ids = self.vllm_client.generate( prompts=ordered_set_of_prompts, @@ -297,33 +662,19 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): sp_group_start = sp_group_id * len(prompts) * self.local_world_size # All ranks in the same SP group get the same data slice - # This ensures identical inputs for sequence-parallel processing process_slice = slice( sp_group_start, - sp_group_start + len(prompts) * self.local_world_size, + sp_group_start + len(prompts), ) - - # Take the full SP group's worth of completions completion_ids = completion_ids[process_slice] else: - # Original behavior for non-sequence-parallel case + # Original behavior for non-sequence parallel case process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) completion_ids = completion_ids[process_slice] - if dist.get_rank() == 0: - import ipdb - - ipdb.set_trace() - dist.barrier() - if dist.get_rank() == 1: - import ipdb - - ipdb.set_trace() - dist.barrier() - # Pad the completions, and concatenate them with the prompts completion_ids = [ torch.tensor(ids, device=device) for ids in completion_ids