minimize diffs to GRPO trainer

This commit is contained in:
Dan Saunders
2025-04-23 19:04:26 +00:00
parent 6c65eeaaf7
commit 6810f0ee19

View File

@@ -3,7 +3,6 @@
# pylint: disable=too-many-lines,duplicate-code
import warnings
from collections import defaultdict
from contextlib import nullcontext
from typing import Any
@@ -15,7 +14,6 @@ from accelerate.utils import (
gather,
gather_object,
is_peft_model,
set_seed,
)
from datasets import Dataset, IterableDataset
from torch import nn
@@ -25,17 +23,12 @@ from torch.utils.data import (
Sampler,
)
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer
@@ -45,18 +38,13 @@ from trl.data_utils import (
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.extras.vllm_client import VLLMClient
from trl.import_utils import (
is_deepspeed_available,
is_rich_available,
is_vllm_available,
)
from trl.models import (
create_reference_model,
prepare_deepspeed,
unwrap_model_for_generation,
)
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc
from trl.trainer.utils import (
@@ -71,7 +59,7 @@ from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig, get_peft_model
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
@@ -104,191 +92,21 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
] = (None, None),
peft_config: "PeftConfig | None" = None,
):
RngLoaderMixin.__init__(self)
SchedulerMixin.__init__(self)
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = GRPOConfig(f"{model_name}-GRPO")
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype")
if (
isinstance(torch_dtype, torch.dtype)
or torch_dtype == "auto"
or torch_dtype is None
):
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False
if args.gradient_checkpointing
else model_init_kwargs.get("use_cache")
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"This argument can only be used when the `model` argument is a string."
)
if peft_config is not None:
if not is_peft_available():
raise ImportError(
"PEFT is required to use `peft_config`. Run `pip install peft`."
)
model = get_peft_model(model, peft_config)
# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
model = self._enable_gradient_checkpointing(model, args)
# Reference model
self.beta = args.beta
if self.beta == 0.0:
# If beta is 0.0, the reference model is not needed
self.ref_model = None
elif is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(
model_id, **model_init_kwargs
)
elif is_peft_model(model):
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
else:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
# Processing class
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(
model.config._name_or_path, padding_side="left"
)
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
self.reward_funcs = reward_funcs
# Reward weights
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
# Reward processing class
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError(
"The number of reward processing classes must match the number of reward functions."
)
for i, (reward_processing_class, reward_func) in enumerate(
zip(reward_processing_classes, reward_funcs)
):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(
reward_func.config._name_or_path
)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = (
reward_processing_class.eos_token
)
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
# Data collator
def data_collator(features): # No data collation is needed in GRPO
return features
# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = (
args.max_completion_length
) # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.temperature = args.temperature
self.top_p = args.top_p
self.top_k = args.top_k
self.min_p = args.min_p
self.repetition_penalty = args.repetition_penalty
self.use_vllm = args.use_vllm
# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = (
args.epsilon_high if args.epsilon_high is not None else args.epsilon
)
# Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
# `_get_train_sampler` and `_prepare_inputs`.
self._buffered_inputs = [None] * args.gradient_accumulation_steps
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
# This acts as a flag to indicate that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True
# Initialize the metrics
self._metrics: dict[str, dict[str, list]] = {
"train": defaultdict(list),
"eval": defaultdict(list),
}
self._total_train_tokens = 0
self.log_completions = args.log_completions
Trainer.__init__(
self,
# First call the superclass constructor with all arguments
super().__init__(
model=model,
reward_funcs=reward_funcs,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
reward_processing_classes=reward_processing_classes,
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
)
# Now execute your custom logic
# Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
@@ -303,13 +121,16 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
if self.num_generations not in possible_values:
raise ValueError(
f"The batch size per SP group ({num_sp_groups} x {self.args.per_device_train_batch_size}) must be evenly "
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
f"configuration, the valid values for the number of generations are: {possible_values}."
f"The batch size per SP group ({num_sp_groups} x "
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
f"the number of generations per prompt ({self.num_generations}). Given "
"the current configuration, the valid values for the number of "
f"generations are: {possible_values}."
)
if self.args.eval_strategy != "no":
# If sequence parallelism is enabled, calculate batch size per SP group
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
possible_values = [
n_gen
for n_gen in range(2, sp_group_eval_batch_size + 1)
@@ -325,108 +146,6 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
f"the valid values for the number of generations are: {possible_values}."
)
# # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
# num_processes = self.accelerator.num_processes
# global_batch_size = args.per_device_train_batch_size * num_processes
# possible_values = [
# n_gen
# for n_gen in range(2, global_batch_size + 1)
# if (global_batch_size) % n_gen == 0
# ]
# if self.num_generations not in possible_values:
# raise ValueError(
# f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
# f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
# f"batch size, the valid values for the number of generations are: {possible_values}."
# )
# if self.args.eval_strategy != "no":
# global_batch_size = args.per_device_eval_batch_size * num_processes
# possible_values = [
# n_gen
# for n_gen in range(2, global_batch_size + 1)
# if (global_batch_size) % n_gen == 0
# ]
# if self.num_generations not in possible_values:
# raise ValueError(
# f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
# f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
# f"eval batch size, the valid values for the number of generations are: {possible_values}."
# )
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
# it's safer to set it in all cases.
set_seed(args.seed, device_specific=True)
if self.use_vllm:
if not is_vllm_available():
raise ImportError(
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
"`pip install vllm` to use it."
)
if self.accelerator.is_main_process:
self.vllm_client = VLLMClient(
args.vllm_server_host,
args.vllm_server_port,
connection_timeout=args.vllm_server_timeout,
)
# vLLM specific sampling arguments
self.guided_decoding_regex = args.vllm_guided_decoding_regex
self._last_loaded_step = (
0 # tag to avoid useless loading during grad accumulation
)
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()
else:
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
pad_token_id=processing_class.pad_token_id,
bos_token_id=processing_class.bos_token_id,
eos_token_id=processing_class.eos_token_id,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
repetition_penalty=self.repetition_penalty,
cache_implementation=args.cache_implementation,
)
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
# Add tags to the model
self.model.add_model_tags(self._tag_names)
if self.ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(
self.ref_model, evaluation_mode=True
)
if args.sync_ref_model:
self.add_callback(
SyncRefModelCallback(
ref_model=self.ref_model, accelerator=self.accelerator
)
)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(
reward_func, evaluation_mode=True
)
# Initialize the SP group
self.sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=self.sp_group)
@@ -631,8 +350,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
# pylint: disable=access-member-before-definition
if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
self._move_model_to_vllm()
# pylint: disable=attribute-defined-outside-init
self._last_loaded_step = self.state.global_step
all_prompts_text = gather_object(prompts_text)
@@ -914,9 +635,11 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
mode = "eval" if self.control.should_evaluate else "train"
if mode == "train":
# pylint: disable=no-member
self._total_train_tokens += (
self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
)
# pylint: disable=no-member
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
completion_length = (