diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 4347b863a..1fa0899b7 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -28,8 +28,14 @@ class GRPOStrategy: @classmethod def get_trainer_class( - cls, sequence_parallel: bool, async_grpo: bool = False, - ) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer] | type[AxolotlAsyncGRPOTrainer]: + cls, + sequence_parallel: bool, + async_grpo: bool = False, + ) -> ( + type[AxolotlGRPOTrainer] + | type[AxolotlGRPOSequenceParallelTrainer] + | type[AxolotlAsyncGRPOTrainer] + ): if sequence_parallel: return AxolotlGRPOSequenceParallelTrainer if async_grpo: @@ -37,7 +43,9 @@ class GRPOStrategy: return AxolotlGRPOTrainer @classmethod - def get_training_args_class(cls, async_grpo: bool = False) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]: + def get_training_args_class( + cls, async_grpo: bool = False + ) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]: if async_grpo: return AxolotlAsyncGRPOConfig return AxolotlGRPOConfig @@ -150,16 +158,40 @@ class GRPOStrategy: if getattr(trl, "streaming_min_groups", None) is not None: grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups if getattr(trl, "vllm_importance_sampling_correction", None) is not None: - grpo_args_kwargs["vllm_importance_sampling_correction"] = trl.vllm_importance_sampling_correction + grpo_args_kwargs["vllm_importance_sampling_correction"] = ( + trl.vllm_importance_sampling_correction + ) if getattr(trl, "vllm_importance_sampling_mode", None) is not None: - grpo_args_kwargs["vllm_importance_sampling_mode"] = trl.vllm_importance_sampling_mode + grpo_args_kwargs["vllm_importance_sampling_mode"] = ( + trl.vllm_importance_sampling_mode + ) if getattr(trl, "vllm_importance_sampling_cap", None) is not None: - grpo_args_kwargs["vllm_importance_sampling_cap"] = trl.vllm_importance_sampling_cap + grpo_args_kwargs["vllm_importance_sampling_cap"] = ( + trl.vllm_importance_sampling_cap + ) if getattr(trl, "off_policy_mask_threshold", None) is not None: - grpo_args_kwargs["off_policy_mask_threshold"] = trl.off_policy_mask_threshold + grpo_args_kwargs["off_policy_mask_threshold"] = ( + trl.off_policy_mask_threshold + ) if getattr(trl, "use_bias_correction_kl", None) is not None: grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl + # Fast Async GRPO fields + if getattr(trl, "reward_num_workers", None) is not None: + grpo_args_kwargs["reward_num_workers"] = trl.reward_num_workers + if getattr(trl, "replay_buffer_size", None) is not None: + grpo_args_kwargs["replay_buffer_size"] = trl.replay_buffer_size + if getattr(trl, "replay_recompute_logps", None) is not None: + grpo_args_kwargs["replay_recompute_logps"] = trl.replay_recompute_logps + if getattr(trl, "reroll_start_fraction", None) is not None: + grpo_args_kwargs["reroll_start_fraction"] = trl.reroll_start_fraction + if getattr(trl, "reroll_max_groups", None) is not None: + grpo_args_kwargs["reroll_max_groups"] = trl.reroll_max_groups + if getattr(trl, "skip_zero_advantage_batches", None) is not None: + grpo_args_kwargs["skip_zero_advantage_batches"] = ( + trl.skip_zero_advantage_batches + ) + return grpo_args_kwargs @classmethod diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index c6d276f1e..f1dd5a6e7 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -6,8 +6,8 @@ from dataclasses import dataclass from trl import GRPOConfig +from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig from axolotl.core.training_args import AxolotlTrainingMixins -from axolotl.monkeypatch.trainer.async_grpo import AsyncGRPOConfig @dataclass @@ -18,7 +18,7 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): @dataclass -class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, AsyncGRPOConfig): +class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig): """Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction.""" context_parallel_size: int | None = None diff --git a/src/axolotl/monkeypatch/trainer/async_grpo.py b/src/axolotl/core/trainers/grpo/async_trainer.py similarity index 76% rename from src/axolotl/monkeypatch/trainer/async_grpo.py rename to src/axolotl/core/trainers/grpo/async_trainer.py index 1bc63b474..169e7de9f 100644 --- a/src/axolotl/monkeypatch/trainer/async_grpo.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -27,21 +27,20 @@ import queue import threading import time from abc import ABC, abstractmethod -from collections import defaultdict, deque +from collections import deque from contextlib import nullcontext from dataclasses import dataclass, field from typing import Any import torch from torch.utils.data import DataLoader, Dataset - from trl.trainer import GRPOConfig, GRPOTrainer from trl.trainer.utils import ( + RepeatSampler, nanmax, nanmin, nanstd, pad, - RepeatSampler, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, @@ -49,7 +48,11 @@ from trl.trainer.utils import ( ) try: - from trl.data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages + from trl.data_utils import ( + apply_chat_template, + is_conversational, + prepare_multimodal_messages, + ) except ImportError: from trl.chat_template_utils import apply_chat_template from trl.data_utils import is_conversational, prepare_multimodal_messages @@ -63,6 +66,7 @@ except ImportError: def disable_gradient_checkpointing(model, kwargs): yield + try: from accelerate.utils import gather_object except ImportError: @@ -80,6 +84,7 @@ except ImportError: # Config # --------------------------------------------------------------------------- + @dataclass class AsyncGRPOConfig(GRPOConfig): """GRPOConfig extended with async prefetch, streaming scoring, and IS correction fields. @@ -92,13 +97,17 @@ class AsyncGRPOConfig(GRPOConfig): # --- Data producer --- use_data_producer: bool = field( default=False, - metadata={"help": "Use the GRPODataProducer protocol for online data generation."}, + metadata={ + "help": "Use the GRPODataProducer protocol for online data generation." + }, ) # --- Async data production --- async_prefetch: bool = field( default=False, - metadata={"help": "Generate rollouts in a background thread while training on the previous rollout."}, + metadata={ + "help": "Generate rollouts in a background thread while training on the previous rollout." + }, ) prefetch_depth: int = field( default=1, @@ -106,13 +115,17 @@ class AsyncGRPOConfig(GRPOConfig): ) vllm_sync_interval: int = field( default=1, - metadata={"help": "Sync model weights to vLLM every N optimizer steps (async mode only)."}, + metadata={ + "help": "Sync model weights to vLLM every N optimizer steps (async mode only)." + }, ) # --- Streaming scoring --- streaming_partial_batch: bool = field( default=False, - metadata={"help": "Score prompt groups incrementally instead of the full batch at once."}, + metadata={ + "help": "Score prompt groups incrementally instead of the full batch at once." + }, ) streaming_min_groups: int = field( default=1, @@ -122,11 +135,15 @@ class AsyncGRPOConfig(GRPOConfig): # --- vLLM importance sampling correction --- vllm_importance_sampling_correction: bool = field( default=True, - metadata={"help": "Apply IS correction for distribution mismatch between vLLM and training model."}, + metadata={ + "help": "Apply IS correction for distribution mismatch between vLLM and training model." + }, ) vllm_importance_sampling_mode: str = field( default="token_truncate", - metadata={"help": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask."}, + metadata={ + "help": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask." + }, ) vllm_importance_sampling_cap: float = field( default=3.0, @@ -185,15 +202,21 @@ class ProducerConfig: if self.mini_epochs < 1: raise ValueError(f"mini_epochs must be >= 1, got {self.mini_epochs}") if self.max_rollouts is not None and self.max_rollouts < 1: - raise ValueError(f"max_rollouts must be >= 1 or None, got {self.max_rollouts}") + raise ValueError( + f"max_rollouts must be >= 1 or None, got {self.max_rollouts}" + ) if self.num_iterations < 1: raise ValueError(f"num_iterations must be >= 1, got {self.num_iterations}") if self.steps_per_generation is not None and self.steps_per_generation < 1: - raise ValueError(f"steps_per_generation must be >= 1 or None, got {self.steps_per_generation}") + raise ValueError( + f"steps_per_generation must be >= 1 or None, got {self.steps_per_generation}" + ) if self.prefetch_depth < 1: raise ValueError(f"prefetch_depth must be >= 1, got {self.prefetch_depth}") if self.sync_warmup_rollouts < 0: - raise ValueError(f"sync_warmup_rollouts must be >= 0, got {self.sync_warmup_rollouts}") + raise ValueError( + f"sync_warmup_rollouts must be >= 0, got {self.sync_warmup_rollouts}" + ) class DataProducer(ABC): @@ -240,7 +263,9 @@ class AsyncDataProducer: datasets in a background thread. """ - def __init__(self, inner: DataProducer, background_produce_kwargs: dict | None = None): + def __init__( + self, inner: DataProducer, background_produce_kwargs: dict | None = None + ): self._inner = inner self._depth = inner.config.prefetch_depth self._warmup_remaining = inner.config.sync_warmup_rollouts @@ -270,7 +295,9 @@ class AsyncDataProducer: bg_kwargs = {**kwargs, **self._background_kwargs} for i in range(1, self._depth + 1): self._queue.append( - self._executor.submit(self._inner.produce, model, global_step + i, **bg_kwargs) + self._executor.submit( + self._inner.produce, model, global_step + i, **bg_kwargs + ) ) self._initialized = True return dataset @@ -302,6 +329,7 @@ class AsyncDataProducer: class DataProducerCallback: """Marker class: if a DataProducer also inherits from this, the Trainer will automatically register it as a callback.""" + pass @@ -414,6 +442,7 @@ class GRPODataProducer(BaseDataProducer): def _init_prompt_dataloader(self) -> None: from functools import partial + from transformers.trainer_utils import seed_worker trainer = self._trainer @@ -498,6 +527,7 @@ class GRPODataProducer(BaseDataProducer): # Trainer # --------------------------------------------------------------------------- + class AsyncGRPOTrainer(GRPOTrainer): """GRPOTrainer with async prefetch, streaming scoring, and IS correction. @@ -510,8 +540,16 @@ class AsyncGRPOTrainer(GRPOTrainer): # Ensure custom attributes exist (stock GRPOTrainer.__init__ may not set them). for attr, cfg_key, default in [ - ("vllm_importance_sampling_correction", "vllm_importance_sampling_correction", True), - ("vllm_importance_sampling_mode", "vllm_importance_sampling_mode", "token_truncate"), + ( + "vllm_importance_sampling_correction", + "vllm_importance_sampling_correction", + True, + ), + ( + "vllm_importance_sampling_mode", + "vllm_importance_sampling_mode", + "token_truncate", + ), ("vllm_importance_sampling_cap", "vllm_importance_sampling_cap", 3.0), ("off_policy_mask_threshold", "off_policy_mask_threshold", None), ]: @@ -552,7 +590,8 @@ class AsyncGRPOTrainer(GRPOTrainer): prompt_dataset=self.train_dataset, num_generations=self.num_generations, generation_batch_size=getattr( - args, "generation_batch_size", + args, + "generation_batch_size", self._train_batch_size * args.gradient_accumulation_steps, ), train_batch_size=args.per_device_train_batch_size, @@ -577,7 +616,8 @@ class AsyncGRPOTrainer(GRPOTrainer): def _setup_async(self): """Create background thread pool, prompt iterator, and pre-fill the async queue.""" gen_batch_size = getattr( - self.args, "generation_batch_size", + self.args, + "generation_batch_size", self._train_batch_size * self.args.gradient_accumulation_steps, ) # RepeatSampler groups prompts with num_generations repetitions each. @@ -677,7 +717,10 @@ class AsyncGRPOTrainer(GRPOTrainer): if "images" in inputs[0]: images = [ex.get("images") for ex in inputs] elif "image" in inputs[0]: - images = [[ex.get("image")] if ex.get("image") is not None else None for ex in inputs] + images = [ + [ex.get("image")] if ex.get("image") is not None else None + for ex in inputs + ] else: images = None if images is not None and all(img == [] for img in images): @@ -687,7 +730,8 @@ class AsyncGRPOTrainer(GRPOTrainer): if not is_conversational(inputs[0]): raise ValueError("Multimodal training requires conversational prompts.") prompts = [ - prepare_multimodal_messages(p, il) for p, il in zip(prompts, images, strict=True) + prepare_multimodal_messages(p, il) + for p, il in zip(prompts, images, strict=True) ] # --- Generate completions --- @@ -704,17 +748,29 @@ class AsyncGRPOTrainer(GRPOTrainer): # --- Pad to tensors --- prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_ids = pad( + prompt_ids, padding_value=self.pad_token_id, padding_side="left" + ) prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_ids = [ + torch.tensor(ids, device=device) for ids in completion_ids_list + ] + completion_mask = [ + torch.ones_like(ids, dtype=torch.long) for ids in completion_ids + ] + completion_ids = pad( + completion_ids, padding_value=self.pad_token_id, padding_side="right" + ) completion_mask = pad(completion_mask, padding_value=0, padding_side="right") if sampling_per_token_logps_list is not None: - sampling_logps = [torch.tensor(lp, device=device) for lp in sampling_per_token_logps_list] - sampling_per_token_logps = pad(sampling_logps, padding_value=0.0, padding_side="right") + sampling_logps = [ + torch.tensor(lp, device=device) for lp in sampling_per_token_logps_list + ] + sampling_per_token_logps = pad( + sampling_logps, padding_value=0.0, padding_side="right" + ) else: sampling_per_token_logps = None @@ -727,7 +783,10 @@ class AsyncGRPOTrainer(GRPOTrainer): # --- Mask truncated completions --- if self.mask_truncated_completions: eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_trunc = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + is_trunc = torch.tensor( + [ids[-1] not in eos_and_pad for ids in completion_ids_list], + device=device, + ) completion_mask = completion_mask * (~is_trunc).unsqueeze(1).int() if tool_mask is not None: tool_mask = tool_mask * (~is_trunc).unsqueeze(1).int() @@ -737,7 +796,10 @@ class AsyncGRPOTrainer(GRPOTrainer): if images is not None: prompts_text = [ apply_chat_template( - {"prompt": p}, self.processing_class, tools=self.tools, **self.chat_template_kwargs + {"prompt": p}, + self.processing_class, + tools=self.tools, + **self.chat_template_kwargs, )["prompt"] for p in prompts ] @@ -756,7 +818,9 @@ class AsyncGRPOTrainer(GRPOTrainer): for ttid_key in ("token_type_ids", "mm_token_type_ids"): if ttid_key in forward_kwargs: tt = forward_kwargs[ttid_key] - forward_kwargs[ttid_key] = torch.cat([tt, tt.new_zeros(completion_ids.shape)], dim=1) + forward_kwargs[ttid_key] = torch.cat( + [tt, tt.new_zeros(completion_ids.shape)], dim=1 + ) # Merge extra_fields from rollout_func into inputs if extra_fields: @@ -791,8 +855,14 @@ class AsyncGRPOTrainer(GRPOTrainer): output["tool_mask"] = tool_mask if images is not None: output["num_images"] = num_images - for k in ("pixel_values", "image_grid_thw", "pixel_attention_mask", - "image_sizes", "token_type_ids", "mm_token_type_ids"): + for k in ( + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): if k in forward_kwargs: output[k] = forward_kwargs[k] return output @@ -831,23 +901,36 @@ class AsyncGRPOTrainer(GRPOTrainer): # Multimodal forward kwargs forward_kwargs = {} - for key in ("pixel_values", "image_grid_thw", "pixel_attention_mask", - "image_sizes", "token_type_ids", "mm_token_type_ids"): + for key in ( + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): if key in data: forward_kwargs[key] = data[key] num_images = data.get("num_images") # --- Policy logprobs --- logprob_batch_size = min(batch_size * 4, len(prompt_ids)) - with disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + with disable_gradient_checkpointing( + self.model, self.args.gradient_checkpointing_kwargs + ): generate_every = self.args.steps_per_generation * self.num_iterations if self.args.gradient_accumulation_steps % generate_every != 0 or ( - self.use_vllm and getattr(self, "vllm_importance_sampling_correction", False) + self.use_vllm + and getattr(self, "vllm_importance_sampling_correction", False) ): old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, prompt_completion_ids, attention_mask, - logits_to_keep, logprob_batch_size, - num_images=num_images, **forward_kwargs, + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + logprob_batch_size, + num_images=num_images, + **forward_kwargs, ) data["old_per_token_logps"] = old_per_token_logps else: @@ -857,18 +940,30 @@ class AsyncGRPOTrainer(GRPOTrainer): if self.beta != 0.0: if self.ref_model is not None: ref_logps, _ = self._get_per_token_logps_and_entropies( - self.ref_model, prompt_completion_ids, attention_mask, - logits_to_keep, batch_size, - num_images=num_images, **forward_kwargs, + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, ) else: model = self.accelerator.unwrap_model(self.model) - adapter_name = "ref" if hasattr(model, "peft_config") and "ref" in model.peft_config else None + adapter_name = ( + "ref" + if hasattr(model, "peft_config") and "ref" in model.peft_config + else None + ) with use_adapter(model, adapter_name=adapter_name): ref_logps, _ = self._get_per_token_logps_and_entropies( - self.model, prompt_completion_ids, attention_mask, - logits_to_keep, batch_size, - num_images=num_images, **forward_kwargs, + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, ) data["ref_per_token_logps"] = ref_logps @@ -880,7 +975,11 @@ class AsyncGRPOTrainer(GRPOTrainer): and "sampling_per_token_logps" in data ): sampling_logps = data["sampling_per_token_logps"] - is_mask = completion_mask if "tool_mask" not in data else completion_mask * data["tool_mask"] + is_mask = ( + completion_mask + if "tool_mask" not in data + else completion_mask * data["tool_mask"] + ) per_token_logps_diff = (old_per_token_logps - sampling_logps) * is_mask is_mode = getattr(self, "vllm_importance_sampling_mode", "token_truncate") @@ -899,19 +998,35 @@ class AsyncGRPOTrainer(GRPOTrainer): data["importance_sampling_ratio"] = is_ratio # --- Rewards --- - rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards( + inputs, prompts, completions, completion_ids_list + ) # --- Advantages --- if self.multi_objective_aggregation == "sum_then_normalize": - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - mean_grouped = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations) + rewards = ( + rewards_per_func * self.reward_weights.to(device).unsqueeze(0) + ).nansum(dim=1) + mean_grouped = ( + rewards.view(-1, num_generations) + .mean(dim=1) + .repeat_interleave(num_generations) + ) if self.scale_rewards in ("group", "none"): if num_generations > 1: - std_rewards = rewards.view(-1, num_generations).std(dim=1).repeat_interleave(num_generations) + std_rewards = ( + rewards.view(-1, num_generations) + .std(dim=1) + .repeat_interleave(num_generations) + ) else: std_rewards = torch.zeros_like(rewards) elif self.scale_rewards == "batch": - std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + std_rewards = ( + rewards.std().expand_as(rewards) + if rewards.numel() > 1 + else torch.zeros_like(rewards) + ) else: raise ValueError(f"Invalid scale_rewards: {self.scale_rewards}") advantages = rewards - mean_grouped @@ -922,15 +1037,27 @@ class AsyncGRPOTrainer(GRPOTrainer): elif self.multi_objective_aggregation == "normalize_then_sum": grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) mean_k = torch.nanmean(grouped, dim=1, keepdim=True) - std_k = nanstd(grouped, dim=1, keepdim=True) if num_generations > 1 else torch.zeros_like(mean_k) + std_k = ( + nanstd(grouped, dim=1, keepdim=True) + if num_generations > 1 + else torch.zeros_like(mean_k) + ) reward_k = (grouped - mean_k) / (std_k + 1e-4) reward_k = reward_k.view(-1, len(self.reward_funcs)) - rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum( + dim=1 + ) + std_rewards = ( + rewards.std().expand_as(rewards) + if rewards.numel() > 1 + else torch.zeros_like(rewards) + ) advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) else: - raise ValueError(f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}") + raise ValueError( + f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}" + ) # Slice for local process process_slice = slice( @@ -943,12 +1070,18 @@ class AsyncGRPOTrainer(GRPOTrainer): # --- Metrics --- for i, name in enumerate(self.reward_func_names): - self._metrics[mode][f"rewards/{name}/mean"].append(torch.nanmean(rewards_per_func[:, i]).item()) - self._metrics[mode][f"rewards/{name}/std"].append(nanstd(rewards_per_func[:, i]).item()) + self._metrics[mode][f"rewards/{name}/mean"].append( + torch.nanmean(rewards_per_func[:, i]).item() + ) + self._metrics[mode][f"rewards/{name}/std"].append( + nanstd(rewards_per_func[:, i]).item() + ) agg_rewards = rewards_per_func.nansum(dim=1) self._metrics[mode]["reward"].append(agg_rewards.mean().item()) self._metrics[mode]["reward_std"].append(agg_rewards.std().item()) - self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + self._metrics[mode]["frac_reward_zero_std"].append( + is_std_zero.float().mean().item() + ) # Token counting total_prompt = self.accelerator.gather(prompt_mask.sum()) @@ -959,20 +1092,36 @@ class AsyncGRPOTrainer(GRPOTrainer): # Completion length metrics comp_lengths = completion_mask.sum(dim=1) agg_lengths = self.accelerator.gather(comp_lengths) - self._metrics[mode]["completions/mean_length"].append(agg_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_lengths.float().max().item()) + self._metrics[mode]["completions/mean_length"].append( + agg_lengths.float().mean().item() + ) + self._metrics[mode]["completions/min_length"].append( + agg_lengths.float().min().item() + ) + self._metrics[mode]["completions/max_length"].append( + agg_lengths.float().max().item() + ) eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_trunc = torch.tensor([ids[-1].item() not in eos_and_pad for ids in completion_ids], device=device) + is_trunc = torch.tensor( + [ids[-1].item() not in eos_and_pad for ids in completion_ids], device=device + ) agg_trunc = self.accelerator.gather(is_trunc) - self._metrics[mode]["completions/clipped_ratio"].append(agg_trunc.float().mean().item()) + self._metrics[mode]["completions/clipped_ratio"].append( + agg_trunc.float().mean().item() + ) term_lengths = agg_lengths[~agg_trunc] if len(term_lengths) == 0: term_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_lengths.float().max().item()) + self._metrics[mode]["completions/mean_terminated_length"].append( + term_lengths.float().mean().item() + ) + self._metrics[mode]["completions/min_terminated_length"].append( + term_lengths.float().min().item() + ) + self._metrics[mode]["completions/max_terminated_length"].append( + term_lengths.float().max().item() + ) # IS metrics if "importance_sampling_ratio" in data and "sampling_per_token_logps" in data: @@ -981,8 +1130,16 @@ class AsyncGRPOTrainer(GRPOTrainer): mask = completion_mask.bool() delta = torch.abs(old_lp - samp_lp) delta_m = delta[mask] - md = torch.mean(delta_m) if delta_m.numel() > 0 else torch.tensor(0.0, device=device) - xd = torch.max(delta_m) if delta_m.numel() > 0 else torch.tensor(0.0, device=device) + md = ( + torch.mean(delta_m) + if delta_m.numel() > 0 + else torch.tensor(0.0, device=device) + ) + xd = ( + torch.max(delta_m) + if delta_m.numel() > 0 + else torch.tensor(0.0, device=device) + ) self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( self.accelerator.gather(md).mean().item() ) @@ -1007,8 +1164,12 @@ class AsyncGRPOTrainer(GRPOTrainer): ) # Log prompt/completion texts - prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) - completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + prompts_text = self.processing_class.batch_decode( + prompt_ids, skip_special_tokens=True + ) + completions_text = self.processing_class.batch_decode( + completion_ids, skip_special_tokens=True + ) if gather_object is not None: self._logs["prompt"].extend(gather_object(prompts_text)) self._logs["completion"].extend(gather_object(completions_text)) @@ -1025,7 +1186,15 @@ class AsyncGRPOTrainer(GRPOTrainer): @torch.no_grad() def _compute_streaming_group_scores( - self, data, s_start, s_end, inputs, prompts, completions, completion_ids_list, is_last_chunk, + self, + data, + s_start, + s_end, + inputs, + prompts, + completions, + completion_ids_list, + is_last_chunk, ): """Score a chunk of prompt groups: rewards, policy logprobs, advantages. @@ -1043,34 +1212,57 @@ class AsyncGRPOTrainer(GRPOTrainer): chunk_completion_ids = data["completion_ids"][s_start:s_end] chunk_prompt_mask = data["prompt_mask"][s_start:s_end] chunk_completion_mask = data["completion_mask"][s_start:s_end] - prompt_completion_ids = torch.cat([chunk_prompt_ids, chunk_completion_ids], dim=1) + prompt_completion_ids = torch.cat( + [chunk_prompt_ids, chunk_completion_ids], dim=1 + ) attention_mask = torch.cat([chunk_prompt_mask, chunk_completion_mask], dim=1) logits_to_keep = chunk_completion_ids.size(1) # Slice multimodal forward kwargs for this chunk forward_kwargs = {} - for key in ("pixel_values", "image_grid_thw", "pixel_attention_mask", - "image_sizes", "token_type_ids", "mm_token_type_ids"): + for key in ( + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): if key in data: val = data[key] - if isinstance(val, torch.Tensor) and val.dim() > 0 and val.size(0) == len(data["prompt_ids"]): + if ( + isinstance(val, torch.Tensor) + and val.dim() > 0 + and val.size(0) == len(data["prompt_ids"]) + ): forward_kwargs[key] = val[s_start:s_end] else: forward_kwargs[key] = val num_images = data.get("num_images") - if num_images is not None and hasattr(num_images, "__getitem__") and len(num_images) == len(data["prompt_ids"]): + if ( + num_images is not None + and hasattr(num_images, "__getitem__") + and len(num_images) == len(data["prompt_ids"]) + ): num_images = num_images[s_start:s_end] logprob_batch_size = min(batch_size * 4, chunk_size) - with disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + with disable_gradient_checkpointing( + self.model, self.args.gradient_checkpointing_kwargs + ): generate_every = self.args.steps_per_generation * self.num_iterations if self.args.gradient_accumulation_steps % generate_every != 0 or ( - self.use_vllm and getattr(self, "vllm_importance_sampling_correction", False) + self.use_vllm + and getattr(self, "vllm_importance_sampling_correction", False) ): old_logps, _ = self._get_per_token_logps_and_entropies( - self.model, prompt_completion_ids, attention_mask, - logits_to_keep, logprob_batch_size, - num_images=num_images, **forward_kwargs, + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + logprob_batch_size, + num_images=num_images, + **forward_kwargs, ) if "old_per_token_logps" not in data: total = len(data["prompt_ids"]) @@ -1082,11 +1274,15 @@ class AsyncGRPOTrainer(GRPOTrainer): # Compute IS ratio for this chunk if "sampling_per_token_logps" in data: samp_chunk = data["sampling_per_token_logps"][s_start:s_end] - is_mask = chunk_completion_mask if "tool_mask" not in data else ( - chunk_completion_mask * data["tool_mask"][s_start:s_end] + is_mask = ( + chunk_completion_mask + if "tool_mask" not in data + else (chunk_completion_mask * data["tool_mask"][s_start:s_end]) ) diff = (old_logps - samp_chunk) * is_mask - is_mode = getattr(self, "vllm_importance_sampling_mode", "token_truncate") + is_mode = getattr( + self, "vllm_importance_sampling_mode", "token_truncate" + ) is_cap = getattr(self, "vllm_importance_sampling_cap", 3.0) seq_is = is_mode in ("sequence_mask", "sequence_truncate") logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff @@ -1098,23 +1294,39 @@ class AsyncGRPOTrainer(GRPOTrainer): if "importance_sampling_ratio" not in data: total = len(data["prompt_ids"]) shape = (total, 1) if seq_is else (total, is_ratio.size(1)) - data["importance_sampling_ratio"] = torch.ones(*shape, device=device, dtype=is_ratio.dtype) + data["importance_sampling_ratio"] = torch.ones( + *shape, device=device, dtype=is_ratio.dtype + ) data["importance_sampling_ratio"][s_start:s_end] = is_ratio # Reference logprobs if self.beta != 0.0: if self.ref_model is not None: ref_logps, _ = self._get_per_token_logps_and_entropies( - self.ref_model, prompt_completion_ids, attention_mask, - logits_to_keep, batch_size, num_images=num_images, **forward_kwargs, + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, ) else: model = self.accelerator.unwrap_model(self.model) - adapter_name = "ref" if hasattr(model, "peft_config") and "ref" in model.peft_config else None + adapter_name = ( + "ref" + if hasattr(model, "peft_config") and "ref" in model.peft_config + else None + ) with use_adapter(model, adapter_name=adapter_name): ref_logps, _ = self._get_per_token_logps_and_entropies( - self.model, prompt_completion_ids, attention_mask, - logits_to_keep, batch_size, num_images=num_images, **forward_kwargs, + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, ) if "ref_per_token_logps" not in data: total = len(data["prompt_ids"]) @@ -1124,14 +1336,26 @@ class AsyncGRPOTrainer(GRPOTrainer): data["ref_per_token_logps"][s_start:s_end] = ref_logps # --- Rewards --- - rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards( + inputs, prompts, completions, completion_ids_list + ) # --- Advantages (group-level normalization) --- if self.multi_objective_aggregation == "sum_then_normalize": - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - mean_g = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations) + rewards = ( + rewards_per_func * self.reward_weights.to(device).unsqueeze(0) + ).nansum(dim=1) + mean_g = ( + rewards.view(-1, num_generations) + .mean(dim=1) + .repeat_interleave(num_generations) + ) if num_generations > 1: - std_r = rewards.view(-1, num_generations).std(dim=1).repeat_interleave(num_generations) + std_r = ( + rewards.view(-1, num_generations) + .std(dim=1) + .repeat_interleave(num_generations) + ) else: std_r = torch.zeros_like(rewards) advantages = rewards - mean_g @@ -1142,15 +1366,33 @@ class AsyncGRPOTrainer(GRPOTrainer): elif self.multi_objective_aggregation == "normalize_then_sum": grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) mean_k = torch.nanmean(grouped, dim=1, keepdim=True) - std_k = nanstd(grouped, dim=1, keepdim=True) if num_generations > 1 else torch.zeros_like(mean_k) - reward_k = ((grouped - mean_k) / (std_k + 1e-4)).view(-1, len(self.reward_funcs)) - rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - std_r = rewards.view(-1, num_generations).std(dim=1).repeat_interleave(num_generations) - mean_r = rewards.view(-1, num_generations).mean(dim=1).repeat_interleave(num_generations) + std_k = ( + nanstd(grouped, dim=1, keepdim=True) + if num_generations > 1 + else torch.zeros_like(mean_k) + ) + reward_k = ((grouped - mean_k) / (std_k + 1e-4)).view( + -1, len(self.reward_funcs) + ) + rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum( + dim=1 + ) + std_r = ( + rewards.view(-1, num_generations) + .std(dim=1) + .repeat_interleave(num_generations) + ) + mean_r = ( + rewards.view(-1, num_generations) + .mean(dim=1) + .repeat_interleave(num_generations) + ) advantages = (rewards - mean_r) / (std_r + 1e-4) is_std_zero = torch.isclose(std_r, torch.zeros_like(std_r)) else: - raise ValueError(f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}") + raise ValueError( + f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}" + ) process_slice = slice( self.accelerator.process_index * len(prompts), @@ -1164,12 +1406,18 @@ class AsyncGRPOTrainer(GRPOTrainer): # --- Chunk metrics --- for i, name in enumerate(self.reward_func_names): - self._metrics[mode][f"rewards/{name}/mean"].append(torch.nanmean(rewards_per_func[:, i]).item()) - self._metrics[mode][f"rewards/{name}/std"].append(nanstd(rewards_per_func[:, i]).item()) + self._metrics[mode][f"rewards/{name}/mean"].append( + torch.nanmean(rewards_per_func[:, i]).item() + ) + self._metrics[mode][f"rewards/{name}/std"].append( + nanstd(rewards_per_func[:, i]).item() + ) agg_rewards = rewards_per_func.nansum(dim=1) self._metrics[mode]["reward"].append(agg_rewards.mean().item()) self._metrics[mode]["reward_std"].append(agg_rewards.std().item()) - self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + self._metrics[mode]["frac_reward_zero_std"].append( + is_std_zero.float().mean().item() + ) # --- Full-batch metrics on last chunk --- if is_last_chunk: @@ -1183,22 +1431,37 @@ class AsyncGRPOTrainer(GRPOTrainer): comp_lengths = all_completion_mask.sum(dim=1) agg_lengths = self.accelerator.gather(comp_lengths) - self._metrics[mode]["completions/mean_length"].append(agg_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_lengths.float().max().item()) + self._metrics[mode]["completions/mean_length"].append( + agg_lengths.float().mean().item() + ) + self._metrics[mode]["completions/min_length"].append( + agg_lengths.float().min().item() + ) + self._metrics[mode]["completions/max_length"].append( + agg_lengths.float().max().item() + ) eos_and_pad = [self.eos_token_id, self.pad_token_id] is_trunc = torch.tensor( - [ids[-1].item() not in eos_and_pad for ids in all_completion_ids], device=device + [ids[-1].item() not in eos_and_pad for ids in all_completion_ids], + device=device, ) agg_trunc = self.accelerator.gather(is_trunc) - self._metrics[mode]["completions/clipped_ratio"].append(agg_trunc.float().mean().item()) + self._metrics[mode]["completions/clipped_ratio"].append( + agg_trunc.float().mean().item() + ) term = agg_lengths[~agg_trunc] if len(term) == 0: term = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term.float().max().item()) + self._metrics[mode]["completions/mean_terminated_length"].append( + term.float().mean().item() + ) + self._metrics[mode]["completions/min_terminated_length"].append( + term.float().min().item() + ) + self._metrics[mode]["completions/max_terminated_length"].append( + term.float().max().item() + ) # IS metrics if ( @@ -1211,27 +1474,41 @@ class AsyncGRPOTrainer(GRPOTrainer): samp_lp = data["sampling_per_token_logps"] mask = all_completion_mask.bool() delta = torch.abs(old_lp - samp_lp)[mask] - md = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - xd = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + md = ( + torch.mean(delta) + if delta.numel() > 0 + else torch.tensor(0.0, device=device) + ) + xd = ( + torch.max(delta) + if delta.numel() > 0 + else torch.tensor(0.0, device=device) + ) self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( self.accelerator.gather(md).mean().item() ) self._metrics[mode]["sampling/sampling_logp_difference/max"].append( self.accelerator.gather(xd).max().item() ) - is_mode = getattr(self, "vllm_importance_sampling_mode", "token_truncate") + is_mode = getattr( + self, "vllm_importance_sampling_mode", "token_truncate" + ) isr = data["importance_sampling_ratio"] - flat = isr.flatten() if is_mode in ("sequence_mask", "sequence_truncate") else isr[mask] + flat = ( + isr.flatten() + if is_mode in ("sequence_mask", "sequence_truncate") + else isr[mask] + ) if flat.numel() > 0: - self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( - nanmin(self.accelerator.gather(torch.min(flat))).item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( - self.accelerator.gather(torch.mean(flat)).nanmean().item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( - nanmax(self.accelerator.gather(torch.max(flat))).item() - ) + self._metrics[mode][ + "sampling/importance_sampling_ratio/min" + ].append(nanmin(self.accelerator.gather(torch.min(flat))).item()) + self._metrics[mode][ + "sampling/importance_sampling_ratio/mean" + ].append(self.accelerator.gather(torch.mean(flat)).nanmean().item()) + self._metrics[mode][ + "sampling/importance_sampling_ratio/max" + ].append(nanmax(self.accelerator.gather(torch.max(flat))).item()) def _score_streaming(self, rollout: dict) -> list[dict]: """Score a rollout using streaming group scoring. Returns list of micro-batches.""" @@ -1257,7 +1534,9 @@ class AsyncGRPOTrainer(GRPOTrainer): s_end = chunk_end_g * num_gen self._compute_streaming_group_scores( - data=data, s_start=s_start, s_end=s_end, + data=data, + s_start=s_start, + s_end=s_end, inputs=inputs[s_start:s_end], prompts=prompts[s_start:s_end], completions=completions[s_start:s_end], @@ -1269,7 +1548,7 @@ class AsyncGRPOTrainer(GRPOTrainer): chunk_size = s_end - s_start perm = torch.randperm(chunk_size) for mb_off in range(0, chunk_size, batch_size): - mb_idx = perm[mb_off:mb_off + batch_size] + mb_idx = perm[mb_off : mb_off + batch_size] abs_idx = mb_idx + s_start mb = {} for key in data: @@ -1315,7 +1594,8 @@ class AsyncGRPOTrainer(GRPOTrainer): # Produce a new rollout self._maybe_sync_vllm_weights() rollout_dataset = self.data_producer.produce( - self.model, self.state.global_step, + self.model, + self.state.global_step, processing_class=self.processing_class, accelerator=self.accelerator, args=self.args, @@ -1375,10 +1655,18 @@ class AsyncGRPOTrainer(GRPOTrainer): # ------------------------------------------------------------------ @staticmethod - def get_off_policy_mask(advantages, per_token_logps, sampling_per_token_logps, mask, off_policy_threshold): + def get_off_policy_mask( + advantages, + per_token_logps, + sampling_per_token_logps, + mask, + off_policy_threshold, + ): """OPSM from DeepSeek-V3.2: drop sequences with negative advantage + high KL.""" kl_div = sampling_per_token_logps - per_token_logps.detach() - seq_kl = (kl_div * mask).sum(dim=1, keepdim=True) / mask.sum(dim=1, keepdim=True).clamp(min=1.0) + seq_kl = (kl_div * mask).sum(dim=1, keepdim=True) / mask.sum( + dim=1, keepdim=True + ).clamp(min=1.0) is_pos_adv = advantages >= 0 is_low_kl = seq_kl <= off_policy_threshold return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) @@ -1386,14 +1674,24 @@ class AsyncGRPOTrainer(GRPOTrainer): def _compute_loss(self, model, inputs): """Override to add IS ratio correction and off-policy sequence masking.""" prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + completion_ids, completion_mask = ( + inputs["completion_ids"], + inputs["completion_mask"], + ) input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) - mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] + mask = ( + completion_mask + if "tool_mask" not in inputs + else completion_mask * inputs["tool_mask"] + ) per_token_logps, entropies = self._get_per_token_logps_and_entropies( - model, input_ids, attention_mask, logits_to_keep, + model, + input_ids, + attention_mask, + logits_to_keep, compute_entropy=True, pixel_values=inputs.get("pixel_values"), image_grid_thw=inputs.get("image_grid_thw"), @@ -1405,7 +1703,9 @@ class AsyncGRPOTrainer(GRPOTrainer): ) if self.top_entropy_quantile < 1.0: - entropy_mask = self.get_high_entropy_mask(entropies, mask, 1 - self.top_entropy_quantile) + entropy_mask = self.get_high_entropy_mask( + entropies, mask, 1 - self.top_entropy_quantile + ) else: entropy_mask = None @@ -1414,12 +1714,18 @@ class AsyncGRPOTrainer(GRPOTrainer): advantages = advantages.unsqueeze(1) old_per_token_logps = inputs.get("old_per_token_logps") - old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + old_per_token_logps = ( + per_token_logps.detach() + if old_per_token_logps is None + else old_per_token_logps + ) # --- OPSM (off-policy sequence mask) --- off_policy_mask = None if getattr(self, "off_policy_mask_threshold", None) is not None: - sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps) + sampling_per_token_logps = inputs.get( + "sampling_per_token_logps", old_per_token_logps + ) off_policy_mask = self.get_off_policy_mask( advantages=advantages, per_token_logps=per_token_logps, @@ -1430,11 +1736,17 @@ class AsyncGRPOTrainer(GRPOTrainer): # --- Importance weights --- log_ratio = per_token_logps - old_per_token_logps - is_level = getattr(self, "importance_sampling_level", getattr(self.args, "importance_sampling_level", "token")) + is_level = getattr( + self, + "importance_sampling_level", + getattr(self.args, "importance_sampling_level", "token"), + ) if is_level == "token": log_importance_weights = log_ratio elif is_level == "sequence": - log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) + log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp( + min=1.0 + ) log_importance_weights = log_importance_weights.unsqueeze(-1) else: raise ValueError(f"Unknown importance sampling level: {is_level}") @@ -1444,7 +1756,11 @@ class AsyncGRPOTrainer(GRPOTrainer): # --- KL divergence --- if self.beta != 0.0: ref_per_token_logps = inputs["ref_per_token_logps"] - per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) + - (ref_per_token_logps - per_token_logps) + - 1 + ) if getattr(self.args, "use_bias_correction_kl", False): per_token_kl = per_token_kl * coef_1 @@ -1460,7 +1776,11 @@ class AsyncGRPOTrainer(GRPOTrainer): coef_1_c = coef_1 per_token_loss = -torch.min(coef_1_c * advantages, coef_2 * advantages) elif self.loss_type == "sapo": - temps = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg) + temps = torch.where( + advantages > 0, + self.args.sapo_temperature_pos, + self.args.sapo_temperature_neg, + ) soft = torch.sigmoid(temps * (coef_1 - 1)) * 4 / temps per_token_loss = -soft * advantages else: @@ -1485,14 +1805,24 @@ class AsyncGRPOTrainer(GRPOTrainer): # --- Aggregate loss --- mode = "train" if self.model.training else "eval" - normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + normalizer = ( + self.current_gradient_accumulation_steps if mode == "train" else 1.0 + ) if self.loss_type in ("grpo", "sapo"): - loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() / normalizer + loss = ( + (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) + ).mean() / normalizer elif self.loss_type == "bnpo": - loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) / normalizer + loss = ( + (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) / normalizer + ) elif self.loss_type == "dr_grpo": - loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) / normalizer + loss = ( + (per_token_loss * mask).sum() + / (per_token_loss.size(0) * self.max_completion_length) + / normalizer + ) elif self.loss_type in ("cispo", "dapo"): norm = inputs["num_items_in_batch"] / self.accelerator.num_processes loss = (per_token_loss * mask).sum() / norm @@ -1505,14 +1835,22 @@ class AsyncGRPOTrainer(GRPOTrainer): completion_token_count = mask.sum().clamp(min=1.0) def masked_batch_mean(x): - return x.mean() if x.shape[1] == 1 else (x * mask).sum() / completion_token_count + return ( + x.mean() + if x.shape[1] == 1 + else (x * mask).sum() / completion_token_count + ) if self.beta != 0.0: mean_kl = masked_batch_mean(per_token_kl) - self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) + self._metrics[mode]["kl"].append( + self.accelerator.gather(mean_kl).nanmean().item() + ) mean_entropy = masked_batch_mean(entropies) - self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + self._metrics[mode]["entropy"].append( + self.accelerator.gather(mean_entropy).nanmean().item() + ) if self.loss_type in ("grpo", "bnpo", "dr_grpo", "dapo", "luspo"): is_low = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) @@ -1528,11 +1866,15 @@ class AsyncGRPOTrainer(GRPOTrainer): self._metrics[mode]["clip_ratio/high_mean"].append(g_high.nanmean().item()) self._metrics[mode]["clip_ratio/high_max"].append(nanmax(g_high).item()) g_clip = self.accelerator.gather(clip_ratio) - self._metrics[mode]["clip_ratio/region_mean"].append(g_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/region_mean"].append( + g_clip.nanmean().item() + ) elif self.loss_type == "cispo": is_cispo = (coef_1 > self.epsilon_high) & (advantages > 0) cr = masked_batch_mean(is_cispo.float()) - self._metrics[mode]["cispo_clip_ratio"].append(self.accelerator.gather(cr).nanmean().item()) + self._metrics[mode]["cispo_clip_ratio"].append( + self.accelerator.gather(cr).nanmean().item() + ) return loss diff --git a/src/axolotl/core/trainers/grpo/fast_async_trainer.py b/src/axolotl/core/trainers/grpo/fast_async_trainer.py new file mode 100644 index 000000000..76820d5f1 --- /dev/null +++ b/src/axolotl/core/trainers/grpo/fast_async_trainer.py @@ -0,0 +1,690 @@ +# Copyright 2020-2026 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Experimental GRPO extensions: parallel reward workers, replay buffer, +deferred re-roll, and zero-advantage skipping. + +These features are built as subclasses of GRPOTrainer and GRPODataProducer, +using the hook system (_compute_rewards_for_batch, _post_advantage_hook, +_pre_produce_hook) defined in the base classes. +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from dataclasses import dataclass, field + +import torch +from torch import nn + +from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOConfig, + AsyncGRPOTrainer, + GRPODataProducer, +) +from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Extended config +# --------------------------------------------------------------------------- + + +@dataclass +class FastAsyncGRPOConfig(AsyncGRPOConfig): + """GRPOConfig with additional experimental parameters.""" + + reward_num_workers: int = field( + default=1, + metadata={ + "help": "Number of persistent subprocess workers for parallel reward computation. Each worker has its " + "own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across " + "workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions." + }, + ) + replay_buffer_size: int = field( + default=0, + metadata={ + "help": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout " + "groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups " + "(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True." + }, + ) + replay_recompute_logps: bool = field( + default=True, + metadata={ + "help": "When True (default), recompute old_per_token_logps for replayed groups using the current " + "training model. This fixes the importance sampling mismatch that occurs when replaying stale data. " + "Only relevant when replay_buffer_size > 0." + }, + ) + reroll_start_fraction: float = field( + default=0.5, + metadata={ + "help": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts " + "(where all rewards in a group are identical) are buffered and re-injected into later batches when the " + "model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True." + }, + ) + reroll_max_groups: int = field( + default=1, + metadata={ + "help": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values " + "increase data utilization but reduce prompt diversity. Only used with use_data_producer=True." + }, + ) + skip_zero_advantage_batches: bool = field( + default=True, + metadata={ + "help": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning " + "signal). This saves ~0.6s per skipped batch by avoiding the forward/backward pass entirely. The step is " + "logged with skipped_zero_adv_batches=1 for monitoring." + }, + ) + + +# --------------------------------------------------------------------------- +# Extended data producer with re-roll injection +# --------------------------------------------------------------------------- + + +class RerollDataProducer(GRPODataProducer): + """GRPODataProducer that injects re-roll candidates into prompt batches. + + Reads from the trainer's ``_reroll_buffer`` (populated by + ``GRPOExperimentalTrainer._post_advantage_hook``) and replaces the + last N prompt groups with previously-failed prompts. + """ + + def _pre_produce_hook(self, inputs: list, global_step: int) -> list: + trainer = self._trainer + reroll_buf = getattr(trainer, "_reroll_buffer", None) + reroll_lock = getattr(trainer, "_reroll_lock", None) + if reroll_buf is None or reroll_lock is None: + return inputs + + max_steps = getattr(trainer.args, "max_steps", -1) + start_frac = getattr(trainer.args, "reroll_start_fraction", 1.0) + max_groups = getattr(trainer.args, "reroll_max_groups", 1) + reroll_start_step = ( + max(1, int(max_steps * start_frac)) if max_steps > 0 else float("inf") + ) + + if global_step < reroll_start_step: + return inputs + + with reroll_lock: + n_to_take = min(max_groups, len(reroll_buf)) + reroll_prompts = [reroll_buf.pop(0) for _ in range(n_to_take)] + + if reroll_prompts: + num_gen = self._num_generations + n_groups = len(inputs) // num_gen + for i, reroll_prompt in enumerate(reroll_prompts): + group_idx = n_groups - 1 - i + if group_idx < 0: + break + start = group_idx * num_gen + for j in range(num_gen): + inputs[start + j] = reroll_prompt + logger.info( + f"[REROLL] Step {global_step}: replaced {len(reroll_prompts)}/{n_groups} prompt groups " + f"with deferred re-roll candidates ({len(reroll_buf)} remaining)" + ) + + return inputs + + +# --------------------------------------------------------------------------- +# Persistent reward subprocess pool +# --------------------------------------------------------------------------- + + +def _persistent_reward_worker(conn): + """Long-lived reward worker. Receives work items, returns results.""" + while True: + try: + msg = conn.recv() + except EOFError: + break + if msg is None: # Shutdown signal + break + ( + reward_funcs, + prompts, + completions, + completion_ids_list, + inputs, + reward_func_names, + ) = msg + try: + keys = [ + key + for key in inputs[0] + if key not in ["prompt", "completion", "completion_ids"] + ] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + results = [] + for reward_func, reward_func_name in zip( + reward_funcs, reward_func_names, strict=True + ): + output = reward_func( + prompts=prompts, + completions=completions, + completion_ids=completion_ids_list, + **reward_kwargs, + ) + results.append( + [float(r) if r is not None else float("nan") for r in output] + ) + conn.send(results) + except Exception: + conn.send(None) + + +# --------------------------------------------------------------------------- +# Extended trainer +# --------------------------------------------------------------------------- + + +class FastAsyncGRPOTrainer(AsyncGRPOTrainer): + """GRPOTrainer with experimental extensions. + + Adds: + - Parallel reward subprocess workers (``reward_num_workers``) + - Replay buffer for high-signal group reuse (``replay_buffer_size``) + - Deferred re-roll of failed prompts (``reroll_start_fraction``) + - Zero-advantage micro-batch skipping + """ + + def __init__(self, *args, **kwargs): + # These must be initialized before super().__init__() because + # _create_data_producer (called during super().__init__) needs them. + self._reroll_buffer: list = [] + self._reroll_lock = threading.Lock() + + # Temporarily suppress the base class's Liger + OPSM validation check, + # since this subclass supports it via a custom compute_liger_loss override. + grpo_args = kwargs.get("args") + if grpo_args is None: + for a in args: + if hasattr(a, "off_policy_mask_threshold"): + grpo_args = a + break + saved_threshold = None + if grpo_args is not None and getattr(grpo_args, "use_liger_kernel", False): + saved_threshold = grpo_args.off_policy_mask_threshold + grpo_args.off_policy_mask_threshold = None + + super().__init__(*args, **kwargs) + + if saved_threshold is not None: + grpo_args.off_policy_mask_threshold = saved_threshold + self.off_policy_mask_threshold = saved_threshold + + # Replay buffer + if getattr(self.args, "replay_buffer_size", 0) > 0: + self._replay_buffer = ReplayBuffer(max_size=self.args.replay_buffer_size) + else: + self._replay_buffer = None + self._replay_recompute_logps = getattr( + self.args, "replay_recompute_logps", True + ) + + # Reward worker pool (lazy-initialized) + self._reward_workers = None + + # -- Factory override: use RerollDataProducer ---------------------------- + + def _create_data_producer(self, args, train_dataset): + """Override to use RerollDataProducer for re-roll prompt injection.""" + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncDataProducer, + ProducerConfig, + ) + + producer_config = ProducerConfig( + mini_epochs=args.num_iterations, + max_rollouts=None, + eval_during_produce=False, + empty_cache_before_produce=True, + empty_cache_after_produce=True, + async_prefetch=args.async_prefetch, + prefetch_depth=args.prefetch_depth, + ) + data_producer = RerollDataProducer( + config=producer_config, + prompt_dataset=train_dataset, + num_generations=self.num_generations, + generation_batch_size=args.generation_batch_size, + train_batch_size=args.per_device_train_batch_size, + steps_per_generation=args.steps_per_generation, + shuffle_dataset=self.shuffle_dataset, + seed=args.seed, + ) + if args.async_prefetch: + data_producer = AsyncDataProducer( + data_producer, + background_produce_kwargs={"skip_policy_logps": True}, + ) + return data_producer + + # -- Reward worker pool -------------------------------------------------- + + def _get_reward_workers(self): + """Return a list of persistent reward worker subprocesses (lazy-initialized).""" + import multiprocessing as _mp + + num_workers = getattr(self.args, "reward_num_workers", 1) + if num_workers < 1: + num_workers = 1 + + if self._reward_workers is not None: + alive = all(proc.is_alive() for conn, proc in self._reward_workers) + if alive and len(self._reward_workers) == num_workers: + return self._reward_workers + self._shutdown_reward_workers() + + workers = [] + for _ in range(num_workers): + parent_conn, child_conn = _mp.Pipe() + proc = _mp.Process( + target=_persistent_reward_worker, args=(child_conn,), daemon=True + ) + proc.start() + child_conn.close() + workers.append((parent_conn, proc)) + + self._reward_workers = workers + return workers + + def _shutdown_reward_workers(self): + """Shut down all persistent reward workers.""" + if self._reward_workers is None: + return + for conn, proc in self._reward_workers: + try: + conn.send(None) + proc.join(timeout=5) + except Exception: + pass + try: + conn.close() + except Exception: + pass + self._reward_workers = None + + # -- Hook overrides ------------------------------------------------------ + + def _compute_rewards_for_batch( + self, inputs, prompts, completions, completion_ids_list + ): + """Dispatch rewards to parallel subprocess workers when possible.""" + from accelerate.utils import gather + + reward_can_bg = all( + not isinstance(rf, nn.Module) and not asyncio.iscoroutinefunction(rf) + for rf in self.reward_funcs + ) + num_workers = getattr(self.args, "reward_num_workers", 1) + + if not reward_can_bg or num_workers <= 1: + return self._calculate_rewards( + inputs, prompts, completions, completion_ids_list + ) + + workers = self._get_reward_workers() + device = self.accelerator.device + num_generations = self.num_generations + num_prompts = len(prompts) + num_groups = num_prompts // num_generations + + # Shard by prompt groups across workers + groups_per_worker = max(1, (num_groups + len(workers) - 1) // len(workers)) + workers_used = [] + for w_idx, (conn, proc) in enumerate(workers): + g_start = w_idx * groups_per_worker + g_end = min((w_idx + 1) * groups_per_worker, num_groups) + if g_start >= num_groups: + break + s_start = g_start * num_generations + s_end = g_end * num_generations + conn.send( + ( + self.reward_funcs, + prompts[s_start:s_end], + completions[s_start:s_end], + completion_ids_list[s_start:s_end], + inputs[s_start:s_end], + self.reward_func_names, + ) + ) + workers_used.append(conn) + + # Collect results + all_worker_results = [] + any_failed = False + for conn in workers_used: + result = conn.recv() + if result is None: + any_failed = True + break + all_worker_results.append(result) + + if not any_failed: + rewards_per_func = torch.zeros( + num_prompts, len(self.reward_funcs), device=device + ) + offset = 0 + for worker_result in all_worker_results: + chunk_size = len(worker_result[0]) + for i, result in enumerate(worker_result): + rewards_per_func[offset : offset + chunk_size, i] = torch.tensor( + result, dtype=torch.float32, device=device + ) + offset += chunk_size + return gather(rewards_per_func) + + # Fallback to main thread on failure + return self._calculate_rewards( + inputs, prompts, completions, completion_ids_list + ) + + def _post_advantage_hook( + self, + data: dict, + rewards_per_func, + advantages, + inputs: list, + num_generations: int, + mode: str, + s_start: int | None = None, + s_end: int | None = None, + is_last_chunk: bool = True, + ) -> None: + """Replay buffer store/replace + re-roll buffering.""" + from trl.models.utils import disable_gradient_checkpointing + + # -- Replay buffer: store high-signal groups -- + if self._replay_buffer is not None: + local_grouped = rewards_per_func.view( + -1, num_generations, len(self.reward_funcs) + ) + per_group_std = local_grouped.std(dim=1) + has_signal = (per_group_std > 0).any(dim=1) + + if has_signal.any(): + grouped_adv = advantages.view(-1, num_generations) + replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1) + offset = s_start or 0 + for group_idx in has_signal.nonzero(as_tuple=True)[0]: + gi = group_idx.item() + start = offset + gi * num_generations + end = start + num_generations + group_data = {} + for key in data: + val = data[key] + if ( + isinstance(val, torch.Tensor) + and val.dim() > 0 + and val.size(0) >= end + ): + group_data[key] = val[start:end].clone() + self._replay_buffer.add(replay_scores[gi].item(), group_data) + + # Replace zero-signal groups (only in deferred path, not streaming) + if s_start is None: + no_signal = ~has_signal + n_replaced = 0 + replaced_ranges = [] + if no_signal.any() and len(self._replay_buffer) > 0: + for group_idx in no_signal.nonzero(as_tuple=True)[0]: + sampled = self._replay_buffer.sample(1) + if sampled is None: + break + sampled_group = sampled[0] + gi = group_idx.item() + start, end = gi * num_generations, (gi + 1) * num_generations + for key, val in sampled_group.items(): + if key in data and isinstance(data[key], torch.Tensor): + src = val.to(data[key].device) + tgt_seq_len = ( + data[key].size(1) if data[key].dim() > 1 else None + ) + if tgt_seq_len is not None: + if src.size(1) <= tgt_seq_len: + data[key][start:end] = 0 + data[key][start:end, : src.size(1)] = src + else: + data[key][start:end] = src[:, :tgt_seq_len] + else: + data[key][start:end] = src + replaced_ranges.append((start, end)) + n_replaced += 1 + + # Recompute old_per_token_logps for replayed groups + if ( + n_replaced > 0 + and self._replay_recompute_logps + and "old_per_token_logps" in data + ): + with ( + torch.no_grad(), + disable_gradient_checkpointing( + self.model, self.args.gradient_checkpointing_kwargs + ), + ): + for r_start, r_end in replaced_ranges: + r_ids = torch.cat( + [ + data["prompt_ids"][r_start:r_end], + data["completion_ids"][r_start:r_end], + ], + dim=1, + ) + r_mask = torch.cat( + [ + data["prompt_mask"][r_start:r_end], + data["completion_mask"][r_start:r_end], + ], + dim=1, + ) + r_logits_to_keep = data["completion_ids"].size(1) + r_fwd_kwargs = {} + for fk in ( + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): + if fk in data: + r_fwd_kwargs[fk] = data[fk] + r_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + r_ids, + r_mask, + r_logits_to_keep, + r_end - r_start, + **r_fwd_kwargs, + ) + data["old_per_token_logps"][r_start:r_end] = r_logps + + if n_replaced > 0: + self._metrics[mode]["replay_buffer_replacements"].append( + float(n_replaced) + ) + + if is_last_chunk: + self._metrics[mode]["replay_buffer_size"].append( + float(len(self._replay_buffer)) + ) + + # -- Re-roll buffer: store failed prompts -- + if getattr(self.args, "reroll_start_fraction", 1.0) < 1.0: + grouped_rewards = rewards_per_func.view( + -1, num_generations, len(self.reward_funcs) + ) + per_group_std = grouped_rewards.std(dim=1) + per_group_mean = grouped_rewards.mean(dim=1) + zero_signal = (per_group_std == 0).all(dim=1) + all_failed = (per_group_mean.abs() < 1e-6).all(dim=1) + should_reroll = zero_signal & all_failed + _n_buffered = 0 + with self._reroll_lock: + for group_idx in should_reroll.nonzero(as_tuple=True)[0]: + prompt_input = inputs[group_idx.item() * num_generations] + self._reroll_buffer.append(prompt_input) + _n_buffered += 1 + if _n_buffered > 0: + self._metrics[mode]["reroll_buffered"].append(float(_n_buffered)) + if is_last_chunk: + self._metrics[mode]["reroll_buffer_size"].append( + float(len(self._reroll_buffer)) + ) + + # -- Zero-advantage skipping + Liger OPSM --------------------------------- + + def compute_liger_loss(self, unwrapped_model, inputs): + """Liger loss with zero-adv skipping and off-policy sequence masking (OPSM). + + The base class Liger path doesn't support OPSM because the fused kernel + doesn't expose per-token logprobs needed for the KL computation. This + override computes them via chunked lm_head matmul (no grad, low memory) + and applies the OPSM to the loss mask before calling the kernel. + """ + if self.args.skip_zero_advantage_batches and torch.all( + inputs["advantages"] == 0 + ): + mode = "train" if self.model.training else "eval" + self._metrics[mode]["skipped_zero_adv_batches"].append(1.0) + return torch.tensor( + 0.0, device=inputs["advantages"].device, requires_grad=True + ) + + if self.off_policy_mask_threshold is None: + return super().compute_liger_loss(unwrapped_model, inputs) + + # OPSM path: need per_token_logps for KL, which Liger kernel doesn't provide + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = ( + inputs["completion_ids"], + inputs["completion_mask"], + ) + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + + last_hidden_state = self._get_last_hidden_state( + unwrapped_model, + input_ids, + attention_mask, + logits_to_keep, + inputs.get("pixel_values"), + inputs.get("image_grid_thw"), + inputs.get("pixel_attention_mask"), + inputs.get("image_sizes"), + ) + + loss_mask = ( + completion_mask + if "tool_mask" not in inputs + else completion_mask * inputs["tool_mask"] + ) + + # Compute per_token_logps via chunked lm_head matmul (no grad, low memory) + lm_weight = unwrapped_model.lm_head.weight + lm_bias = unwrapped_model.lm_head.bias + with torch.no_grad(): + per_token_logps_chunks = [] + for i in range(last_hidden_state.size(0)): + chunk_logits = torch.matmul(last_hidden_state[i : i + 1], lm_weight.t()) + if lm_bias is not None: + chunk_logits = chunk_logits + lm_bias + chunk_lps = ( + chunk_logits.float() + .log_softmax(-1) + .gather(-1, completion_ids[i : i + 1].unsqueeze(-1)) + .squeeze(-1) + ) + per_token_logps_chunks.append(chunk_lps) + del chunk_logits + per_token_logps = torch.cat(per_token_logps_chunks, dim=0) + + advantages = inputs["advantages"] + if advantages.dim() == 1: + advantages_2d = advantages.unsqueeze(1) + else: + advantages_2d = advantages + + sampling_per_token_logps = inputs.get("sampling_per_token_logps") + if sampling_per_token_logps is None: + sampling_per_token_logps = inputs.get("old_per_token_logps") + if sampling_per_token_logps is None: + sampling_per_token_logps = per_token_logps + + off_policy_mask = GRPOTrainer.get_off_policy_mask( + advantages=advantages_2d, + per_token_logps=per_token_logps, + sampling_per_token_logps=sampling_per_token_logps, + mask=loss_mask, + off_policy_threshold=self.off_policy_mask_threshold, + ) + loss_mask = loss_mask * off_policy_mask + + # Call the Liger fused kernel with OPSM-modified mask + loss, metrics = self.liger_grpo_loss( + _input=last_hidden_state, + lin_weight=unwrapped_model.lm_head.weight, + selected_token_ids=completion_ids, + attention_mask=loss_mask, + advantages=inputs["advantages"], + bias=unwrapped_model.lm_head.bias, + old_per_token_logps=inputs.get("old_per_token_logps"), + ref_per_token_logps=inputs.get("ref_per_token_logps"), + vllm_is_ratio=inputs.get("importance_sampling_ratio"), + ) + + mean_kl = metrics[0] if self.beta != 0.0 else None + clip_ratio = metrics[-1] + + mode = "train" if self.model.training else "eval" + if self.beta != 0.0: + self._metrics[mode]["kl"].append( + self.accelerator.gather(mean_kl).mean().item() + ) + self._metrics[mode]["clip_ratio"].append( + self.accelerator.gather(clip_ratio).mean().item() + ) + normalizer = ( + self.current_gradient_accumulation_steps if mode == "train" else 1.0 + ) + return loss / normalizer + + def _compute_loss(self, model, inputs): + if self.args.skip_zero_advantage_batches and torch.all( + inputs["advantages"] == 0 + ): + mode = "train" if self.model.training else "eval" + self._metrics[mode]["skipped_zero_adv_batches"].append(1.0) + return torch.tensor( + 0.0, device=inputs["advantages"].device, requires_grad=True + ) + return super()._compute_loss(model, inputs) diff --git a/src/axolotl/core/trainers/grpo/replay_buffer.py b/src/axolotl/core/trainers/grpo/replay_buffer.py new file mode 100644 index 000000000..9220fb6ff --- /dev/null +++ b/src/axolotl/core/trainers/grpo/replay_buffer.py @@ -0,0 +1,42 @@ +"""Simple replay buffer for storing and sampling high-signal rollout groups.""" + +import heapq + +import torch + + +class ReplayBuffer: + """Min-heap replay buffer that keeps the highest-scoring rollout groups. + Groups are scored by signal quality (advantage magnitude * reward variance). + When sampling, groups are drawn proportional to their scores. + """ + + def __init__(self, max_size: int): + self.max_size = max_size + self._heap = [] # min-heap of (score, id, data) + self._counter = 0 # unique tiebreaker for heap + + def __len__(self): + return len(self._heap) + + def add(self, score: float, data: dict): + """Add a group to the buffer. If full, replaces lowest-scoring entry.""" + self._counter += 1 + if len(self._heap) < self.max_size: + heapq.heappush(self._heap, (score, self._counter, data)) + elif score > self._heap[0][0]: + heapq.heapreplace(self._heap, (score, self._counter, data)) + + def sample(self, num_samples: int) -> list[dict] | None: + """Sample groups weighted by their scores. Returns None if buffer is empty.""" + if not self._heap: + return None + + scores = torch.tensor([item[0] for item in self._heap], dtype=torch.float32) + scores = scores.clamp(min=1e-8) # avoid zero probabilities + probs = scores / scores.sum() + replacement = num_samples > len(self._heap) + indices = torch.multinomial( + probs, num_samples, replacement=replacement + ).tolist() + return [self._heap[i][2] for i in indices] diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 219169816..3a95ad439 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -34,14 +34,13 @@ from trl.data_utils import ( is_conversational, maybe_apply_chat_template, ) - -from axolotl.monkeypatch.trainer.async_grpo import AsyncGRPOTrainer from trl.extras.profiling import profiling_context from trl.models import unwrap_model_for_generation from trl.trainer.grpo_config import GRPOConfig from trl.trainer.grpo_trainer import RewardFunc, nanstd from trl.trainer.utils import pad +from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOTrainer from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.mixins import ( DistributedParallelMixin, @@ -74,11 +73,11 @@ class AxolotlAsyncGRPOTrainer( OptimizerMixin, OptimizerInitMixin, DistributedParallelMixin, - AsyncGRPOTrainer, + FastAsyncGRPOTrainer, ): - """Extend AsyncGRPOTrainer with axolotl helpers (async prefetch, streaming, IS correction).""" + """Extend AsyncGRPOTrainer with axolotl helpers""" - _tag_names = ["trl", "grpo", "async-grpo", "axolotl"] + _tag_names = ["trl", "grpo", "async", "axolotl"] class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index 99cf6019d..ac377360a 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -234,7 +234,8 @@ class TRLConfig(BaseModel): }, ) vllm_importance_sampling_mode: ( - Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"] | None + Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"] + | None ) = Field( default=None, json_schema_extra={ @@ -243,9 +244,7 @@ class TRLConfig(BaseModel): ) vllm_importance_sampling_cap: float | None = Field( default=None, - json_schema_extra={ - "description": "Cap C for IS ratio clipping/masking." - }, + json_schema_extra={"description": "Cap C for IS ratio clipping/masking."}, ) off_policy_mask_threshold: float | None = Field( default=None, @@ -255,7 +254,53 @@ class TRLConfig(BaseModel): ) use_bias_correction_kl: bool | None = Field( default=None, + json_schema_extra={"description": "Apply IS correction to KL divergence term."}, + ) + + reward_num_workers: int = Field( + default=1, json_schema_extra={ - "description": "Apply IS correction to KL divergence term." + "description": "Number of persistent subprocess workers for parallel reward computation. Each worker has its " + "own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across " + "workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions." + }, + ) + replay_buffer_size: int = Field( + default=0, + json_schema_extra={ + "description": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout " + "groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups " + "(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True." + }, + ) + replay_recompute_logps: bool = Field( + default=True, + json_schema_extra={ + "description": "When True (default), recompute old_per_token_logps for replayed groups using the current " + "training model. This fixes the importance sampling mismatch that occurs when replaying stale data. " + "Only relevant when replay_buffer_size > 0." + }, + ) + reroll_start_fraction: float = Field( + default=0.5, + json_schema_extra={ + "description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts " + "(where all rewards in a group are identical) are buffered and re-injected into later batches when the " + "model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True." + }, + ) + reroll_max_groups: int = Field( + default=1, + json_schema_extra={ + "description": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values " + "increase data utilization but reduce prompt diversity. Only used with use_data_producer=True." + }, + ) + skip_zero_advantage_batches: bool = Field( + default=True, + json_schema_extra={ + "description": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning " + "signal). This saves ~0.6s per skipped batch by avoiding the forward/backward pass entirely. The step is " + "logged with skipped_zero_adv_batches=1 for monitoring." }, )