From 5ef3f283409d66762b45c8d638515823e4adf715 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 17 Mar 2026 11:42:47 -0400 Subject: [PATCH] Support for Async GRPO (#3486) * async grpo support * implement data producer * use fast async * handle call to create data producer * fix liger kernel setup * fix replay buffer * chore: lint * make gpus go brrr * chore: lint * inplace div_, unwrap model for logits in bf16 * fuse selective softmax and empty cuda cache on each scoring step * remove waiting for synch time and fix race * make fp8 work and allow lora kernels w rl * grpo with lora vllm sync and fixes for sharded distributed * update docs * more patches so it works against trl main * address PR feedback for corerabbit --- docs/rlhf.qmd | 207 ++ src/axolotl/cli/vllm_serve.py | 33 +- src/axolotl/core/builders/rl.py | 58 +- src/axolotl/core/trainers/grpo/__init__.py | 78 +- src/axolotl/core/trainers/grpo/args.py | 8 + .../core/trainers/grpo/async_trainer.py | 2657 +++++++++++++++++ .../core/trainers/grpo/fast_async_trainer.py | 768 +++++ .../core/trainers/grpo/replay_buffer.py | 44 + src/axolotl/core/trainers/grpo/trainer.py | 14 + src/axolotl/kernels/lora.py | 8 +- src/axolotl/kernels/quantize.py | 54 +- src/axolotl/loaders/adapter.py | 12 + src/axolotl/loaders/model.py | 2 + src/axolotl/loaders/patch_manager.py | 12 + src/axolotl/monkeypatch/trainer/trl_vllm.py | 245 ++ src/axolotl/scripts/__init__.py | 0 src/axolotl/scripts/vllm_serve_lora.py | 503 ++++ src/axolotl/scripts/vllm_worker_ext.py | 158 + src/axolotl/utils/schemas/trl.py | 122 + src/axolotl/utils/schemas/validation.py | 14 - src/axolotl/utils/schemas/vllm.py | 7 + tests/core/test_async_grpo.py | 220 ++ tests/monkeypatch/test_trl_vllm.py | 286 ++ 23 files changed, 5474 insertions(+), 36 deletions(-) create mode 100644 src/axolotl/core/trainers/grpo/async_trainer.py create mode 100644 src/axolotl/core/trainers/grpo/fast_async_trainer.py create mode 100644 src/axolotl/core/trainers/grpo/replay_buffer.py create mode 100644 src/axolotl/monkeypatch/trainer/trl_vllm.py create mode 100644 src/axolotl/scripts/__init__.py create mode 100644 src/axolotl/scripts/vllm_serve_lora.py create mode 100644 src/axolotl/scripts/vllm_worker_ext.py create mode 100644 tests/core/test_async_grpo.py create mode 100644 tests/monkeypatch/test_trl_vllm.py diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 135b3038c..60c34933d 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -721,6 +721,213 @@ trl: For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types). +#### Async GRPO + +Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step. + +```yaml +trl: + use_data_producer: true # Enable data producer protocol + use_vllm: true + async_prefetch: true # Generate rollouts in background thread + prefetch_depth: 1 # Number of rollouts to prefetch + vllm_sync_interval: 2 # Sync weights to vLLM every N steps +``` + +::: {.callout-note} +Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled). +::: + +##### vLLM LoRA Sync + +By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels. + +```yaml +adapter: lora +lora_r: 32 +lora_alpha: 64 +lora_target_linear: true + +trl: + vllm_lora_sync: true # Enable native LoRA sync +``` + +When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual: + +```bash +CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml +``` + +Then start training on a separate GPU: + +```bash +CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml +``` + +::: {.callout-tip} +LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation. +::: + +##### Streaming Partial Batch + +Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring. + +```yaml +trl: + streaming_partial_batch: true +``` + +##### Importance Sampling Correction + +When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift. + +```yaml +trl: + vllm_importance_sampling_correction: true # Enable IS correction + importance_sampling_level: token # 'token' or 'sequence' + off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this +``` + +- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel) +- `importance_sampling_level: sequence` applies per-sequence IS ratios +- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy + +##### Replay Buffer + +The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches. + +```yaml +trl: + replay_buffer_size: 100 # Max cached groups (0 = disabled) + replay_recompute_logps: true # Recompute log-probs for replayed data (recommended) +``` + +::: {.callout-note} +When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data. +::: + +##### Deferred Re-rolling + +Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them. + +```yaml +trl: + reroll_start_fraction: 0.5 # Start re-rolling after 50% of training + reroll_max_groups: 1 # Max groups to replace per batch +``` + +##### Zero-Advantage Batch Skipping + +When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`. + +```yaml +trl: + skip_zero_advantage_batches: true # default +``` + +##### Parallel Reward Workers + +Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation. + +```yaml +trl: + reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism) +``` + +##### Full Async GRPO Example + +```yaml +base_model: Qwen/Qwen2.5-1.5B-Instruct + +vllm: + host: 0.0.0.0 + port: 8000 + gpu_memory_utilization: 0.35 + dtype: auto + +adapter: lora +lora_r: 32 +lora_alpha: 64 +lora_target_linear: true + +rl: grpo +trl: + use_data_producer: true + use_vllm: true + async_prefetch: true + prefetch_depth: 1 + vllm_sync_interval: 2 + vllm_lora_sync: true + streaming_partial_batch: true + vllm_importance_sampling_correction: true + off_policy_mask_threshold: 0.5 + importance_sampling_level: token + num_generations: 8 + max_completion_length: 512 + reward_funcs: + - rewards.accuracy_reward + reroll_start_fraction: 0.5 + replay_buffer_size: 100 + reward_num_workers: 4 + skip_zero_advantage_batches: true + +datasets: + - path: AI-MO/NuminaMath-TIR + type: rewards.prompt_transform + split: train + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +max_steps: 500 +learning_rate: 1e-5 +bf16: true +gradient_checkpointing: true +``` + +```bash +# Terminal 1: Start vLLM on GPU 0 +CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml + +# Terminal 2: Train on GPU 1 +CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml +``` + +##### Multi-GPU Async GRPO + +Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs. + +**FSDP:** + +```yaml +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer +gradient_checkpointing_kwargs: + use_reentrant: false +``` + +**DeepSpeed ZeRO-3:** + +```yaml +deepspeed: deepspeed_configs/zero3_bf16.json +gradient_checkpointing_kwargs: + use_reentrant: true # Required for ZeRO-3 +``` + +```bash +# Terminal 1: Start vLLM on GPU 0 +CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml + +# Terminal 2: Train on GPUs 0,1 +CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml +``` + +::: {.callout-important} +With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads. +::: + ### GDPO GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them. diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py index ea454fc96..10db23878 100644 --- a/src/axolotl/cli/vllm_serve.py +++ b/src/axolotl/cli/vllm_serve.py @@ -38,7 +38,18 @@ def do_vllm_serve( cfg = load_cfg(config) model = cfg.base_model - serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve") + # Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default + serve_module = cli_args.get("serve_module") or getattr( + cfg.vllm, "serve_module", None + ) + if ( + serve_module is None + and getattr(cfg, "trl", None) + and getattr(cfg.trl, "vllm_lora_sync", False) + ): + serve_module = "axolotl.scripts.vllm_serve_lora" + if serve_module is None: + serve_module = "trl.scripts.vllm_serve" vllm_serve_main = __import__(serve_module, fromlist=["main"]).main tensor_parallel_size = 1 data_parallel_size = 1 @@ -68,7 +79,7 @@ def do_vllm_serve( cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False ) - vllm_script_args = AxolotlScriptArguments( + base_kwargs = dict( model=model, tensor_parallel_size=tensor_parallel_size, data_parallel_size=data_parallel_size, @@ -78,7 +89,21 @@ def do_vllm_serve( dtype=dtype, max_model_len=max_model_len, enable_prefix_caching=enable_prefix_caching, - reasoning_parser=reasoning_parser, - enable_reasoning=enable_reasoning, ) + + # Use LoRAScriptArguments when serving with native LoRA support + if serve_module == "axolotl.scripts.vllm_serve_lora": + from axolotl.scripts.vllm_serve_lora import LoRAScriptArguments + + lora_kwargs = {} + if hasattr(cfg, "lora_r") and cfg.lora_r: + lora_kwargs["max_lora_rank"] = cfg.lora_r + vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs) + else: + vllm_script_args = AxolotlScriptArguments( + **base_kwargs, + reasoning_parser=reasoning_parser, + enable_reasoning=enable_reasoning, + ) + vllm_serve_main(vllm_script_args) diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index bb67aef6d..43ef133ff 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -54,8 +54,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl in {RLType.GRPO, RLType.GDPO}: from axolotl.core.trainers.grpo import GRPOStrategy + async_grpo = bool( + self.cfg.trl + and ( + getattr(self.cfg.trl, "async_prefetch", False) + or getattr(self.cfg.trl, "use_data_producer", False) + ) + ) trainer_cls = GRPOStrategy.get_trainer_class( - sequence_parallel=self.cfg.context_parallel_size > 1 + sequence_parallel=self.cfg.context_parallel_size > 1, + async_grpo=async_grpo, ) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) @@ -151,7 +159,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase): elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}: from axolotl.core.trainers.grpo import GRPOStrategy - training_args_cls = GRPOStrategy.get_training_args_class() + async_grpo = bool( + self.cfg.trl + and ( + getattr(self.cfg.trl, "async_prefetch", False) + or getattr(self.cfg.trl, "use_data_producer", False) + ) + ) + training_args_cls = GRPOStrategy.get_training_args_class( + async_grpo=async_grpo + ) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() if self.cfg.rl is RLType.GDPO: @@ -217,13 +234,36 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_kwargs, trainer_cls ) - trainer = trainer_cls( - *trainer_cls_args, - args=training_args, - train_dataset=self.train_dataset, - callbacks=self.get_callbacks(), - **trainer_kwargs, - ) + # Allow FP8-quantized models to be fine-tuned with LoRA adapters. + # transformers' validate_quantization_for_training blocks FP8 because + # hf_quantizer.is_trainable is False, but LoRA only trains the adapters + # (base weights stay frozen in FP8). + _orig_validate_quant = None + if ( + self.cfg.adapter + and hasattr(self.model, "is_quantized") + and self.model.is_quantized + ): + import transformers.trainer as _trainer_module + + _orig_validate_quant = _trainer_module.validate_quantization_for_training + _trainer_module.validate_quantization_for_training = lambda model: None + + try: + trainer = trainer_cls( + *trainer_cls_args, + args=training_args, + train_dataset=self.train_dataset, + callbacks=self.get_callbacks(), + **trainer_kwargs, + ) + finally: + if _orig_validate_quant is not None: + import transformers.trainer as _trainer_module + + _trainer_module.validate_quantization_for_training = ( + _orig_validate_quant + ) if self.cfg.fsdp_config or self.cfg.fsdp: ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype) if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model: diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 0d2615aec..5c057cc40 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -9,8 +9,9 @@ from huggingface_hub import snapshot_download from requests import HTTPError from trl.trainer.grpo_trainer import RewardFunc -from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig +from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig from axolotl.core.trainers.grpo.trainer import ( + AxolotlAsyncGRPOTrainer, AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer, ) @@ -27,14 +28,31 @@ class GRPOStrategy: @classmethod def get_trainer_class( - cls, sequence_parallel: bool - ) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]: + cls, + sequence_parallel: bool, + async_grpo: bool = False, + ) -> ( + type[AxolotlGRPOTrainer] + | type[AxolotlGRPOSequenceParallelTrainer] + | type[AxolotlAsyncGRPOTrainer] + ): + if sequence_parallel and async_grpo: + raise ValueError( + "sequence_parallel and async_grpo cannot both be enabled. " + "Disable one of context_parallel_size > 1 or async_prefetch/use_data_producer." + ) if sequence_parallel: return AxolotlGRPOSequenceParallelTrainer + if async_grpo: + return AxolotlAsyncGRPOTrainer return AxolotlGRPOTrainer @classmethod - def get_training_args_class(cls) -> type[AxolotlGRPOConfig]: + def get_training_args_class( + cls, async_grpo: bool = False + ) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]: + if async_grpo: + return AxolotlAsyncGRPOConfig return AxolotlGRPOConfig @classmethod @@ -124,13 +142,63 @@ class GRPOStrategy: grpo_args_kwargs["epsilon_high"] = trl.epsilon_high if trl.use_liger_loss is not None: - grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss + grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss if trl.multi_objective_aggregation is not None: grpo_args_kwargs["multi_objective_aggregation"] = ( trl.multi_objective_aggregation ) + # Async GRPO fields + if getattr(trl, "use_data_producer", None) is not None: + grpo_args_kwargs["use_data_producer"] = trl.use_data_producer + if getattr(trl, "async_prefetch", None) is not None: + grpo_args_kwargs["async_prefetch"] = trl.async_prefetch + if getattr(trl, "prefetch_depth", None) is not None: + grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth + if getattr(trl, "vllm_sync_interval", None) is not None: + grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval + if getattr(trl, "streaming_partial_batch", None) is not None: + grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch + if getattr(trl, "streaming_min_groups", None) is not None: + grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups + if getattr(trl, "vllm_importance_sampling_correction", None) is not None: + grpo_args_kwargs["vllm_importance_sampling_correction"] = ( + trl.vllm_importance_sampling_correction + ) + if getattr(trl, "vllm_importance_sampling_mode", None) is not None: + grpo_args_kwargs["vllm_importance_sampling_mode"] = ( + trl.vllm_importance_sampling_mode + ) + if getattr(trl, "vllm_importance_sampling_cap", None) is not None: + grpo_args_kwargs["vllm_importance_sampling_cap"] = ( + trl.vllm_importance_sampling_cap + ) + if getattr(trl, "off_policy_mask_threshold", None) is not None: + grpo_args_kwargs["off_policy_mask_threshold"] = ( + trl.off_policy_mask_threshold + ) + if getattr(trl, "use_bias_correction_kl", None) is not None: + grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl + + # Fast Async GRPO fields + if getattr(trl, "reward_num_workers", None) is not None: + grpo_args_kwargs["reward_num_workers"] = trl.reward_num_workers + if getattr(trl, "replay_buffer_size", None) is not None: + grpo_args_kwargs["replay_buffer_size"] = trl.replay_buffer_size + if getattr(trl, "replay_recompute_logps", None) is not None: + grpo_args_kwargs["replay_recompute_logps"] = trl.replay_recompute_logps + if getattr(trl, "reroll_start_fraction", None) is not None: + grpo_args_kwargs["reroll_start_fraction"] = trl.reroll_start_fraction + if getattr(trl, "reroll_max_groups", None) is not None: + grpo_args_kwargs["reroll_max_groups"] = trl.reroll_max_groups + if getattr(trl, "skip_zero_advantage_batches", None) is not None: + grpo_args_kwargs["skip_zero_advantage_batches"] = ( + trl.skip_zero_advantage_batches + ) + if getattr(trl, "vllm_lora_sync", None) is not None: + grpo_args_kwargs["vllm_lora_sync"] = trl.vllm_lora_sync + return grpo_args_kwargs @classmethod diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index 2ea52998e..f1dd5a6e7 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from trl import GRPOConfig +from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig from axolotl.core.training_args import AxolotlTrainingMixins @@ -14,3 +15,10 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): """Axolotl GRPO Config for GRPO training""" context_parallel_size: int | None = None + + +@dataclass +class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig): + """Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction.""" + + context_parallel_size: int | None = None diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py new file mode 100644 index 000000000..acfd02909 --- /dev/null +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -0,0 +1,2657 @@ +""" +Async GRPO training with streaming scoring and IS correction. + +Works on stock TRL v0.29.0 and transformers v5.3.0 — no custom branches needed. + +Features: + - Async prefetch: background thread generates completions via vLLM while the main + thread trains on the previous rollout. + - Deferred scoring: rewards, advantages, and policy logprobs computed on the main + thread (thread-safe with GPU forward passes). + - Streaming group scoring: scores prompt groups incrementally so that reward + computation overlaps with the next group's logprob computation. + - Importance sampling (IS) correction: corrects for stale vLLM weights. + - Off-Policy Sequence Mask (OPSM): drops sequences with high KL + negative advantage. + - Configurable vLLM weight sync interval. + +Classes exported: + - AsyncGRPOConfig: GRPOConfig extended with async/streaming/IS fields + - AsyncGRPOTrainer: GRPOTrainer with async prefetch and IS correction + - ProducerConfig, DataProducer, BaseDataProducer, AsyncDataProducer: data producer protocol +""" + +import atexit +import concurrent.futures +import logging +import queue +import threading +from abc import ABC, abstractmethod +from collections import deque +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Any + +import torch +from torch.utils.data import DataLoader, Dataset +from trl.extras.profiling import profiling_decorator +from trl.trainer import GRPOConfig, GRPOTrainer +from trl.trainer.utils import ( + RepeatSampler, + entropy_from_logits, + nanmax, + nanmin, + nanstd, + pad, + selective_log_softmax, + shuffle_sequence_dict, + split_pixel_values_by_grid, + split_tensor_dict, + unsplit_pixel_values_by_grid, +) + +try: + from trl.data_utils import ( + apply_chat_template, + is_conversational, + prepare_multimodal_messages, + ) +except ImportError: + from trl.chat_template_utils import apply_chat_template + from trl.data_utils import is_conversational, prepare_multimodal_messages + +try: + from trl.models.utils import disable_gradient_checkpointing +except ImportError: + from contextlib import contextmanager + + @contextmanager + def disable_gradient_checkpointing(model, kwargs): + yield + + +try: + from accelerate.utils import gather_object +except ImportError: + gather_object = None + +try: + from peft import PeftModel + from trl.trainer.utils import use_adapter +except ImportError: + PeftModel = None + use_adapter = nullcontext + +try: + from liger_kernel.ops.grpo_loss import ( + fused_selective_log_softmax as _fused_selective_log_softmax, + ) +except ImportError: + _fused_selective_log_softmax = None + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +@dataclass +class AsyncGRPOConfig(GRPOConfig): + """GRPOConfig extended with async prefetch, streaming scoring, and IS correction fields. + + Fields already present in stock GRPOConfig (e.g. ``importance_sampling_level``, + ``multi_objective_aggregation``) are listed here for safety: if the stock version + does not define them, the defaults below ensure everything works. + """ + + # --- Data producer --- + use_data_producer: bool = field( + default=False, + metadata={ + "help": "Use the GRPODataProducer protocol for online data generation." + }, + ) + + # --- Async data production --- + async_prefetch: bool = field( + default=False, + metadata={ + "help": "Generate rollouts in a background thread while training on the previous rollout." + }, + ) + prefetch_depth: int = field( + default=1, + metadata={"help": "Number of rollouts to prefetch ahead of training."}, + ) + vllm_sync_interval: int = field( + default=1, + metadata={ + "help": "Sync model weights to vLLM every N optimizer steps (async mode only)." + }, + ) + + # --- Streaming scoring --- + streaming_partial_batch: bool = field( + default=False, + metadata={ + "help": "Score prompt groups incrementally instead of the full batch at once." + }, + ) + streaming_min_groups: int = field( + default=1, + metadata={"help": "Minimum prompt groups to score per streaming chunk."}, + ) + + # --- vLLM importance sampling correction --- + vllm_importance_sampling_correction: bool = field( + default=True, + metadata={ + "help": "Apply IS correction for distribution mismatch between vLLM and training model." + }, + ) + vllm_importance_sampling_mode: str = field( + default="token_truncate", + metadata={ + "help": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask." + }, + ) + vllm_importance_sampling_cap: float = field( + default=3.0, + metadata={"help": "Cap C for IS ratio clipping/masking."}, + ) + + # --- Off-policy sequence mask (OPSM) --- + off_policy_mask_threshold: float | None = field( + default=None, + metadata={"help": "KL threshold for OPSM (DeepSeek-V3.2). None = disabled."}, + ) + + # --- Bias-corrected KL --- + use_bias_correction_kl: bool = field( + default=False, + metadata={"help": "Apply IS correction to KL divergence term."}, + ) + + +# --------------------------------------------------------------------------- +# Data Producer Protocol (standalone — no transformers branch needed) +# --------------------------------------------------------------------------- + +logger = logging.getLogger(__name__) +_dp_logger = logging.getLogger(__name__ + ".data_producer") + + +@dataclass +class ProducerConfig: + """Configuration for a :class:`DataProducer`. + + Args: + mini_epochs: Number of training passes over each produced dataset. + max_rollouts: Maximum number of produce-then-train rounds (None = unlimited). + steps_per_generation: Optimisation steps per produced dataset before regenerating. + num_iterations: Number of times to reuse each generation across optimisation steps. + async_prefetch: Produce the next dataset in a background thread. + prefetch_depth: How many rollouts to queue ahead when async. + sync_warmup_rollouts: Initial on-policy rollouts before switching to async. + eval_during_produce: Switch model to eval() during produce(). + empty_cache_before_produce: torch.cuda.empty_cache() before produce(). + empty_cache_after_produce: torch.cuda.empty_cache() after produce(). + """ + + mini_epochs: int = 1 + max_rollouts: int | None = None + steps_per_generation: int | None = None + num_iterations: int = 1 + async_prefetch: bool = False + prefetch_depth: int = 1 + sync_warmup_rollouts: int = 0 + eval_during_produce: bool = True + empty_cache_before_produce: bool = False + empty_cache_after_produce: bool = False + + def __post_init__(self): + if self.mini_epochs < 1: + raise ValueError(f"mini_epochs must be >= 1, got {self.mini_epochs}") + if self.max_rollouts is not None and self.max_rollouts < 1: + raise ValueError( + f"max_rollouts must be >= 1 or None, got {self.max_rollouts}" + ) + if self.num_iterations < 1: + raise ValueError(f"num_iterations must be >= 1, got {self.num_iterations}") + if self.steps_per_generation is not None and self.steps_per_generation < 1: + raise ValueError( + f"steps_per_generation must be >= 1 or None, got {self.steps_per_generation}" + ) + if self.prefetch_depth < 1: + raise ValueError(f"prefetch_depth must be >= 1, got {self.prefetch_depth}") + if self.sync_warmup_rollouts < 0: + raise ValueError( + f"sync_warmup_rollouts must be >= 0, got {self.sync_warmup_rollouts}" + ) + + +class DataProducer(ABC): + """Abstract base class for online data producers. + + Subclass this and implement :meth:`produce` to supply fresh training data + each rollout round. + """ + + config: ProducerConfig + + @abstractmethod + def produce( + self, + model: Any, + global_step: int, + *, + processing_class: Any = None, + accelerator: Any = None, + args: Any = None, + **kwargs, + ) -> Dataset: + """Generate a fresh training dataset.""" + ... + + +class BaseDataProducer(DataProducer): + """Convenience base class with a default :class:`ProducerConfig` and lifecycle hooks.""" + + def __init__(self, config: ProducerConfig | None = None): + self.config = config or ProducerConfig() + + def on_rollout_begin(self, global_step: int) -> None: + """Called before each produce() invocation.""" + + def on_rollout_end(self, dataset: Dataset, global_step: int) -> None: + """Called after each produce() invocation with the produced dataset.""" + + +class AsyncDataProducer: + """Wraps a synchronous :class:`DataProducer` for background-thread data generation. + + While the Trainer trains on the current rollout, this wrapper produces upcoming + datasets in a background thread. + + FSDP compatibility: Background threads must NOT call cross-rank collectives + (gather_object, broadcast_object_list, FSDP all-gather) because the main thread + may be doing FSDP forward/backward concurrently, causing deadlocks. When + ``num_processes > 1``, only rank 0 runs BG generation; results are broadcast + to other ranks on the main thread when ``produce()`` is next called. + """ + + def __init__( + self, inner: DataProducer, background_produce_kwargs: dict | None = None + ): + self._inner = inner + self._depth = inner.config.prefetch_depth + self._warmup_remaining = inner.config.sync_warmup_rollouts + self._background_kwargs = background_produce_kwargs or {} + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="async-producer" + ) + self._queue: deque[concurrent.futures.Future] = deque() + self._initialized = False + # Lock held by the background thread during vLLM generation. + # The main thread acquires this lock for weight sync to ensure + # merge_adapter/unmerge_adapter don't overlap with generation. + self._generate_lock = threading.Lock() + # Detected at first produce() call + self._num_processes: int | None = None + self._is_main: bool | None = None + + @property + def config(self) -> ProducerConfig: + return self._inner.config + + def produce(self, model: Any, global_step: int, **kwargs) -> Dataset: + """Return the next dataset, blocking if the prefetch hasn't finished.""" + # Detect multi-process on first call + if self._num_processes is None: + accelerator = kwargs.get("accelerator") + if accelerator is not None: + self._num_processes = accelerator.num_processes + self._is_main = accelerator.is_main_process + else: + self._num_processes = 1 + self._is_main = True + + # During warmup, produce synchronously (on-policy) + if self._warmup_remaining > 0: + self._warmup_remaining -= 1 + _dp_logger.info( + f"AsyncDataProducer: sync warmup rollout (remaining={self._warmup_remaining})" + ) + return self._inner.produce(model, global_step, **kwargs) + + if not self._initialized: + dataset = self._inner.produce(model, global_step, **kwargs) + bg_kwargs = {**kwargs, **self._background_kwargs} + # With FSDP (multi-process), only submit BG tasks on rank 0. + # Non-rank-0 processes will receive data via broadcast. + if self._num_processes > 1: + bg_kwargs["_rank0_only"] = True + for i in range(1, self._depth + 1): + self._queue.append( + self._executor.submit( + self._locked_produce, model, global_step + i, **bg_kwargs + ) + ) + self._initialized = True + return dataset + + # Get the pre-generated dataset from the BG thread + dataset = self._queue.popleft().result() + + # With FSDP: BG thread only ran on rank 0. Broadcast to all ranks. + if self._num_processes > 1: + dataset = self._broadcast_dataset(dataset) + + bg_kwargs = {**kwargs, **self._background_kwargs} + if self._num_processes > 1: + bg_kwargs["_rank0_only"] = True + next_step = global_step + self._depth + self._queue.append( + self._executor.submit(self._locked_produce, model, next_step, **bg_kwargs) + ) + return dataset + + def _broadcast_dataset(self, dataset) -> Dataset: + """Broadcast a prefetched dataset from rank 0 to all ranks (main thread). + + Rank 0 has a full RolloutDataset from BG generation; other ranks have None. + After broadcast, tensors are moved to each rank's local device. + """ + import torch.distributed as dist + + if not dist.is_initialized(): + return dataset + + # Rank 0 sends _data dict; others receive it + obj_list = [dataset._data if self._is_main else None] + dist.broadcast_object_list(obj_list, src=0) + + data: dict[str, Any] = obj_list[0] # type: ignore[assignment] + + # Move tensors to local device (broadcast_object_list deserializes to CPU) + accelerator = self._inner._trainer.accelerator # type: ignore[attr-defined] + device = accelerator.device + for key, val in data.items(): + if isinstance(val, torch.Tensor) and val.device != device: + data[key] = val.to(device) + + if not self._is_main: + from axolotl.core.trainers.grpo.async_trainer import RolloutDataset + + dataset = RolloutDataset(data) + else: + # Rank 0 already has the dataset, but update _data with device-moved tensors + dataset._data = data + return dataset + + def _locked_produce(self, model: Any, global_step: int, **kwargs) -> Dataset: + """Run produce while holding the generate lock.""" + with self._generate_lock: + return self._inner.produce(model, global_step, **kwargs) + + def on_rollout_begin(self, global_step: int) -> None: + if hasattr(self._inner, "on_rollout_begin"): + self._inner.on_rollout_begin(global_step) + + def on_rollout_end(self, dataset: Dataset, global_step: int) -> None: + if hasattr(self._inner, "on_rollout_end"): + self._inner.on_rollout_end(dataset, global_step) + + def shutdown(self) -> None: + """Shut down the background thread pool and cancel pending futures.""" + for future in self._queue: + future.cancel() + self._queue.clear() + self._executor.shutdown(wait=False) + + +class DataProducerCallback: + """Marker class: if a DataProducer also inherits from this, the Trainer will + automatically register it as a callback.""" + + pass + + +# --------------------------------------------------------------------------- +# RolloutDataset + GRPODataProducer +# --------------------------------------------------------------------------- + + +class RolloutDataset(Dataset): + """A Dataset wrapping the output dict from _generate_and_score_completions. + + Per-sample tensors are sliced by index; shared metadata is passed through. + """ + + _ALWAYS_SHARED = frozenset( + {"num_items_in_batch", "_pending_policy_logps", "_rank0_only"} + ) + + def __init__(self, data: dict[str, Any]): + self._data = data + self._shared_keys: set[str] = set() + self._sample_keys: set[str] = set() + + for key, val in data.items(): + if key in self._ALWAYS_SHARED: + self._shared_keys.add(key) + elif not isinstance(val, torch.Tensor): + self._shared_keys.add(key) + elif val.dim() == 0: + self._shared_keys.add(key) + else: + self._sample_keys.add(key) + + self._num_samples = 0 + for key in self._sample_keys: + n = data[key].size(0) + if self._num_samples == 0: + self._num_samples = n + elif n != self._num_samples: + raise ValueError( + f"Inconsistent sample count: key '{key}' has {n}, expected {self._num_samples}" + ) + if self._num_samples == 0: + raise ValueError("No per-sample tensors found in rollout data") + + def __len__(self) -> int: + return self._num_samples + + def __getitem__(self, idx: int) -> dict[str, Any]: + item: dict[str, Any] = {} + for key in self._sample_keys: + item[key] = self._data[key][idx] + for key in self._shared_keys: + item[key] = self._data[key] + return item + + +def make_rollout_collator(shared_keys: set[str]): + """Return a collator that stacks per-sample tensors and passes shared keys through.""" + + def _collate(batch: list[dict[str, Any]]) -> dict[str, Any]: + result: dict[str, Any] = {} + for key in batch[0]: + if key in shared_keys: + result[key] = batch[0][key] + else: + values = [item[key] for item in batch] + if isinstance(values[0], torch.Tensor): + result[key] = torch.stack(values) + else: + result[key] = values + return result + + return _collate + + +class GRPODataProducer(BaseDataProducer): + """Produces GRPO training rollouts using the trainer's generation pipeline. + + Created before Trainer.__init__ completes; the trainer reference is injected + later via set_trainer(). + """ + + def __init__( + self, + config: ProducerConfig, + prompt_dataset, + *, + num_generations: int, + generation_batch_size: int, + train_batch_size: int, + steps_per_generation: int, + shuffle_dataset: bool, + seed: int, + ): + super().__init__(config) + self._dataset = prompt_dataset + self._num_generations = num_generations + self._generation_batch_size = generation_batch_size + self._train_batch_size = train_batch_size + self._steps_per_generation = steps_per_generation + self._shuffle_dataset = shuffle_dataset + self._seed = seed + self._trainer: Any = None + self._prompt_dl: Any = None + self._prompt_iter: Any = None + + def set_trainer(self, trainer) -> None: + """Inject the live trainer reference and create the prompt DataLoader.""" + self._trainer = trainer + self._init_prompt_dataloader() + + def _init_prompt_dataloader(self) -> None: + from functools import partial + + from transformers.trainer_utils import seed_worker + + trainer = self._trainer + sampler = RepeatSampler( + data_source=self._dataset, + mini_repeat_count=self._num_generations, + batch_size=self._generation_batch_size // self._num_generations, + repeat_count=1, + shuffle=self._shuffle_dataset, + seed=self._seed, + ) + + # Use identity collator (same as stock GRPOTrainer) + def _identity(x): + return x + + dl = DataLoader( + self._dataset, + batch_size=self._train_batch_size * self._steps_per_generation, + sampler=sampler, + collate_fn=_identity, + num_workers=trainer.args.dataloader_num_workers, + pin_memory=trainer.args.dataloader_pin_memory, + persistent_workers=trainer.args.dataloader_persistent_workers, + worker_init_fn=partial( + seed_worker, + num_workers=trainer.args.dataloader_num_workers, + rank=trainer.args.process_index, + ), + ) + self._prompt_dl = trainer.accelerator.prepare(dl) + + # Don't let accelerator track this dataloader + acc_dls = trainer.accelerator._dataloaders + if self._prompt_dl in acc_dls: + acc_dls.remove(self._prompt_dl) + + self._prompt_iter = iter(self._prompt_dl) + + def produce( + self, + model: Any, + global_step: int, + *, + skip_policy_logps: bool = False, + processing_class: Any = None, + accelerator: Any = None, + args: Any = None, + _rank0_only: bool = False, + **kwargs, + ) -> RolloutDataset | None: + """Generate a fresh GRPO training rollout.""" + is_main = self._trainer.accelerator.is_main_process + + # FSDP rank0-only mode: non-rank-0 returns None (broadcast fills it later) + if _rank0_only and not is_main: + return None + + try: + inputs = next(self._prompt_iter) + except StopIteration: + self._prompt_iter = iter(self._prompt_dl) + inputs = next(self._prompt_iter) + + if skip_policy_logps: + # Async path: use _generate_only (generation without scoring) which + # works on stock TRL (no skip_policy_logps parameter needed). + output = self._trainer._generate_only(inputs, rank0_only=_rank0_only) + else: + # Sync path: full generation + scoring + output = self._trainer._generate_and_score_completions(inputs) + + # Strip non-sequence metadata before shuffling + metadata = {} + for key in list(output.keys()): + val = output[key] + if not isinstance(val, (torch.Tensor, list)): + metadata[key] = output.pop(key) + elif isinstance(val, torch.Tensor) and val.dim() == 0: + metadata[key] = output.pop(key) + + output = shuffle_sequence_dict(output) + output.update(metadata) + + return RolloutDataset(output) + + +# --------------------------------------------------------------------------- +# Trainer +# --------------------------------------------------------------------------- + + +class AsyncGRPOTrainer(GRPOTrainer): + """GRPOTrainer with async prefetch, streaming scoring, and IS correction. + + Drop-in replacement: pass ``AsyncGRPOConfig`` as ``args`` and use this trainer + instead of ``GRPOTrainer``. + """ + + def __init__(self, *args, **kwargs): + # When using native LoRA sync, skip the NCCL communicator init in VLLMGeneration. + # The communicator is not needed because weight sync happens via filesystem + HTTP, + # and it fails when vLLM and a trainer rank share the same CUDA device. + training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None) + if training_args is not None and getattr( + training_args, "vllm_lora_sync", False + ): + from trl.generation.vllm_generation import VLLMGeneration + + _orig_init_vllm = VLLMGeneration._init_vllm + + def _init_vllm_no_communicator(self_vllm): + """Init vLLM client without NCCL communicator (LoRA sync uses filesystem).""" + if self_vllm.mode == "server" and self_vllm.accelerator.is_main_process: + from trl.generation.vllm_client import VLLMClient + + if self_vllm.server_base_url is not None: + base_url = self_vllm.server_base_url + else: + base_url = ( + f"http://{self_vllm.server_host}:{self_vllm.server_port}" + ) + self_vllm.vllm_client = VLLMClient( + base_url=base_url, + group_port=self_vllm.group_port, + connection_timeout=self_vllm.server_timeout, + ) + # Deliberately skip init_communicator — no NCCL needed + elif self_vllm.mode != "server": + _orig_init_vllm(self_vllm) + + VLLMGeneration._init_vllm = _init_vllm_no_communicator + + super().__init__(*args, **kwargs) + + # FP8 models: zero out the pad token embedding so that padding + # positions have zero hidden states throughout the network. + # FP8 linear layers produce NaN on non-zero inputs at masked + # positions (the Triton fp8 matmul kernel can't handle the + # extreme values that accumulate at unattended positions). + self._zero_pad_embedding_for_fp8() + + # Ensure custom attributes exist (stock GRPOTrainer.__init__ may not set them). + for attr, cfg_key, default in [ + ( + "vllm_importance_sampling_correction", + "vllm_importance_sampling_correction", + True, + ), + ( + "vllm_importance_sampling_mode", + "vllm_importance_sampling_mode", + "token_truncate", + ), + ("vllm_importance_sampling_cap", "vllm_importance_sampling_cap", 3.0), + ("off_policy_mask_threshold", "off_policy_mask_threshold", None), + ]: + if not hasattr(self, attr): + setattr(self, attr, getattr(self.args, cfg_key, default)) + + # Async state + self._async_queue: queue.Queue | None = None + self._executor: concurrent.futures.ThreadPoolExecutor | None = None + self._prompt_iter = None + self._last_synced_step = -1 + self._buffered_inputs: list | None = None # override stock attr + + # Data producer (the proper architecture for async generation) + self.data_producer = None + if getattr(self.args, "use_data_producer", False): + self.data_producer = self._create_data_producer( + kwargs["args"], kwargs["train_dataset"] + ) + + if self.args.async_prefetch and self.data_producer is None: + # Legacy path: direct _prepare_inputs override without data producer + self._setup_async() + + def _create_data_producer(self, args, train_dataset): + """Create and return the GRPODataProducer (possibly wrapped in AsyncDataProducer).""" + producer_config = ProducerConfig( + mini_epochs=args.num_iterations, + max_rollouts=None, + eval_during_produce=False, + empty_cache_before_produce=True, + empty_cache_after_produce=True, + async_prefetch=args.async_prefetch, + prefetch_depth=args.prefetch_depth, + ) + data_producer = GRPODataProducer( + config=producer_config, + prompt_dataset=train_dataset, + num_generations=self.num_generations, + generation_batch_size=args.generation_batch_size, + train_batch_size=args.per_device_train_batch_size, + steps_per_generation=args.steps_per_generation, + shuffle_dataset=getattr(self, "shuffle_dataset", True), + seed=args.seed, + ) + data_producer.set_trainer(self) + + if args.async_prefetch: + data_producer = AsyncDataProducer( + data_producer, + background_produce_kwargs={"skip_policy_logps": True}, + ) + return data_producer + + # ------------------------------------------------------------------ + # Async setup / teardown + # ------------------------------------------------------------------ + + def _setup_async(self): + """Create background thread pool, prompt iterator, and pre-fill the async queue.""" + gen_batch_size = getattr( + self.args, + "generation_batch_size", + self._train_batch_size * self.args.gradient_accumulation_steps, + ) + # RepeatSampler groups prompts with num_generations repetitions each. + # DataLoader batches the yielded indices into generation-sized batches. + sampler = RepeatSampler( + data_source=self.train_dataset, + mini_repeat_count=self.num_generations, + batch_size=gen_batch_size // self.num_generations, + repeat_count=10_000, # effectively infinite + shuffle=True, + seed=self.args.seed, + ) + self._prompt_dataloader = DataLoader( + self.train_dataset, + batch_size=gen_batch_size, + sampler=sampler, + collate_fn=self.data_collator, + num_workers=0, + ) + self._prompt_iter = iter(self._prompt_dataloader) + self._async_queue = queue.Queue(maxsize=self.args.prefetch_depth) + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + # Pre-submit generations to fill the queue + for _ in range(self.args.prefetch_depth): + self._submit_generation() + + atexit.register(self._shutdown_async) + + def _shutdown_async(self): + if self._executor is not None: + self._executor.shutdown(wait=False, cancel_futures=True) + self._executor = None + + def _submit_generation(self): + """Submit the next background generation job.""" + batch = next(self._prompt_iter) + future = self._executor.submit(self._generate_only, batch) + self._async_queue.put(future) + + # ------------------------------------------------------------------ + # Weight sync + # ------------------------------------------------------------------ + + def _sync_peft_weights_no_merge(self): + """Thread-safe weight sync: compute merged LoRA weights without in-place modification. + + Required for FP8 models where merge_adapter() fails (addmm not implemented + for Float8), and also safe for concurrent use since it never modifies base + weights in-place. + """ + model = self.vllm_generation.model + accelerator = self.vllm_generation.accelerator + vllm_client = self.vllm_generation.vllm_client + fix_name = self.vllm_generation._fix_param_name_to_vllm + + if not (self.vllm_generation.mode == "server" and accelerator.is_main_process): + return + + # Build lookup: module_path -> (A, B, scaling) for all active LoRA layers + lora_info = {} + for mod_name, module in model.base_model.model.named_modules(): + if not hasattr(module, "lora_A") or not hasattr(module, "active_adapters"): + continue + active = module.active_adapters[0] + if active not in module.lora_A: + continue + lora_info[mod_name] = ( + module.lora_A[active].weight.data, + module.lora_B[active].weight.data, + module.scaling[active], + ) + + # Build lookup for FP8 scale_inv parameters (needed for dequantization) + scale_inv_lookup = {} + for pname, pparam in model.named_parameters(): + if "weight_scale_inv" in pname: + # Map weight name -> scale_inv tensor + weight_name = pname.replace(".weight_scale_inv", ".weight") + scale_inv_lookup[weight_name] = pparam.data + + # Iterate all parameters, computing merged weights for LoRA layers. + # Skip LoRA-specific params and FP8 scale params (scales will be + # recomputed by vLLM when it receives the merged bf16 weight). + params_to_sync = [] + for name, param in model.named_parameters(): + vllm_name = name.removeprefix("base_model.model.").replace( + ".base_layer", "" + ) + if model.prefix in vllm_name: + continue + if "original_module" in vllm_name: + continue + # Skip FP8 quantization scale parameters - they are recomputed + # on the vLLM side when we update the weight itself + if "weight_scale_inv" in vllm_name or "input_scale" in vllm_name: + continue + vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."]) + + data = param.data + compute_dtype = torch.bfloat16 + + if vllm_name.endswith(".weight"): + # Dequantize FP8 weights before merging + if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup: + scale_inv = scale_inv_lookup[name] + # Block dequantization: weight * scale_inv (with broadcasting) + fp8_bf16 = data.to(compute_dtype) + if scale_inv.dim() == 2 and fp8_bf16.dim() == 2: + # Block-quantized: scale_inv shape (rows/block, cols/block) + sr, sc = scale_inv.shape + br = fp8_bf16.shape[0] // sr # block height + bc = fp8_bf16.shape[1] // sc # block width + # Reshape → multiply by block scale → reshape back + data = ( + fp8_bf16.reshape(sr, br, sc, bc) + * scale_inv[:, None, :, None].to(compute_dtype) + ).reshape(fp8_bf16.shape) + elif scale_inv.dim() <= 1: + # Per-tensor or per-channel scale + data = fp8_bf16 * scale_inv.to(compute_dtype) + else: + data = fp8_bf16 + elif data.dtype == torch.float8_e4m3fn: + # FP8 but no scale found - just cast (lossy) + data = data.to(compute_dtype) + + mod_path = vllm_name[: -len(".weight")] + if mod_path in lora_info: + A, B, s = lora_info[mod_path] + merged = data.to(compute_dtype) + s * ( + B.to(compute_dtype) @ A.to(compute_dtype) + ) + data = merged + + params_to_sync.append((vllm_name, data)) + + # Batch sync all params in one HTTP+NCCL call (vs individual calls) + if params_to_sync: + vllm_client.batch_update_named_params(params_to_sync) + + # Reset prefix cache after weight update + vllm_client.reset_prefix_cache() + + def _sync_lora_adapter(self): + """Sync LoRA adapter to vLLM via filesystem (native LoRA mode). + + Saves the PEFT adapter to a temp directory and POSTs the path to vLLM's + /set_lora_adapter/ endpoint. vLLM loads the adapter natively using Punica + kernels, avoiding the need to merge weights and NCCL-broadcast the full model. + + Syncs only the LoRA adapter weights via filesystem instead of the full merged model via NCCL. + + FSDP/DeepSpeed: All ranks must participate in the state_dict gather. + accelerator.get_state_dict() handles this (FSDP uses FullStateDictConfig + with rank0_only=True). Only rank 0 gets the full dict, writes files, and + does the HTTP POST. + """ + import os + import tempfile + + accelerator = self.vllm_generation.accelerator + model = self.vllm_generation.model + + if self.vllm_generation.mode != "server": + return + + is_main = accelerator.is_main_process + + # Increment adapter version (all ranks, kept in sync) + if not hasattr(self, "_lora_sync_version"): + self._lora_sync_version = 0 + if is_main: + self._lora_sync_dir = tempfile.mkdtemp(prefix="lora_sync_") + else: + self._lora_sync_dir = None + # Broadcast sync dir from rank 0 to all ranks + if accelerator.num_processes > 1: + import torch.distributed as dist + + if dist.is_initialized(): + obj_list = [self._lora_sync_dir] + dist.broadcast_object_list(obj_list, src=0) + self._lora_sync_dir = obj_list[0] + self._lora_sync_version += 1 + + adapter_path = os.path.join(self._lora_sync_dir, f"v{self._lora_sync_version}") + + # Gather state dict from all ranks (FSDP/DeepSpeed gather, rank0_only) + # All ranks must participate even though only rank 0 gets the result. + # Use self.model_wrapped (the DeepSpeed/FSDP engine) for get_state_dict, + # since it has the necessary hooks (e.g. zero_gather_16bit_weights_on_model_save). + # self.vllm_generation.model is the unwrapped PEFT model which lacks these. + wrapped_model = getattr(self, "model_wrapped", model) + state_dict = accelerator.get_state_dict(wrapped_model) + + if is_main: + # Unwrap to access PEFT's save_pretrained + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(adapter_path, state_dict=state_dict) + + import requests + + vllm_client = self.vllm_generation.vllm_client + url = f"{vllm_client.base_url}/set_lora_adapter/" + response = requests.post( + url, + json={ + "lora_name": "active_lora", + "lora_int_id": self._lora_sync_version, + "lora_path": adapter_path, + }, + timeout=30, + ) + if response.status_code != 200: + logger.warning( + "Failed to set LoRA adapter: %s %s", + response.status_code, + response.text, + ) + return + + # Reset prefix cache after adapter update + vllm_client.reset_prefix_cache() + + # Clean up old adapter versions (keep only current) + if self._lora_sync_version > 1: + old_path = os.path.join( + self._lora_sync_dir, f"v{self._lora_sync_version - 1}" + ) + if os.path.exists(old_path): + import shutil + + shutil.rmtree(old_path, ignore_errors=True) + + logger.info( + "Synced LoRA adapter v%d to vLLM (%s)", + self._lora_sync_version, + adapter_path, + ) + + # Barrier to ensure all ranks complete before resuming forward passes. + # Without this, rank 1 may start a forward pass (triggering FSDP unshard) + # while rank 0 is still doing save_pretrained, causing FSDP all-gather deadlock. + if accelerator.num_processes > 1: + import torch.distributed as dist + + if dist.is_initialized(): + dist.barrier() + + def _maybe_sync_vllm_weights(self): + """Sync model weights to vLLM if the interval has elapsed. + + Dispatches to one of three strategies: + - vllm_lora_sync: saves adapter to filesystem, vLLM loads natively + - PEFT no-merge: computes merged weights as new tensors, NCCL broadcast + - Non-PEFT: stock sync_weights via merge_adapter + NCCL + """ + if not (self.use_vllm and self.args.async_prefetch): + return + step = self.state.global_step + interval = self.args.vllm_sync_interval + if step != self._last_synced_step and step % interval == 0: + if getattr(self.args, "vllm_lora_sync", False): + if step == 0: + logger.info("Skipping LoRA sync at step 0 (no training yet)") + self._last_synced_step = step + return + # Native LoRA sync: save adapter to filesystem, vLLM loads it directly + self._sync_lora_adapter() + else: + from accelerate.utils import is_peft_model + + use_no_merge = is_peft_model(self.vllm_generation.model) + + if use_no_merge: + # No-merge sync: computes merged weights as new tensors + # (doesn't modify base weights in-place), so it's safe to + # run concurrently with BG generation — no lock needed. + self._sync_peft_weights_no_merge() + else: + # Non-PEFT: use stock sync (acquires lock to avoid overlap) + if self.data_producer is not None and hasattr( + self.data_producer, "_generate_lock" + ): + with self.data_producer._generate_lock: + self.vllm_generation.sync_weights() + elif self._async_queue is not None: + pending = list(self._async_queue.queue) + for f in pending: + if isinstance(f, concurrent.futures.Future): + f.result() + self.vllm_generation.sync_weights() + else: + self.vllm_generation.sync_weights() + self._last_synced_step = step + + def _zero_pad_embedding_for_fp8(self): + """Zero out the pad token embedding for FP8 models. + + FP8 linear layers produce NaN when processing positions with + attention_mask=0 (the hidden states at those positions have + unconstrained values that overflow FP8 range during + quantization). By setting the pad token embedding to zeros, + padding positions start with zero hidden states and stay zero + through masked attention, preventing NaN from FP8 matmul. + """ + model = self.accelerator.unwrap_model(self.model) + # Check if model has FP8 weights + has_fp8 = any( + p.dtype == torch.float8_e4m3fn + for p in model.parameters() + if not p.requires_grad + ) + if not has_fp8: + return + + # Find the embedding layer + if hasattr(model, "model") and hasattr(model.model, "embed_tokens"): + embed = model.model.embed_tokens + elif hasattr(model, "base_model") and hasattr(model.base_model, "model"): + m = model.base_model.model + if hasattr(m, "model") and hasattr(m.model, "embed_tokens"): + embed = m.model.embed_tokens + else: + return + else: + return + + pad_id = self.processing_class.pad_token_id + if pad_id is not None and pad_id < embed.weight.shape[0]: + with torch.no_grad(): + embed.weight.data[pad_id].zero_() + import logging + + logging.getLogger("async_grpo").info( + f"Zeroed pad token embedding (id={pad_id}) for FP8 NaN prevention" + ) + + # ------------------------------------------------------------------ + # Background-thread generation (no scoring) + # ------------------------------------------------------------------ + + def _generate_single_turn(self, prompts, **kwargs): + """Override to prevent weight sync from background thread and to use + no-merge sync for PEFT models (FP8 models can't merge_adapter).""" + is_bg = threading.current_thread() is not threading.main_thread() + saved_step = None + + if is_bg and self.use_vllm: + # Trick: match _last_loaded_step so the stock sync check is a no-op + saved_step = getattr(self, "_last_loaded_step", None) + self._last_loaded_step = self.state.global_step + + # Permanently replace vllm_generation.sync_weights with our custom + # sync to avoid merge_adapter (fails on FP8 / races with training). + # For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights + # handles the sync with proper interval tracking. + if not getattr(self, "_patched_sync_weights", False): + if self.use_vllm and hasattr(self, "vllm_generation"): + if getattr(self.args, "vllm_lora_sync", False): + # No-op: LoRA sync is driven by _maybe_sync_vllm_weights + self.vllm_generation.sync_weights = lambda: None + self._patched_sync_weights = True + else: + from accelerate.utils import is_peft_model + + if is_peft_model(self.vllm_generation.model): + + def _no_merge_sync(): + self._sync_peft_weights_no_merge() + + self.vllm_generation.sync_weights = _no_merge_sync + self._patched_sync_weights = True + + try: + return super()._generate_single_turn(prompts, **kwargs) + finally: + if saved_step is not None: + self._last_loaded_step = saved_step + + def _generate_rank0_only(self, prompts): + """Generate using vLLM directly on rank 0 without cross-rank collectives. + + Called from BG thread in FSDP mode. Bypasses ``gather_object`` / + ``broadcast_object_list`` since the main thread may be running FSDP + collectives concurrently. + + Returns the same tuple as ``_generate``. + """ + import copy + + prompts = copy.deepcopy(prompts) + + # Duplicate prompts for num_generations (same as TRL's gather+unique pattern) + num_generations = self.num_generations + unique_prompts = prompts[::num_generations] + + # Build sampling params + vg = self.vllm_generation + sampling_params = { + "n": num_generations, + "repetition_penalty": vg.repetition_penalty, + "temperature": vg.temperature, + "top_p": vg.top_p, + "top_k": vg.top_k, + "min_p": 0.0 if vg.min_p is None else vg.min_p, + "max_tokens": vg.max_completion_length, + "logprobs": vg.logprobs, + "structured_outputs_regex": vg.structured_outputs_regex, + "generation_kwargs": vg.generation_kwargs, + } + + # Call vLLM directly (no collectives) + from trl.data_utils import is_conversational + + if is_conversational({"prompt": unique_prompts[0]}): + output = vg.vllm_client.chat( + messages=unique_prompts, + **sampling_params, + chat_template_kwargs=vg.chat_template_kwargs, + tools=vg.tools, + chat_template=vg.chat_template, + ) + else: + output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params) + + # vLLM returns 1 prompt_ids per unique prompt, but num_generations completion_ids. + # Duplicate prompt_ids to match completions (one per generation). + raw_prompt_ids = output["prompt_ids"] + prompt_ids = [pid for pid in raw_prompt_ids for _ in range(num_generations)] + completion_ids = output["completion_ids"] + logprobs_raw = output["logprobs"] + extra_fields = { + k: v + for k, v in output.items() + if k + not in {"prompt_ids", "completion_ids", "logprobs", "logprob_token_ids"} + } + + # Extract top-1 logprob per token + logprobs = [[lp[0] for lp in seq] for seq in logprobs_raw] + + # Decode completions + if is_conversational({"prompt": prompts[0]}): + contents = self.processing_class.batch_decode( + completion_ids, skip_special_tokens=True + ) + completions = [[{"role": "assistant", "content": c}] for c in contents] + else: + completions = self.processing_class.batch_decode( + completion_ids, skip_special_tokens=True + ) + + tool_mask = extra_fields.pop("env_mask", None) + + # Compute total completion tokens locally (no gather) + total_completion_tokens = sum(len(ids) for ids in completion_ids) + + return ( + prompt_ids, + completion_ids, + tool_mask, + completions, + total_completion_tokens, + logprobs, + extra_fields, + ) + + def _generate_only(self, inputs, rank0_only=False): + """Generate completions without scoring. Runs on background thread. + + Mirrors the first half of ``_generate_and_score_completions`` (prompt + extraction → vLLM generation → tensor padding) and returns a deferred + output dict for main-thread scoring. + + When ``rank0_only=True`` (FSDP mode), bypasses ``gather_object`` / + ``broadcast_object_list`` collectives and calls vLLM directly on rank 0. + Results are broadcast to other ranks on the main thread later. + + Args: + inputs: list of dicts (one per sample), as yielded by the DataLoader + with ``identity`` collate_fn. + """ + device = self.accelerator.device + + prompts = [x["prompt"] for x in inputs] + + # --- Handle images (multimodal) --- + if "images" in inputs[0]: + images = [ex.get("images") for ex in inputs] + elif "image" in inputs[0]: + images = [ + [ex.get("image")] if ex.get("image") is not None else None + for ex in inputs + ] + else: + images = None + if images is not None and all(img == [] for img in images): + images = None + + if images is not None: + if not is_conversational(inputs[0]): + raise ValueError("Multimodal training requires conversational prompts.") + prompts = [ + prepare_multimodal_messages(p, il) + for p, il in zip(prompts, images, strict=True) + ] + + # --- Generate completions --- + if rank0_only: + # FSDP mode: call vLLM directly without cross-rank collectives + ( + prompt_ids_list, + completion_ids_list, + tool_mask_list, + completions, + num_items_in_batch, + sampling_per_token_logps_list, + extra_fields, + ) = self._generate_rank0_only(prompts) + else: + ( + prompt_ids_list, + completion_ids_list, + tool_mask_list, + completions, + num_items_in_batch, + sampling_per_token_logps_list, + extra_fields, + ) = self._generate(prompts) + # _generate gathers prompts from all ranks internally. Gather inputs + # to match the full-batch output size. + if self.accelerator.num_processes > 1: + from accelerate.utils import gather_object + + inputs = gather_object(inputs) + prompts = [x["prompt"] for x in inputs] + + # --- Pad to tensors --- + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad( + prompt_ids, padding_value=self.pad_token_id, padding_side="left" + ) + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + + completion_ids = [ + torch.tensor(ids, device=device) for ids in completion_ids_list + ] + completion_mask = [ + torch.ones_like(ids, dtype=torch.long) for ids in completion_ids + ] + completion_ids = pad( + completion_ids, padding_value=self.pad_token_id, padding_side="right" + ) + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + + if sampling_per_token_logps_list is not None: + sampling_logps = [ + torch.tensor(lp, device=device) for lp in sampling_per_token_logps_list + ] + sampling_per_token_logps = pad( + sampling_logps, padding_value=0.0, padding_side="right" + ) + else: + sampling_per_token_logps = None + + if tool_mask_list is not None: + tool_mask = [torch.tensor(m, device=device) for m in tool_mask_list] + tool_mask = pad(tool_mask, padding_value=1, padding_side="right") + else: + tool_mask = None + + # --- Mask truncated completions --- + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_trunc = torch.tensor( + [ids[-1] not in eos_and_pad for ids in completion_ids_list], + device=device, + ) + completion_mask = completion_mask * (~is_trunc).unsqueeze(1).int() + if tool_mask is not None: + tool_mask = tool_mask * (~is_trunc).unsqueeze(1).int() + + # --- Multimodal forward kwargs --- + num_images = [len(il) for il in images] if images is not None else None + if images is not None: + prompts_text = [ + apply_chat_template( + {"prompt": p}, + self.processing_class, + tools=self.tools, + **self.chat_template_kwargs, + )["prompt"] + for p in prompts + ] + prompt_inputs = self.processing_class( + images=images, text=prompts_text, padding=True, return_tensors="pt" + ) + forward_kwargs = { + k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in prompt_inputs.items() + if k not in ("input_ids", "attention_mask") + } + else: + forward_kwargs = {} + + # Extend token_type_ids / mm_token_type_ids for completion tokens + for ttid_key in ("token_type_ids", "mm_token_type_ids"): + if ttid_key in forward_kwargs: + tt = forward_kwargs[ttid_key] + forward_kwargs[ttid_key] = torch.cat( + [tt, tt.new_zeros(completion_ids.shape)], dim=1 + ) + + # Merge extra_fields from rollout_func into inputs + if extra_fields: + for i, inp in enumerate(inputs): + for key, values in extra_fields.items(): + if isinstance(values, list) and i < len(values): + inp[key] = values[i] + elif not isinstance(values, list): + inp[key] = values + + # No explicit CUDA sync needed here — both threads share the + # default stream, so operations are naturally ordered. + + # --- Construct deferred output --- + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "num_items_in_batch": num_items_in_batch, + "advantages": torch.zeros(completion_ids.size(0), device=device), + # Sentinels for deferred scoring + "_pending_policy_logps": True, + "_deferred_inputs": inputs, + "_deferred_prompts": prompts, + "_deferred_completions": completions, + "_deferred_completion_ids_list": completion_ids_list, + "_rank0_only": rank0_only, + } + if sampling_per_token_logps is not None: + output["sampling_per_token_logps"] = sampling_per_token_logps + if tool_mask is not None: + output["tool_mask"] = tool_mask + if images is not None: + output["num_images"] = num_images + for k in ( + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): + if k in forward_kwargs: + output[k] = forward_kwargs[k] + return output + + # ------------------------------------------------------------------ + # Hooks (overridden by subclasses like FastAsyncGRPOTrainer) + # ------------------------------------------------------------------ + + def _compute_rewards_for_batch( + self, inputs, prompts, completions, completion_ids_list + ): + """Compute rewards for a batch. Override for parallel workers, caching, etc.""" + return self._calculate_rewards( + inputs, prompts, completions, completion_ids_list + ) + + def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list): + """Launch reward computation in background. Override for parallel dispatch. + + Default: no-op (rewards computed synchronously in _collect_reward_workers). + """ + self._pending_reward_args = (inputs, prompts, completions, completion_ids_list) + + def _collect_reward_workers( + self, inputs, prompts, completions, completion_ids_list + ): + """Collect reward results. Override to collect from parallel workers. + + Default: compute rewards synchronously now. + """ + args = getattr(self, "_pending_reward_args", None) + if args is not None: + self._pending_reward_args = None + return self._compute_rewards_for_batch(*args) + return self._compute_rewards_for_batch( + inputs, prompts, completions, completion_ids_list + ) + + def _post_advantage_hook( + self, + data: dict, + rewards_per_func, + advantages, + inputs: list, + num_generations: int, + mode: str, + s_start: int | None = None, + s_end: int | None = None, + is_last_chunk: bool = True, + ) -> None: + """Called after advantages are computed. Override for replay buffer, re-roll, etc.""" + + # ------------------------------------------------------------------ + # Main-thread scoring + # ------------------------------------------------------------------ + + @torch.no_grad() + def _compute_deferred_scores(self, rollout: dict) -> dict: + """Compute rewards, advantages, policy logprobs, and IS ratio on the main thread. + + Takes the deferred output from ``_generate_only`` and produces a fully + scored dict ready for ``split_tensor_dict`` → micro-batches. + """ + device = self.accelerator.device + batch_size = self.args.per_device_train_batch_size + num_generations = self.num_generations + mode = "train" + + # --- Extract deferred data --- + data = rollout + inputs = data.pop("_deferred_inputs") + prompts = data.pop("_deferred_prompts") + completions = data.pop("_deferred_completions") + completion_ids_list = data.pop("_deferred_completion_ids_list") + rank0_only = data.pop("_rank0_only", False) + del data["_pending_policy_logps"] + + prompt_ids = data["prompt_ids"] + completion_ids = data["completion_ids"] + prompt_mask = data["prompt_mask"] + completion_mask = data["completion_mask"] + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + + # Multimodal forward kwargs + forward_kwargs = {} + for key in ( + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): + if key in data: + forward_kwargs[key] = data[key] + num_images = data.get("num_images") + + # --- Launch rewards in parallel with logprobs --- + self._launch_reward_workers(inputs, prompts, completions, completion_ids_list) + + # --- Policy logprobs --- + logprob_batch_size = min(batch_size * 4, len(prompt_ids)) + with disable_gradient_checkpointing( + self.model, self.args.gradient_checkpointing_kwargs + ): + generate_every = self.args.steps_per_generation * self.num_iterations + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm + and getattr(self, "vllm_importance_sampling_correction", False) + ): + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + logprob_batch_size, + num_images=num_images, + **forward_kwargs, + ) + data["old_per_token_logps"] = old_per_token_logps + else: + old_per_token_logps = None + + # Reference model logprobs + if self.beta != 0.0: + if self.ref_model is not None: + ref_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, + ) + else: + unwrapped = self.accelerator.unwrap_model(self.model) + adapter_name = ( + "ref" + if hasattr(unwrapped, "peft_config") + and "ref" in unwrapped.peft_config + else None + ) + with use_adapter(unwrapped, adapter_name=adapter_name): + ref_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, + ) + data["ref_per_token_logps"] = ref_logps + + # --- IS ratio --- + if ( + self.use_vllm + and getattr(self, "vllm_importance_sampling_correction", False) + and old_per_token_logps is not None + and "sampling_per_token_logps" in data + ): + sampling_logps = data["sampling_per_token_logps"] + is_mask = ( + completion_mask + if "tool_mask" not in data + else completion_mask * data["tool_mask"] + ) + per_token_logps_diff = (old_per_token_logps - sampling_logps) * is_mask + + is_mode = getattr(self, "vllm_importance_sampling_mode", "token_truncate") + is_cap = getattr(self, "vllm_importance_sampling_cap", 3.0) + sequence_level_is = is_mode in ("sequence_mask", "sequence_truncate") + if sequence_level_is: + logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) + else: + logps_diff = per_token_logps_diff + + is_ratio = torch.exp(logps_diff) + if is_mode in ("sequence_truncate", "token_truncate"): + is_ratio = torch.clamp(is_ratio, max=is_cap) + elif is_mode in ("sequence_mask", "token_mask"): + is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0) + data["importance_sampling_ratio"] = is_ratio + + # --- Collect rewards (launched before logprobs, should be done) --- + rewards_per_func = self._collect_reward_workers( + inputs, prompts, completions, completion_ids_list + ) + # In rank0_only mode, all ranks compute the same rewards on identical data. + # _calculate_rewards / _collect_reward_workers always `gather()` across ranks, + # which duplicates the rows (N_local * num_processes). De-duplicate so that + # rewards_per_func matches the data dict (which has N_local rows). + if rank0_only and rewards_per_func.size(0) > len(prompts): + rewards_per_func = rewards_per_func[: len(prompts)] + + # --- Advantages --- + if self.multi_objective_aggregation == "sum_then_normalize": + rewards = ( + rewards_per_func * self.reward_weights.to(device).unsqueeze(0) + ).nansum(dim=1) + mean_grouped = ( + rewards.view(-1, num_generations) + .mean(dim=1) + .repeat_interleave(num_generations) + ) + if self.scale_rewards in ("group", "none"): + if num_generations > 1: + std_rewards = ( + rewards.view(-1, num_generations) + .std(dim=1) + .repeat_interleave(num_generations) + ) + else: + std_rewards = torch.zeros_like(rewards) + elif self.scale_rewards == "batch": + std_rewards = ( + rewards.std().expand_as(rewards) + if rewards.numel() > 1 + else torch.zeros_like(rewards) + ) + else: + raise ValueError(f"Invalid scale_rewards: {self.scale_rewards}") + advantages = rewards - mean_grouped + if self.scale_rewards != "none": + advantages = advantages / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + + elif self.multi_objective_aggregation == "normalize_then_sum": + grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) + mean_k = torch.nanmean(grouped, dim=1, keepdim=True) + std_k = ( + nanstd(grouped, dim=1, keepdim=True) + if num_generations > 1 + else torch.zeros_like(mean_k) + ) + reward_k = (grouped - mean_k) / (std_k + 1e-4) + reward_k = reward_k.view(-1, len(self.reward_funcs)) + rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum( + dim=1 + ) + std_rewards = ( + rewards.std().expand_as(rewards) + if rewards.numel() > 1 + else torch.zeros_like(rewards) + ) + advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + else: + raise ValueError( + f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}" + ) + + # Slice for local process + # In rank0_only mode, all ranks already have identical data from broadcast, + # so no slicing needed. Otherwise, each rank takes its portion. + if rank0_only: + process_slice = slice(0, len(prompts)) + else: + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + all_advantages = advantages.clone() + advantages = advantages[process_slice] + data["advantages"] = advantages + + # --- Post-advantage hook (for replay buffer, re-roll, etc.) --- + self._post_advantage_hook( + data, + rewards_per_func, + advantages, + inputs, + num_generations, + mode, + ) + + # --- Metrics --- + for i, name in enumerate(self.reward_func_names): + self._metrics[mode][f"rewards/{name}/mean"].append( + torch.nanmean(rewards_per_func[:, i]).item() + ) + self._metrics[mode][f"rewards/{name}/std"].append( + nanstd(rewards_per_func[:, i]).item() + ) + agg_rewards = rewards_per_func.nansum(dim=1) + self._metrics[mode]["reward"].append(agg_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(agg_rewards.std().item()) + self._metrics[mode]["frac_reward_zero_std"].append( + is_std_zero.float().mean().item() + ) + + # Token counting + total_prompt = self.accelerator.gather(prompt_mask.sum()).sum() + total_completion = self.accelerator.gather(completion_mask.sum()).sum() + self.state.num_input_tokens_seen += (total_prompt + total_completion).item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Completion length metrics + comp_lengths = completion_mask.sum(dim=1) + agg_lengths = self.accelerator.gather(comp_lengths) + self._metrics[mode]["completions/mean_length"].append( + agg_lengths.float().mean().item() + ) + self._metrics[mode]["completions/min_length"].append( + agg_lengths.float().min().item() + ) + self._metrics[mode]["completions/max_length"].append( + agg_lengths.float().max().item() + ) + + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_trunc = torch.tensor( + [ids[-1].item() not in eos_and_pad for ids in completion_ids], device=device + ) + agg_trunc = self.accelerator.gather(is_trunc) + self._metrics[mode]["completions/clipped_ratio"].append( + agg_trunc.float().mean().item() + ) + term_lengths = agg_lengths[~agg_trunc] + if len(term_lengths) == 0: + term_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append( + term_lengths.float().mean().item() + ) + self._metrics[mode]["completions/min_terminated_length"].append( + term_lengths.float().min().item() + ) + self._metrics[mode]["completions/max_terminated_length"].append( + term_lengths.float().max().item() + ) + + # IS metrics + if "importance_sampling_ratio" in data and "sampling_per_token_logps" in data: + old_lp = data["old_per_token_logps"] + samp_lp = data["sampling_per_token_logps"] + mask = completion_mask.bool() + delta = torch.abs(old_lp - samp_lp) + delta_m = delta[mask] + md = ( + torch.mean(delta_m) + if delta_m.numel() > 0 + else torch.tensor(0.0, device=device) + ) + xd = ( + torch.max(delta_m) + if delta_m.numel() > 0 + else torch.tensor(0.0, device=device) + ) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(md).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(xd).max().item() + ) + isr = data["importance_sampling_ratio"] + is_mode = getattr(self, "vllm_importance_sampling_mode", "token_truncate") + if is_mode in ("sequence_mask", "sequence_truncate"): + flat_isr = isr.flatten() + else: + flat_isr = isr[mask] + if flat_isr.numel() > 0: + self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(torch.min(flat_isr))).item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(torch.mean(flat_isr)).nanmean().item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(torch.max(flat_isr))).item() + ) + + # Log prompt/completion texts + prompts_text = self.processing_class.batch_decode( + prompt_ids, skip_special_tokens=True + ) + completions_text = self.processing_class.batch_decode( + completion_ids, skip_special_tokens=True + ) + if gather_object is not None: + self._logs["prompt"].extend(gather_object(prompts_text)) + self._logs["completion"].extend(gather_object(completions_text)) + for i, name in enumerate(self.reward_func_names): + self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + self._logs["advantages"].extend(all_advantages.tolist()) + + # Remove deferred keys + for k in list(data.keys()): + if k.startswith("_deferred") or k == "_pending_policy_logps": + data.pop(k, None) + + return data + + @torch.no_grad() + def _compute_streaming_group_scores( + self, + data, + s_start, + s_end, + inputs, + prompts, + completions, + completion_ids_list, + is_last_chunk, + rank0_only=False, + ): + """Score a chunk of prompt groups: rewards, policy logprobs, advantages. + + Called during streaming scoring to incrementally score groups. + Writes results directly into ``data`` at positions ``s_start:s_end``. + """ + device = self.accelerator.device + batch_size = self.args.per_device_train_batch_size + num_generations = self.num_generations + mode = "train" + chunk_size = s_end - s_start + + # --- Policy logprobs for this chunk --- + chunk_prompt_ids = data["prompt_ids"][s_start:s_end] + chunk_completion_ids = data["completion_ids"][s_start:s_end] + chunk_prompt_mask = data["prompt_mask"][s_start:s_end] + chunk_completion_mask = data["completion_mask"][s_start:s_end] + prompt_completion_ids = torch.cat( + [chunk_prompt_ids, chunk_completion_ids], dim=1 + ) + attention_mask = torch.cat([chunk_prompt_mask, chunk_completion_mask], dim=1) + logits_to_keep = chunk_completion_ids.size(1) + + # Slice multimodal forward kwargs for this chunk + forward_kwargs = {} + for key in ( + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): + if key in data: + val = data[key] + if ( + isinstance(val, torch.Tensor) + and val.dim() > 0 + and val.size(0) == len(data["prompt_ids"]) + ): + forward_kwargs[key] = val[s_start:s_end] + else: + forward_kwargs[key] = val + num_images = data.get("num_images") + if ( + num_images is not None + and hasattr(num_images, "__getitem__") + and len(num_images) == len(data["prompt_ids"]) + ): + num_images = num_images[s_start:s_end] + + # --- Launch rewards in parallel with logprobs --- + self._launch_reward_workers(inputs, prompts, completions, completion_ids_list) + + # --- Policy logprobs for this chunk (GPU, overlaps with BG rewards) --- + logprob_batch_size = min(batch_size * 2, chunk_size) + with disable_gradient_checkpointing( + self.model, self.args.gradient_checkpointing_kwargs + ): + generate_every = self.args.steps_per_generation * self.num_iterations + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm + and getattr(self, "vllm_importance_sampling_correction", False) + ): + old_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + logprob_batch_size, + num_images=num_images, + **forward_kwargs, + ) + if "old_per_token_logps" not in data: + total = len(data["prompt_ids"]) + data["old_per_token_logps"] = torch.zeros( + total, old_logps.size(1), device=device, dtype=old_logps.dtype + ) + data["old_per_token_logps"][s_start:s_end] = old_logps + + # Compute IS ratio for this chunk + if "sampling_per_token_logps" in data: + samp_chunk = data["sampling_per_token_logps"][s_start:s_end] + is_mask = ( + chunk_completion_mask + if "tool_mask" not in data + else (chunk_completion_mask * data["tool_mask"][s_start:s_end]) + ) + diff = (old_logps - samp_chunk) * is_mask + is_mode = getattr( + self, "vllm_importance_sampling_mode", "token_truncate" + ) + is_cap = getattr(self, "vllm_importance_sampling_cap", 3.0) + seq_is = is_mode in ("sequence_mask", "sequence_truncate") + logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff + is_ratio = torch.exp(logps_diff) + if is_mode in ("sequence_truncate", "token_truncate"): + is_ratio = torch.clamp(is_ratio, max=is_cap) + elif is_mode in ("sequence_mask", "token_mask"): + is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0) + if "importance_sampling_ratio" not in data: + total = len(data["prompt_ids"]) + shape = (total, 1) if seq_is else (total, is_ratio.size(1)) + data["importance_sampling_ratio"] = torch.ones( + *shape, device=device, dtype=is_ratio.dtype + ) + data["importance_sampling_ratio"][s_start:s_end] = is_ratio + + # Reference logprobs + if self.beta != 0.0: + if self.ref_model is not None: + ref_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, + ) + else: + unwrapped = self.accelerator.unwrap_model(self.model) + adapter_name = ( + "ref" + if hasattr(unwrapped, "peft_config") + and "ref" in unwrapped.peft_config + else None + ) + with use_adapter(unwrapped, adapter_name=adapter_name): + ref_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, + ) + if "ref_per_token_logps" not in data: + total = len(data["prompt_ids"]) + data["ref_per_token_logps"] = torch.zeros( + total, ref_logps.size(1), device=device, dtype=ref_logps.dtype + ) + data["ref_per_token_logps"][s_start:s_end] = ref_logps + + # --- Collect rewards (should already be done, ran in parallel with logprobs) --- + rewards_per_func = self._collect_reward_workers( + inputs, prompts, completions, completion_ids_list + ) + # De-duplicate gathered rewards when all ranks computed the same data. + # _calculate_rewards always gather()s, which duplicates rows in rank0_only mode. + if rewards_per_func.size(0) > chunk_size: + rewards_per_func = rewards_per_func[:chunk_size] + + # --- Advantages (group-level normalization) --- + if self.multi_objective_aggregation == "sum_then_normalize": + rewards = ( + rewards_per_func * self.reward_weights.to(device).unsqueeze(0) + ).nansum(dim=1) + mean_g = ( + rewards.view(-1, num_generations) + .mean(dim=1) + .repeat_interleave(num_generations) + ) + if num_generations > 1: + std_r = ( + rewards.view(-1, num_generations) + .std(dim=1) + .repeat_interleave(num_generations) + ) + else: + std_r = torch.zeros_like(rewards) + advantages = rewards - mean_g + if self.scale_rewards != "none": + advantages = advantages / (std_r + 1e-4) + is_std_zero = torch.isclose(std_r, torch.zeros_like(std_r)) + + elif self.multi_objective_aggregation == "normalize_then_sum": + grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) + mean_k = torch.nanmean(grouped, dim=1, keepdim=True) + std_k = ( + nanstd(grouped, dim=1, keepdim=True) + if num_generations > 1 + else torch.zeros_like(mean_k) + ) + reward_k = ((grouped - mean_k) / (std_k + 1e-4)).view( + -1, len(self.reward_funcs) + ) + rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum( + dim=1 + ) + std_r = ( + rewards.view(-1, num_generations) + .std(dim=1) + .repeat_interleave(num_generations) + ) + mean_r = ( + rewards.view(-1, num_generations) + .mean(dim=1) + .repeat_interleave(num_generations) + ) + advantages = (rewards - mean_r) / (std_r + 1e-4) + is_std_zero = torch.isclose(std_r, torch.zeros_like(std_r)) + else: + raise ValueError( + f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}" + ) + + if rank0_only: + process_slice = slice(0, len(prompts)) + else: + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + + if "advantages" not in data or not isinstance(data["advantages"], torch.Tensor): + data["advantages"] = torch.zeros(len(data["prompt_ids"]), device=device) + data["advantages"][s_start:s_end] = advantages + + # --- Post-advantage hook (for replay buffer, re-roll, etc.) --- + self._post_advantage_hook( + data, + rewards_per_func, + advantages, + inputs, + num_generations, + mode, + s_start=s_start, + s_end=s_end, + is_last_chunk=is_last_chunk, + ) + + # --- Chunk metrics --- + for i, name in enumerate(self.reward_func_names): + self._metrics[mode][f"rewards/{name}/mean"].append( + torch.nanmean(rewards_per_func[:, i]).item() + ) + self._metrics[mode][f"rewards/{name}/std"].append( + nanstd(rewards_per_func[:, i]).item() + ) + agg_rewards = rewards_per_func.nansum(dim=1) + self._metrics[mode]["reward"].append(agg_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(agg_rewards.std().item()) + self._metrics[mode]["frac_reward_zero_std"].append( + is_std_zero.float().mean().item() + ) + + # --- Full-batch metrics on last chunk --- + if is_last_chunk: + all_prompt_mask = data["prompt_mask"] + all_completion_mask = data["completion_mask"] + all_completion_ids = data["completion_ids"] + total_p = self.accelerator.gather(all_prompt_mask.sum()).sum() + total_c = self.accelerator.gather(all_completion_mask.sum()).sum() + self.state.num_input_tokens_seen += (total_p + total_c).item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + comp_lengths = all_completion_mask.sum(dim=1) + agg_lengths = self.accelerator.gather(comp_lengths) + self._metrics[mode]["completions/mean_length"].append( + agg_lengths.float().mean().item() + ) + self._metrics[mode]["completions/min_length"].append( + agg_lengths.float().min().item() + ) + self._metrics[mode]["completions/max_length"].append( + agg_lengths.float().max().item() + ) + + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_trunc = torch.tensor( + [ids[-1].item() not in eos_and_pad for ids in all_completion_ids], + device=device, + ) + agg_trunc = self.accelerator.gather(is_trunc) + self._metrics[mode]["completions/clipped_ratio"].append( + agg_trunc.float().mean().item() + ) + term = agg_lengths[~agg_trunc] + if len(term) == 0: + term = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append( + term.float().mean().item() + ) + self._metrics[mode]["completions/min_terminated_length"].append( + term.float().min().item() + ) + self._metrics[mode]["completions/max_terminated_length"].append( + term.float().max().item() + ) + + # IS metrics + if ( + self.use_vllm + and getattr(self, "vllm_importance_sampling_correction", False) + and "sampling_per_token_logps" in data + and "old_per_token_logps" in data + ): + old_lp = data["old_per_token_logps"] + samp_lp = data["sampling_per_token_logps"] + mask = all_completion_mask.bool() + delta = torch.abs(old_lp - samp_lp)[mask] + md = ( + torch.mean(delta) + if delta.numel() > 0 + else torch.tensor(0.0, device=device) + ) + xd = ( + torch.max(delta) + if delta.numel() > 0 + else torch.tensor(0.0, device=device) + ) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(md).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(xd).max().item() + ) + is_mode = getattr( + self, "vllm_importance_sampling_mode", "token_truncate" + ) + isr = data["importance_sampling_ratio"] + flat = ( + isr.flatten() + if is_mode in ("sequence_mask", "sequence_truncate") + else isr[mask] + ) + if flat.numel() > 0: + self._metrics[mode][ + "sampling/importance_sampling_ratio/min" + ].append(nanmin(self.accelerator.gather(torch.min(flat))).item()) + self._metrics[mode][ + "sampling/importance_sampling_ratio/mean" + ].append(self.accelerator.gather(torch.mean(flat)).nanmean().item()) + self._metrics[mode][ + "sampling/importance_sampling_ratio/max" + ].append(nanmax(self.accelerator.gather(torch.max(flat))).item()) + + def _score_streaming(self, rollout: dict) -> list[dict]: + """Score a rollout using streaming group scoring. Returns list of micro-batches.""" + data = rollout + num_gen = self.num_generations + n_groups = len(data["prompt_ids"]) // num_gen + batch_size = self.args.per_device_train_batch_size + min_groups = max(1, self.args.streaming_min_groups) + + # Extract deferred data + inputs = data.pop("_deferred_inputs") + prompts = data.pop("_deferred_prompts") + completions = data.pop("_deferred_completions") + completion_ids_list = data.pop("_deferred_completion_ids_list") + rank0_only = data.pop("_rank0_only", False) + del data["_pending_policy_logps"] + + all_micro_batches = [] + shared_keys = {"num_items_in_batch"} + + for chunk_start_g in range(0, n_groups, min_groups): + chunk_end_g = min(chunk_start_g + min_groups, n_groups) + s_start = chunk_start_g * num_gen + s_end = chunk_end_g * num_gen + + self._compute_streaming_group_scores( + data=data, + s_start=s_start, + s_end=s_end, + inputs=inputs[s_start:s_end], + prompts=prompts[s_start:s_end], + completions=completions[s_start:s_end], + completion_ids_list=completion_ids_list[s_start:s_end], + is_last_chunk=(chunk_end_g == n_groups), + rank0_only=rank0_only, + ) + + # Yield micro-batches from this scored chunk + chunk_size = s_end - s_start + perm = torch.randperm(chunk_size) + for mb_off in range(0, chunk_size, batch_size): + mb_idx = perm[mb_off : mb_off + batch_size] + abs_idx = mb_idx + s_start + mb = {} + for key in data: + if key.startswith("_"): + continue + val = data[key] + if key in shared_keys: + mb[key] = val + elif isinstance(val, torch.Tensor) and val.dim() > 0: + mb[key] = val[abs_idx] + else: + mb[key] = val + all_micro_batches.append(mb) + + # Repeat for num_iterations + return all_micro_batches * self.num_iterations + + # ------------------------------------------------------------------ + # _prepare_inputs override + # ------------------------------------------------------------------ + + def _prepare_inputs(self, generation_batch): + """Override to support data producer and async prefetch paths.""" + mode = "train" if self.model.training else "eval" + + # --- Data producer path --- + if mode == "train" and self.data_producer is not None: + return self._prepare_inputs_data_producer(generation_batch) + + # --- Legacy async prefetch path (no data producer) --- + if mode == "train" and self.args.async_prefetch: + return self._prepare_inputs_legacy_async(generation_batch) + + # --- Stock path --- + return super()._prepare_inputs(generation_batch) + + def _prepare_inputs_data_producer(self, generation_batch): + """Data producer path: produce rollout, score deferred logps, split into micro-batches.""" + # Return from buffer if available + if self._buffered_inputs: + return self._buffered_inputs.pop(0) + + # Produce a new rollout + self._maybe_sync_vllm_weights() + + rollout_dataset = self.data_producer.produce( + self.model, + self.state.global_step, + processing_class=self.processing_class, + accelerator=self.accelerator, + args=self.args, + ) + + # Convert RolloutDataset back to a dict for scoring/splitting + rollout = rollout_dataset._data + + # If async (skip_policy_logps=True), score deferred logps on main thread + if rollout.get("_pending_policy_logps"): + if self.args.streaming_partial_batch: + micro_batches = self._score_streaming(rollout) + else: + scored = self._compute_deferred_scores(rollout) + scored = split_pixel_values_by_grid(scored) + scored = shuffle_sequence_dict(scored) + batches = split_tensor_dict(scored, self.args.steps_per_generation) + micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches] + micro_batches = micro_batches * self.num_iterations + else: + # Sync path: data is already fully scored + rollout = split_pixel_values_by_grid(rollout) + batches = split_tensor_dict(rollout, self.args.steps_per_generation) + micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches] + micro_batches = micro_batches * self.num_iterations + + self._buffered_inputs = micro_batches[1:] + return micro_batches[0] + + def _prepare_inputs_legacy_async(self, generation_batch): + """Legacy async path: direct queue-based prefetch without data producer.""" + # Return from buffer if available + if self._buffered_inputs: + return self._buffered_inputs.pop(0) + + # Need a new rollout + self._maybe_sync_vllm_weights() + future = self._async_queue.get() + rollout = future.result() + self._submit_generation() + + if self.args.streaming_partial_batch: + micro_batches = self._score_streaming(rollout) + else: + scored = self._compute_deferred_scores(rollout) + scored = split_pixel_values_by_grid(scored) + scored = shuffle_sequence_dict(scored) + batches = split_tensor_dict(scored, self.args.steps_per_generation) + micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches] + micro_batches = micro_batches * self.num_iterations + + self._buffered_inputs = micro_batches[1:] + + # Release cached CUDA memory from scoring + # before training allocations begin, reducing peak reserved memory. + torch.cuda.empty_cache() + + return micro_batches[0] + + @profiling_decorator + def _get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=None, + compute_entropy=False, + pixel_values=None, + image_grid_thw=None, + num_images=None, + pixel_attention_mask=None, + image_sizes=None, + token_type_ids=None, + mm_token_type_ids=None, + ) -> tuple[Any, torch.Tensor | None]: + """Compute log-probs and (optionally) entropies for each token. + + When running under no_grad (scoring path), bypasses accelerate's + ConvertOutputsToFp32 wrapper to avoid a fp32 copy of the + logits tensor. + """ + # Bypass accelerate's ConvertOutputsToFp32 wrapper which converts the + # entire (B, L, V) logits tensor from bf16 to fp32 — unnecessary and + # extremely wasteful for large vocabularies. + # Skip unwrapping for FSDP — parameters are only valid inside FSDP's + # forward context; unwrapping exposes flattened/sharded tensors. + if not self.is_fsdp_enabled: + model = self.accelerator.unwrap_model(model, keep_fp32_wrapper=False) + autocast_ctx = torch.autocast( + device_type=input_ids.device.type, dtype=torch.bfloat16 + ) + + # Use Liger's Triton kernel in scoring path (no grad): fuses + # temperature + log_softmax + gather into a single kernel pass. + use_fused = ( + self.use_liger_kernel + and _fused_selective_log_softmax is not None + and not torch.is_grad_enabled() + ) + + batch_size = batch_size or input_ids.size(0) + all_logps = [] + all_entropies = [] + with autocast_ctx: + for start in range(0, input_ids.size(0), batch_size): + input_ids_batch = input_ids[start : start + batch_size] + attention_mask_batch = attention_mask[start : start + batch_size] + + # Build model inputs + model_inputs = { + "input_ids": input_ids_batch, + "attention_mask": attention_mask_batch, + } + if image_grid_thw is not None and pixel_values is not None: + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat( + [ + torch.tensor([0], device=rows_per_sample.device), + rows_per_sample.cumsum(0), + ] + ) + row_start, row_end = ( + cum_rows[start].item(), + cum_rows[start + batch_size].item(), + ) + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[ + start : start + batch_size + ] + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask[ + start : start + batch_size + ] + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes[ + start : start + batch_size + ] + if token_type_ids is not None: + model_inputs["token_type_ids"] = token_type_ids[ + start : start + batch_size + ] + if mm_token_type_ids is not None: + model_inputs["mm_token_type_ids"] = mm_token_type_ids[ + start : start + batch_size + ] + + if "logits_to_keep" in self.model_kwarg_keys: + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False + + logits = model(**model_inputs).logits + completion_ids = input_ids_batch[:, -logits_to_keep:] + # FP8 models produce NaN logits at positions where + # attention_mask=0 (padding). Replace NaN with 0 so + # log_softmax yields uniform distribution for those positions. + # The completion_mask ensures these don't affect the loss. + logits = torch.nan_to_num(logits, nan=0.0) + + if use_fused: + logits = logits[:, -(logits_to_keep + 1) :, :] + if not logits.is_contiguous(): + logits = logits.contiguous() + logps = _fused_selective_log_softmax( + logits, completion_ids, self.temperature + ) + all_logps.append(logps) + else: + logits = logits[:, :-1, :] + logits = logits[:, -logits_to_keep:, :] + logits.div_(self.temperature) + logps = selective_log_softmax(logits, completion_ids) + all_logps.append(logps) + + if compute_entropy: + with torch.no_grad(): + entropies = entropy_from_logits(logits) + all_entropies.append(entropies) + + logps = torch.cat(all_logps, dim=0) + entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None + return logps, entropies + + # ------------------------------------------------------------------ + # Loss override (adds IS ratio + OPSM) + # ------------------------------------------------------------------ + + @staticmethod + def get_off_policy_mask( + advantages, + per_token_logps, + sampling_per_token_logps, + mask, + off_policy_threshold, + ): + """OPSM from DeepSeek-V3.2: drop sequences with negative advantage + high KL.""" + kl_div = sampling_per_token_logps - per_token_logps.detach() + seq_kl = (kl_div * mask).sum(dim=1, keepdim=True) / mask.sum( + dim=1, keepdim=True + ).clamp(min=1.0) + is_pos_adv = advantages >= 0 + is_low_kl = seq_kl <= off_policy_threshold + return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) + + def _compute_loss(self, model, inputs): + """Override to add IS ratio correction and off-policy sequence masking.""" + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = ( + inputs["completion_ids"], + inputs["completion_mask"], + ) + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + mask = ( + completion_mask + if "tool_mask" not in inputs + else completion_mask * inputs["tool_mask"] + ) + + per_token_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), + mm_token_type_ids=inputs.get("mm_token_type_ids"), + ) + if self.top_entropy_quantile < 1.0: + entropy_mask = self.get_high_entropy_mask( + entropies, mask, 1 - self.top_entropy_quantile + ) + else: + entropy_mask = None + + advantages = inputs["advantages"] + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = ( + per_token_logps.detach() + if old_per_token_logps is None + else old_per_token_logps + ) + + # --- OPSM (off-policy sequence mask) --- + off_policy_mask = None + if getattr(self, "off_policy_mask_threshold", None) is not None: + sampling_per_token_logps = inputs.get( + "sampling_per_token_logps", old_per_token_logps + ) + off_policy_mask = self.get_off_policy_mask( + advantages=advantages, + per_token_logps=per_token_logps, + sampling_per_token_logps=sampling_per_token_logps, + mask=mask, + off_policy_threshold=self.off_policy_mask_threshold, + ) + + # --- Importance weights --- + log_ratio = per_token_logps - old_per_token_logps + is_level = getattr( + self, + "importance_sampling_level", + getattr(self.args, "importance_sampling_level", "token"), + ) + if is_level == "token": + log_importance_weights = log_ratio + elif is_level == "sequence": + log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp( + min=1.0 + ) + log_importance_weights = log_importance_weights.unsqueeze(-1) + else: + raise ValueError(f"Unknown importance sampling level: {is_level}") + + coef_1 = torch.exp(log_importance_weights) + + # --- KL divergence --- + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) + - (ref_per_token_logps - per_token_logps) + - 1 + ) + if getattr(self.args, "use_bias_correction_kl", False): + per_token_kl = per_token_kl * coef_1 + + # --- Per-token loss --- + if self.loss_type == "cispo": + clamped = torch.clamp(coef_1, max=self.epsilon_high).detach() + per_token_loss = -clamped * advantages * per_token_logps + elif self.loss_type in ("grpo", "bnpo", "dr_grpo", "dapo", "luspo"): + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + if self.args.delta is not None: + coef_1_c = torch.clamp(coef_1, max=self.args.delta) + else: + coef_1_c = coef_1 + per_token_loss = -torch.min(coef_1_c * advantages, coef_2 * advantages) + elif self.loss_type == "sapo": + temps = torch.where( + advantages > 0, + self.args.sapo_temperature_pos, + self.args.sapo_temperature_neg, + ) + soft = torch.sigmoid(temps * (coef_1 - 1)) * 4 / temps + per_token_loss = -soft * advantages + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + # --- Apply masks --- + if off_policy_mask is not None: + per_token_loss = per_token_loss * off_policy_mask + if entropy_mask is not None: + per_token_loss = per_token_loss * entropy_mask + + # --- IS ratio correction (vLLM distribution mismatch) --- + if ( + self.use_vllm + and getattr(self, "vllm_importance_sampling_correction", False) + and "importance_sampling_ratio" in inputs + ): + per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] + + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + # --- Aggregate loss --- + mode = "train" if self.model.training else "eval" + normalizer = ( + self.current_gradient_accumulation_steps if mode == "train" else 1.0 + ) + + if self.loss_type in ("grpo", "sapo"): + loss = ( + (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) + ).mean() / normalizer + elif self.loss_type == "bnpo": + loss = ( + (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) / normalizer + ) + elif self.loss_type == "dr_grpo": + loss = ( + (per_token_loss * mask).sum() + / (per_token_loss.size(0) * self.max_completion_length) + / normalizer + ) + elif self.loss_type in ("cispo", "dapo"): + norm = inputs["num_items_in_batch"] / self.accelerator.num_processes + loss = (per_token_loss * mask).sum() / norm + elif self.loss_type == "luspo": + loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() / normalizer + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + # --- Metrics --- + completion_token_count = mask.sum().clamp(min=1.0) + + def masked_batch_mean(x): + return ( + x.mean() + if x.shape[1] == 1 + else (x * mask).sum() / completion_token_count + ) + + if self.beta != 0.0: + mean_kl = masked_batch_mean(per_token_kl) + self._metrics[mode]["kl"].append( + self.accelerator.gather(mean_kl).nanmean().item() + ) + + mean_entropy = masked_batch_mean(entropies) + self._metrics[mode]["entropy"].append( + self.accelerator.gather(mean_entropy).nanmean().item() + ) + + if self.loss_type in ("grpo", "bnpo", "dr_grpo", "dapo", "luspo"): + is_low = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) + is_high = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) + is_region = is_low | is_high + low_clip = masked_batch_mean(is_low.float()) + high_clip = masked_batch_mean(is_high.float()) + clip_ratio = masked_batch_mean(is_region.float()) + g_low = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(g_low.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(g_low).item()) + g_high = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(g_high.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(g_high).item()) + g_clip = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append( + g_clip.nanmean().item() + ) + elif self.loss_type == "cispo": + is_cispo = (coef_1 > self.epsilon_high) & (advantages > 0) + cr = masked_batch_mean(is_cispo.float()) + self._metrics[mode]["cispo_clip_ratio"].append( + self.accelerator.gather(cr).nanmean().item() + ) + + return loss diff --git a/src/axolotl/core/trainers/grpo/fast_async_trainer.py b/src/axolotl/core/trainers/grpo/fast_async_trainer.py new file mode 100644 index 000000000..9d1128b97 --- /dev/null +++ b/src/axolotl/core/trainers/grpo/fast_async_trainer.py @@ -0,0 +1,768 @@ +# Copyright 2020-2026 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Experimental GRPO extensions: parallel reward workers, replay buffer, +deferred re-roll, and zero-advantage skipping. + +These features are built as subclasses of GRPOTrainer and GRPODataProducer, +using the hook system (_compute_rewards_for_batch, _post_advantage_hook, +_pre_produce_hook) defined in the base classes. +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from dataclasses import dataclass, field + +import torch +from torch import nn +from trl import GRPOTrainer + +from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOConfig, + AsyncGRPOTrainer, + GRPODataProducer, +) +from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Extended config +# --------------------------------------------------------------------------- + + +@dataclass +class FastAsyncGRPOConfig(AsyncGRPOConfig): + """GRPOConfig with additional experimental parameters.""" + + reward_num_workers: int = field( + default=1, + metadata={ + "help": "Number of persistent subprocess workers for parallel reward computation. Each worker has its " + "own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across " + "workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions." + }, + ) + replay_buffer_size: int = field( + default=0, + metadata={ + "help": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout " + "groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups " + "(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True." + }, + ) + replay_recompute_logps: bool = field( + default=True, + metadata={ + "help": "When True (default), recompute old_per_token_logps for replayed groups using the current " + "training model. This fixes the importance sampling mismatch that occurs when replaying stale data. " + "Only relevant when replay_buffer_size > 0." + }, + ) + reroll_start_fraction: float = field( + default=0.5, + metadata={ + "help": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts " + "(where all rewards in a group are identical) are buffered and re-injected into later batches when the " + "model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True." + }, + ) + reroll_max_groups: int = field( + default=1, + metadata={ + "help": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values " + "increase data utilization but reduce prompt diversity. Only used with use_data_producer=True." + }, + ) + skip_zero_advantage_batches: bool = field( + default=True, + metadata={ + "help": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning " + "signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is " + "logged with skipped_zero_adv_batches=1 for monitoring." + }, + ) + vllm_lora_sync: bool = field( + default=False, + metadata={ + "help": "When True, sync LoRA adapter weights to vLLM via filesystem instead of merging into base model " + "and NCCL-broadcasting all parameters. vLLM loads the adapter natively using Punica kernels. " + "Requires vllm_serve_lora serve module (auto-selected when this is True). " + "Syncs only LoRA adapter weights (much smaller) vs full merged model. Legacy merge behavior is used when False." + }, + ) + + +# --------------------------------------------------------------------------- +# Extended data producer with re-roll injection +# --------------------------------------------------------------------------- + + +class RerollDataProducer(GRPODataProducer): + """GRPODataProducer that injects re-roll candidates into prompt batches. + + Reads from the trainer's ``_reroll_buffer`` (populated by + ``GRPOExperimentalTrainer._post_advantage_hook``) and replaces the + last N prompt groups with previously-failed prompts. + """ + + def _pre_produce_hook(self, inputs: list, global_step: int) -> list: + trainer = self._trainer + reroll_buf = getattr(trainer, "_reroll_buffer", None) + reroll_lock = getattr(trainer, "_reroll_lock", None) + if reroll_buf is None or reroll_lock is None: + return inputs + + max_steps = getattr(trainer.args, "max_steps", -1) + start_frac = getattr(trainer.args, "reroll_start_fraction", 1.0) + max_groups = getattr(trainer.args, "reroll_max_groups", 1) + reroll_start_step = ( + max(1, int(max_steps * start_frac)) if max_steps > 0 else float("inf") + ) + + if global_step < reroll_start_step: + return inputs + + with reroll_lock: + n_to_take = min(max_groups, len(reroll_buf)) + reroll_prompts = [reroll_buf.pop(0) for _ in range(n_to_take)] + + if reroll_prompts: + num_gen = self._num_generations + n_groups = len(inputs) // num_gen + for i, reroll_prompt in enumerate(reroll_prompts): + group_idx = n_groups - 1 - i + if group_idx < 0: + break + start = group_idx * num_gen + for j in range(num_gen): + inputs[start + j] = reroll_prompt + logger.info( + f"[REROLL] Step {global_step}: replaced {len(reroll_prompts)}/{n_groups} prompt groups " + f"with deferred re-roll candidates ({len(reroll_buf)} remaining)" + ) + + return inputs + + +# --------------------------------------------------------------------------- +# Persistent reward subprocess pool +# --------------------------------------------------------------------------- + + +def _persistent_reward_worker(conn): + """Long-lived reward worker. Receives work items, returns results.""" + while True: + try: + msg = conn.recv() + except EOFError: + break + if msg is None: # Shutdown signal + break + ( + reward_funcs, + prompts, + completions, + completion_ids_list, + inputs, + reward_func_names, + ) = msg + try: + keys = [ + key + for key in inputs[0] + if key not in ["prompt", "completion", "completion_ids"] + ] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + results = [] + for reward_func, _reward_func_name in zip( + reward_funcs, reward_func_names, strict=True + ): + output = reward_func( + prompts=prompts, + completions=completions, + completion_ids=completion_ids_list, + **reward_kwargs, + ) + results.append( + [float(r) if r is not None else float("nan") for r in output] + ) + conn.send(results) + except Exception: + conn.send(None) + + +# --------------------------------------------------------------------------- +# Extended trainer +# --------------------------------------------------------------------------- + + +class FastAsyncGRPOTrainer(AsyncGRPOTrainer): + """GRPOTrainer with experimental extensions. + + Adds: + - Parallel reward subprocess workers (``reward_num_workers``) + - Replay buffer for high-signal group reuse (``replay_buffer_size``) + - Deferred re-roll of failed prompts (``reroll_start_fraction``) + - Zero-advantage micro-batch skipping + """ + + def __init__(self, *args, **kwargs): + # These must be initialized before super().__init__() because + # _create_data_producer (called during super().__init__) needs them. + self._reroll_buffer: list = [] + self._reroll_lock = threading.Lock() + + # Temporarily suppress the base class's Liger + OPSM validation check, + # since this subclass supports it via a custom compute_liger_loss override. + grpo_args = kwargs.get("args") + if grpo_args is None: + for a in args: + if hasattr(a, "off_policy_mask_threshold"): + grpo_args = a + break + saved_threshold = None + if grpo_args is not None and getattr(grpo_args, "use_liger_kernel", False): + saved_threshold = grpo_args.off_policy_mask_threshold + grpo_args.off_policy_mask_threshold = None + + super().__init__(*args, **kwargs) + + if saved_threshold is not None: + grpo_args.off_policy_mask_threshold = saved_threshold + self.off_policy_mask_threshold = saved_threshold + + # Replay buffer + if getattr(self.args, "replay_buffer_size", 0) > 0: + self._replay_buffer = ReplayBuffer(max_size=self.args.replay_buffer_size) + else: + self._replay_buffer = None + self._replay_recompute_logps = getattr( + self.args, "replay_recompute_logps", True + ) + + # Reward worker pool (lazy-initialized) + self._reward_workers = None + + # -- Factory override: use RerollDataProducer ---------------------------- + + def _create_data_producer(self, args, train_dataset): + """Override to use RerollDataProducer for re-roll prompt injection.""" + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncDataProducer, + ProducerConfig, + ) + + producer_config = ProducerConfig( + mini_epochs=args.num_iterations, + max_rollouts=None, + eval_during_produce=False, + empty_cache_before_produce=True, + empty_cache_after_produce=True, + async_prefetch=args.async_prefetch, + prefetch_depth=args.prefetch_depth, + ) + data_producer = RerollDataProducer( + config=producer_config, + prompt_dataset=train_dataset, + num_generations=self.num_generations, + generation_batch_size=args.generation_batch_size, + train_batch_size=args.per_device_train_batch_size, + steps_per_generation=args.steps_per_generation, + shuffle_dataset=self.shuffle_dataset, + seed=args.seed, + ) + data_producer.set_trainer(self) + if args.async_prefetch: + data_producer = AsyncDataProducer( + data_producer, + background_produce_kwargs={"skip_policy_logps": True}, + ) + return data_producer + + # -- Reward worker pool -------------------------------------------------- + + def _get_reward_workers(self): + """Return a list of persistent reward worker subprocesses (lazy-initialized).""" + import multiprocessing as _mp + + num_workers = getattr(self.args, "reward_num_workers", 1) + if num_workers < 1: + num_workers = 1 + + if self._reward_workers is not None: + alive = all(proc.is_alive() for conn, proc in self._reward_workers) + if alive and len(self._reward_workers) == num_workers: + return self._reward_workers + self._shutdown_reward_workers() + + workers = [] + for _ in range(num_workers): + parent_conn, child_conn = _mp.Pipe() + proc = _mp.Process( + target=_persistent_reward_worker, args=(child_conn,), daemon=True + ) + proc.start() + child_conn.close() + workers.append((parent_conn, proc)) + + self._reward_workers = workers + return workers + + def _shutdown_reward_workers(self): + """Shut down all persistent reward workers.""" + if self._reward_workers is None: + return + for conn, proc in self._reward_workers: + try: + conn.send(None) + proc.join(timeout=5) + except Exception: + pass + try: + conn.close() + except Exception: + pass + self._reward_workers = None + + # -- Hook overrides ------------------------------------------------------ + + def _compute_rewards_for_batch( + self, inputs, prompts, completions, completion_ids_list + ): + """Dispatch rewards to parallel subprocess workers (synchronous wrapper).""" + self._launch_reward_workers(inputs, prompts, completions, completion_ids_list) + return self._collect_reward_workers( + inputs, prompts, completions, completion_ids_list + ) + + def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list): + """Send reward work to subprocess workers (non-blocking). + + Results are collected later by _collect_reward_workers, allowing GPU + logprob computation to overlap with CPU reward computation. + """ + reward_can_bg = all( + callable(rf) + and not isinstance(rf, nn.Module) + and not asyncio.iscoroutinefunction(rf) + for rf in self.reward_funcs + ) + num_workers = getattr(self.args, "reward_num_workers", 1) + + if not reward_can_bg or num_workers <= 1: + # Can't parallelize — store args for sync fallback in collect + self._reward_workers_used = None + self._pending_reward_args = ( + inputs, + prompts, + completions, + completion_ids_list, + ) + return + + workers = self._get_reward_workers() + num_generations = self.num_generations + num_prompts = len(prompts) + num_groups = num_prompts // num_generations + + # Shard by prompt groups across workers + groups_per_worker = max(1, (num_groups + len(workers) - 1) // len(workers)) + workers_used = [] + for w_idx, (conn, _proc) in enumerate(workers): + g_start = w_idx * groups_per_worker + g_end = min((w_idx + 1) * groups_per_worker, num_groups) + if g_start >= num_groups: + break + s_start = g_start * num_generations + s_end = g_end * num_generations + conn.send( + ( + self.reward_funcs, + prompts[s_start:s_end], + completions[s_start:s_end], + completion_ids_list[s_start:s_end], + inputs[s_start:s_end], + self.reward_func_names, + ) + ) + workers_used.append(conn) + + self._reward_workers_used = workers_used + self._pending_reward_args = (inputs, prompts, completions, completion_ids_list) + + def _collect_reward_workers( + self, inputs, prompts, completions, completion_ids_list + ): + """Collect reward results from subprocess workers (blocks until done).""" + from accelerate.utils import gather + + workers_used = getattr(self, "_reward_workers_used", None) + args = getattr(self, "_pending_reward_args", None) + self._reward_workers_used = None + self._pending_reward_args = None + + if workers_used is None: + # Sync fallback — compute on main thread + if args is not None: + return self._calculate_rewards(*args) + return self._calculate_rewards( + inputs, prompts, completions, completion_ids_list + ) + + device = self.accelerator.device + num_prompts = len(args[1]) if args else len(prompts) + + # Collect results from workers + all_worker_results = [] + any_failed = False + for conn in workers_used: + result = conn.recv() + if result is None: + any_failed = True + # Drain remaining workers to prevent stale results in pipes + for remaining_conn in workers_used: + if remaining_conn is not conn: + try: + remaining_conn.recv() + except Exception: + pass + break + all_worker_results.append(result) + + if not any_failed: + rewards_per_func = torch.zeros( + num_prompts, len(self.reward_funcs), device=device + ) + offset = 0 + for worker_result in all_worker_results: + chunk_size = len(worker_result[0]) + for i, result in enumerate(worker_result): + rewards_per_func[offset : offset + chunk_size, i] = torch.tensor( + result, dtype=torch.float32, device=device + ) + offset += chunk_size + return gather(rewards_per_func) + + # Fallback to main thread on failure + if args is not None: + return self._calculate_rewards(*args) + return self._calculate_rewards( + inputs, prompts, completions, completion_ids_list + ) + + def _post_advantage_hook( + self, + data: dict, + rewards_per_func, + advantages, + inputs: list, + num_generations: int, + mode: str, + s_start: int | None = None, + s_end: int | None = None, + is_last_chunk: bool = True, + ) -> None: + """Replay buffer store/replace + re-roll buffering.""" + from trl.models.utils import disable_gradient_checkpointing + + # -- Replay buffer: store high-signal groups -- + if self._replay_buffer is not None: + local_grouped = rewards_per_func.view( + -1, num_generations, len(self.reward_funcs) + ) + per_group_std = local_grouped.std(dim=1) + has_signal = (per_group_std > 0).any(dim=1) + offset = s_start or 0 + + if has_signal.any(): + grouped_adv = advantages.view(-1, num_generations) + replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1) + for group_idx in has_signal.nonzero(as_tuple=True)[0]: + gi = group_idx.item() + start = offset + gi * num_generations + end = start + num_generations + group_data = {} + for key in data: + val = data[key] + if ( + isinstance(val, torch.Tensor) + and val.dim() > 0 + and val.size(0) >= end + ): + group_data[key] = val[start:end].clone() + self._replay_buffer.add(replay_scores[gi].item(), group_data) + + # Replace zero-signal groups with high-signal replay buffer entries + # Only in non-streaming path (s_start is None) — streaming scores + # groups incrementally, so replacement + logprob recompute would be + # too expensive per chunk. + n_replaced = 0 + if s_start is None: + no_signal = ~has_signal + replaced_ranges = [] + if no_signal.any() and len(self._replay_buffer) > 0: + for group_idx in no_signal.nonzero(as_tuple=True)[0]: + sampled = self._replay_buffer.sample(1) + if sampled is None: + break + sampled_group = sampled[0] + gi = group_idx.item() + start = offset + gi * num_generations + end = start + num_generations + for key, val in sampled_group.items(): + if key in data and isinstance(data[key], torch.Tensor): + src = val.to(data[key].device) + tgt_seq_len = ( + data[key].size(1) if data[key].dim() > 1 else None + ) + if start >= data[key].size(0) or end > data[key].size( + 0 + ): + continue + if tgt_seq_len is not None: + if src.size(1) <= tgt_seq_len: + data[key][start:end] = 0 + data[key][start:end, : src.size(1)] = src + else: + data[key][start:end] = src[:, :tgt_seq_len] + else: + data[key][start:end] = src + replaced_ranges.append((start, end)) + n_replaced += 1 + + # Recompute old_per_token_logps for replayed groups + if ( + n_replaced > 0 + and self._replay_recompute_logps + and "old_per_token_logps" in data + ): + with ( + torch.no_grad(), + disable_gradient_checkpointing( + self.model, self.args.gradient_checkpointing_kwargs + ), + ): + for r_start, r_end in replaced_ranges: + r_ids = torch.cat( + [ + data["prompt_ids"][r_start:r_end], + data["completion_ids"][r_start:r_end], + ], + dim=1, + ) + r_mask = torch.cat( + [ + data["prompt_mask"][r_start:r_end], + data["completion_mask"][r_start:r_end], + ], + dim=1, + ) + r_logits_to_keep = data["completion_ids"].size(1) + r_fwd_kwargs = {} + for fk in ( + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): + if fk in data: + r_fwd_kwargs[fk] = data[fk] + r_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + r_ids, + r_mask, + r_logits_to_keep, + r_end - r_start, + **r_fwd_kwargs, + ) + data["old_per_token_logps"][r_start:r_end] = r_logps + + if n_replaced > 0: + self._metrics[mode]["replay_buffer_replacements"].append( + float(n_replaced) + ) + + if is_last_chunk: + self._metrics[mode]["replay_buffer_size"].append( + float(len(self._replay_buffer)) + ) + + # -- Re-roll buffer: store failed prompts -- + if getattr(self.args, "reroll_start_fraction", 1.0) < 1.0: + grouped_rewards = rewards_per_func.view( + -1, num_generations, len(self.reward_funcs) + ) + per_group_std = grouped_rewards.std(dim=1) + per_group_mean = grouped_rewards.mean(dim=1) + zero_signal = (per_group_std == 0).all(dim=1) + all_failed = (per_group_mean.abs() < 1e-6).all(dim=1) + should_reroll = zero_signal & all_failed + _n_buffered = 0 + with self._reroll_lock: + for group_idx in should_reroll.nonzero(as_tuple=True)[0]: + idx = group_idx.item() * num_generations + if idx >= len(inputs): + continue + prompt_input = inputs[idx] + self._reroll_buffer.append(prompt_input) + _n_buffered += 1 + if _n_buffered > 0: + self._metrics[mode]["reroll_buffered"].append(float(_n_buffered)) + if is_last_chunk: + self._metrics[mode]["reroll_buffer_size"].append( + float(len(self._reroll_buffer)) + ) + + # -- Zero-advantage skipping + Liger OPSM --------------------------------- + + def compute_liger_loss(self, unwrapped_model, inputs): + """Liger loss with zero-adv skipping and off-policy sequence masking (OPSM). + + The base class Liger path doesn't support OPSM because the fused kernel + doesn't expose per-token logprobs needed for the KL computation. This + override computes them via chunked lm_head matmul (no grad, low memory) + and applies the OPSM to the loss mask before calling the kernel. + """ + if self.args.skip_zero_advantage_batches and torch.all( + inputs["advantages"] == 0 + ): + mode = "train" if self.model.training else "eval" + self._metrics[mode]["skipped_zero_adv_batches"].append(1.0) + return torch.tensor( + 0.0, device=inputs["advantages"].device, requires_grad=True + ) + + if self.off_policy_mask_threshold is None: + return super().compute_liger_loss(unwrapped_model, inputs) + + # OPSM path: need per_token_logps for KL, which Liger kernel doesn't provide + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = ( + inputs["completion_ids"], + inputs["completion_mask"], + ) + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + + last_hidden_state = self._get_last_hidden_state( + unwrapped_model, + input_ids, + attention_mask, + logits_to_keep, + inputs.get("pixel_values"), + inputs.get("image_grid_thw"), + inputs.get("pixel_attention_mask"), + inputs.get("image_sizes"), + ) + + loss_mask = ( + completion_mask + if "tool_mask" not in inputs + else completion_mask * inputs["tool_mask"] + ) + + # Compute per_token_logps via chunked lm_head matmul (no grad, low memory) + lm_weight = unwrapped_model.lm_head.weight + lm_bias = unwrapped_model.lm_head.bias + with torch.no_grad(): + per_token_logps_chunks = [] + for i in range(last_hidden_state.size(0)): + chunk_logits = torch.matmul(last_hidden_state[i : i + 1], lm_weight.t()) + if lm_bias is not None: + chunk_logits = chunk_logits + lm_bias + chunk_lps = ( + chunk_logits.float() + .log_softmax(-1) + .gather(-1, completion_ids[i : i + 1].unsqueeze(-1)) + .squeeze(-1) + ) + per_token_logps_chunks.append(chunk_lps) + del chunk_logits + per_token_logps = torch.cat(per_token_logps_chunks, dim=0) + + advantages = inputs["advantages"] + if advantages.dim() == 1: + advantages_2d = advantages.unsqueeze(1) + else: + advantages_2d = advantages + + sampling_per_token_logps = inputs.get("sampling_per_token_logps") + if sampling_per_token_logps is None: + sampling_per_token_logps = inputs.get("old_per_token_logps") + if sampling_per_token_logps is None: + sampling_per_token_logps = per_token_logps + + off_policy_mask = GRPOTrainer.get_off_policy_mask( + advantages=advantages_2d, + per_token_logps=per_token_logps, + sampling_per_token_logps=sampling_per_token_logps, + mask=loss_mask, + off_policy_threshold=self.off_policy_mask_threshold, + ) + loss_mask = loss_mask * off_policy_mask + + # Call the Liger fused kernel with OPSM-modified mask + loss, metrics = self.liger_grpo_loss( + _input=last_hidden_state, + lin_weight=unwrapped_model.lm_head.weight, + selected_token_ids=completion_ids, + attention_mask=loss_mask, + advantages=inputs["advantages"], + bias=unwrapped_model.lm_head.bias, + old_per_token_logps=inputs.get("old_per_token_logps"), + ref_per_token_logps=inputs.get("ref_per_token_logps"), + vllm_is_ratio=inputs.get("importance_sampling_ratio"), + ) + + mean_kl = metrics[0] if self.beta != 0.0 else None + clip_ratio = metrics[-1] + + mode = "train" if self.model.training else "eval" + if self.beta != 0.0: + self._metrics[mode]["kl"].append( + self.accelerator.gather(mean_kl).mean().item() + ) + self._metrics[mode]["clip_ratio"].append( + self.accelerator.gather(clip_ratio).mean().item() + ) + normalizer = ( + self.current_gradient_accumulation_steps if mode == "train" else 1.0 + ) + return loss / normalizer + + def _compute_loss(self, model, inputs): + if self.args.skip_zero_advantage_batches and torch.all( + inputs["advantages"] == 0 + ): + mode = "train" if self.model.training else "eval" + self._metrics[mode]["skipped_zero_adv_batches"].append(1.0) + # Create zero loss with grad_fn. DeepSpeed requires grad_fn != None. + # With ZeRO-3, parameters are partitioned (shape=[0], requires_grad=False) + # so we can't just do `(p * 0).sum()`. Instead, do a tiny forward pass + # with a single token to create a proper computation graph. + prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1) + attn = torch.ones_like(prompt_ids) + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + out = model(input_ids=prompt_ids, attention_mask=attn) + return out.logits.sum() * 0 + return super()._compute_loss(model, inputs) diff --git a/src/axolotl/core/trainers/grpo/replay_buffer.py b/src/axolotl/core/trainers/grpo/replay_buffer.py new file mode 100644 index 000000000..1e35cb56a --- /dev/null +++ b/src/axolotl/core/trainers/grpo/replay_buffer.py @@ -0,0 +1,44 @@ +"""Simple replay buffer for storing and sampling high-signal rollout groups.""" + +import heapq + +import torch + + +class ReplayBuffer: + """Min-heap replay buffer that keeps the highest-scoring rollout groups. + Groups are scored by signal quality (advantage magnitude * reward variance). + When sampling, groups are drawn proportional to their scores. + """ + + def __init__(self, max_size: int): + self.max_size = max_size + self._heap: list[tuple[float, int, dict]] = [] # min-heap of (score, id, data) + self._counter = 0 # unique tiebreaker for heap + + def __len__(self): + return len(self._heap) + + def add(self, score: float, data: dict): + """Add a group to the buffer. If full, replaces lowest-scoring entry.""" + if self.max_size <= 0: + return + self._counter += 1 + if len(self._heap) < self.max_size: + heapq.heappush(self._heap, (score, self._counter, data)) + elif score > self._heap[0][0]: + heapq.heapreplace(self._heap, (score, self._counter, data)) + + def sample(self, num_samples: int) -> list[dict] | None: + """Sample groups weighted by their scores. Returns None if buffer is empty.""" + if self.max_size <= 0 or not self._heap: + return None + + scores = torch.tensor([item[0] for item in self._heap], dtype=torch.float32) + scores = scores.clamp(min=1e-8) # avoid zero probabilities + probs = scores / scores.sum() + replacement = num_samples > len(self._heap) + indices = torch.multinomial( + probs, num_samples, replacement=replacement + ).tolist() + return [self._heap[i][2] for i in indices] diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index f9f5a695b..3a95ad439 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -40,6 +40,7 @@ from trl.trainer.grpo_config import GRPOConfig from trl.trainer.grpo_trainer import RewardFunc, nanstd from trl.trainer.utils import pad +from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOTrainer from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.mixins import ( DistributedParallelMixin, @@ -66,6 +67,19 @@ class AxolotlGRPOTrainer( _tag_names = ["trl", "grpo", "axolotl"] +class AxolotlAsyncGRPOTrainer( + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + FastAsyncGRPOTrainer, +): + """Extend AsyncGRPOTrainer with axolotl helpers""" + + _tag_names = ["trl", "grpo", "async", "axolotl"] + + class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): """Extend the base GRPOTrainer for sequence parallelism handling""" diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index c3356fb90..28ef75acc 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -25,7 +25,7 @@ def get_lora_parameters( ) -> tuple[ torch.Tensor, torch.Tensor | None, - QuantState | None, + QuantState | torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, float | None, @@ -48,9 +48,13 @@ def get_lora_parameters( if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: quant_state = getattr(W, "quant_state", None) + if quant_state is None and W.dtype == torch.float8_e4m3fn: + quant_state = getattr(base_layer, "weight_scale_inv", None) return W, b, quant_state, None, None, None quant_state = getattr(W, "quant_state", None) + if quant_state is None and W.dtype == torch.float8_e4m3fn: + quant_state = getattr(base_layer, "weight_scale_inv", None) active_adapter = ( proj.active_adapters[0] @@ -81,7 +85,7 @@ def matmul_lora( X: torch.Tensor, W: torch.Tensor, b: torch.Tensor | None, - W_quant: QuantState | None, + W_quant: QuantState | torch.Tensor | None, A: torch.Tensor | None, B: torch.Tensor | None, s: float | None, diff --git a/src/axolotl/kernels/quantize.py b/src/axolotl/kernels/quantize.py index d094f2381..c9c0f59bd 100644 --- a/src/axolotl/kernels/quantize.py +++ b/src/axolotl/kernels/quantize.py @@ -1,4 +1,4 @@ -"""Dequantization utilities for `bitsandbytes` integration.""" +"""Dequantization utilities for `bitsandbytes` and FP8 integration.""" import ctypes @@ -15,9 +15,50 @@ CUDA_STREAM: torch.cuda.Stream | None = None HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3") +def dequantize_fp8( + W: torch.Tensor, + scale_inv: torch.Tensor, + dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv. + + Args: + W: FP8 weight tensor [out_features, in_features] in float8_e4m3fn. + scale_inv: Per-block inverse scale [ceil(out/block), ceil(in/block)] + or per-tensor scalar. + dtype: Output dtype (default bf16). + + Returns: + Dequantized tensor in the specified dtype. + """ + W_float = W.to(dtype) + if scale_inv.numel() == 1: + return W_float * scale_inv.to(dtype) + if scale_inv.dim() == 2 and W.dim() == 2: + sr, sc = scale_inv.shape + br = W.shape[0] // sr + bc = W.shape[1] // sc + # If dimensions are exactly divisible, use fast reshape path + if sr * br == W.shape[0] and sc * bc == W.shape[1]: + return ( + W_float.reshape(sr, br, sc, bc) * scale_inv[:, None, :, None].to(dtype) + ).reshape(W.shape) + # Tail-block handling: compute actual block size (ceil division), + # tile scale_inv to cover full shape, then crop to W's dimensions + br_ceil = -(-W.shape[0] // sr) # ceil(rows / scale_rows) = block_size + bc_ceil = -(-W.shape[1] // sc) + scale_expanded = ( + scale_inv.to(dtype) + .repeat_interleave(br_ceil, dim=0) + .repeat_interleave(bc_ceil, dim=1) + )[: W.shape[0], : W.shape[1]] + return W_float * scale_expanded + return W_float * scale_inv.to(dtype) + + def dequantize( W: torch.Tensor, - quant_state: QuantState | list | None = None, + quant_state: QuantState | list | torch.Tensor | None = None, out: torch.Tensor | None = None, ) -> torch.Tensor: """ @@ -49,6 +90,15 @@ def dequantize( if quant_state is None: return W + # FP8 path: quant_state is actually scale_inv tensor + if W.dtype == torch.float8_e4m3fn: + scale_inv = quant_state + # Caller may pass W.t() (non-contiguous) — dequantize in original + # layout then transpose back so the result shape matches the input. + if not W.is_contiguous() and W.dim() == 2: + return dequantize_fp8(W.t(), scale_inv).t() + return dequantize_fp8(W, scale_inv) + # Get the target device from input tensor W target_device = W.device diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index eb7203c01..2b53b7b2c 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -160,6 +160,18 @@ def load_lora( else: model = get_peft_model(model, lora_config, **model_kwargs) + # FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training + # requires a compute dtype (bf16/fp16). Cast trainable LoRA params. + if cfg.torch_dtype: + _fp8_cast_dtype = cfg.torch_dtype + elif torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + _fp8_cast_dtype = torch.bfloat16 + else: + _fp8_cast_dtype = torch.float16 + for _name, param in model.named_parameters(): + if param.requires_grad and param.dtype == torch.float8_e4m3fn: + param.data = param.data.to(_fp8_cast_dtype) + if rank == 0: try: model.print_trainable_parameters() diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 2662d0b86..37c112337 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -215,6 +215,8 @@ class ModelLoader: self.model_kwargs["revision"] = self.cfg.revision_of_model if self.cfg.use_kernels: self.model_kwargs["use_kernels"] = self.cfg.use_kernels + if "allow_all_kernels" not in self.model_kwargs: + self.model_kwargs["allow_all_kernels"] = self.cfg.use_kernels self._set_quantization_config() self._set_attention_config() self._check_model_requirements() diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 5874c940b..205e32e6f 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -116,6 +116,7 @@ class PatchManager: self._apply_patch_deepspeed_zero3() self._apply_voxtral_patches() self._apply_apertus_patches() + self._apply_trl_vllm_patches() def apply_post_plugin_pre_model_load_patches(self): """Apply post plugin-pre_model_load load patches based on config.""" @@ -667,6 +668,17 @@ class PatchManager: patch_apertus_xielu_activation() + def _apply_trl_vllm_patches(self): + """Apply TRL vLLM patches for batched weight sync, NaN logprobs fix, and scalar handling.""" + if ( + self.cfg.rl + and getattr(self.cfg, "trl", None) + and getattr(self.cfg.trl, "use_vllm", False) + ): + from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm + + patch_trl_vllm() + def _apply_scaling_softmax_patch(self, model: PreTrainedModel): """Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399""" if self.cfg.scaling_softmax: diff --git a/src/axolotl/monkeypatch/trainer/trl_vllm.py b/src/axolotl/monkeypatch/trainer/trl_vllm.py new file mode 100644 index 000000000..a3296df61 --- /dev/null +++ b/src/axolotl/monkeypatch/trainer/trl_vllm.py @@ -0,0 +1,245 @@ +"""Monkeypatches for TRL's vLLM integration and trainer utils. + +Adds: +- VLLMClient.batch_update_named_params: batched weight sync (fewer HTTP round-trips) +- extract_logprobs: NaN→0.0 fix (prevents downstream NaN propagation) +- VLLMGeneration: weight_sync_chunk_size + batched sync path for non-FSDP/non-ZeRO +- split_tensor_dict / shuffle_sequence_dict: scalar type handling (int/float/bool passthrough) +""" + +import logging +import math +from functools import wraps + +import torch +from torch import nn + +LOG = logging.getLogger(__name__) + + +def _batch_update_named_params( + self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None +): + """Batched weight sync — sends param metadata via HTTP, tensors via NCCL.""" + from transformers import is_torch_xpu_available + + if chunk_size is None: + chunks = [params] + else: + chunks = [] + current_chunk: list[tuple[str, torch.Tensor]] = [] + current_elements = 0 + for name, weights in params: + n_elem = weights.numel() + if current_chunk and current_elements + n_elem > chunk_size: + chunks.append(current_chunk) + current_chunk = [] + current_elements = 0 + current_chunk.append((name, weights)) + current_elements += n_elem + if current_chunk: + chunks.append(current_chunk) + + for chunk in chunks: + param_metadata = [ + {"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)} + for name, weights in chunk + ] + url = f"{self.base_url}/batch_update_named_params/" + response = self.session.post(url, json={"params": param_metadata}) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + for _name, weights in chunk: + if is_torch_xpu_available(): + self.communicator.broadcast(weights, root=self.rank) + else: + self.communicator.broadcast(weights, src=self.rank) + + if is_torch_xpu_available(): + self.communicator.barrier() + else: + self.communicator.group.barrier() + + +def _update_model_params(self, model: nn.Module, chunk_size: int | None = None): + """Updates all model params using batch_update_named_params.""" + params = [(name, param.data) for name, param in model.named_parameters()] + self.batch_update_named_params(params, chunk_size=chunk_size) + + +def _patched_extract_logprobs(all_outputs): + """extract_logprobs with NaN→0.0 fix (stock TRL uses None which causes downstream errors).""" + all_logprobs = [] + all_token_ids = [] + + for outputs in all_outputs: + for output in outputs.outputs: + if output.logprobs is None: + return None, None + seq_logprobs = [] + seq_token_ids = [] + for lp in output.logprobs: + sorted_items = sorted(lp.items(), key=lambda x: x[1].rank) + seq_token_ids.append([token_id for token_id, _ in sorted_items]) + seq_logprobs.append( + [ + 0.0 if math.isnan(item.logprob) else item.logprob + for _, item in sorted_items + ] + ) + all_logprobs.append(seq_logprobs) + all_token_ids.append(seq_token_ids) + + return all_logprobs, all_token_ids + + +def _patched_split_tensor_dict(tensor_dict, num_chunks): + """split_tensor_dict that handles scalar types (int/float/bool) for num_items_in_batch.""" + first_tensor = next( + tensor + for tensor in tensor_dict.values() + if tensor is not None and isinstance(tensor, torch.Tensor) and tensor.ndim > 0 + ) + chunk_size = first_tensor.shape[0] // num_chunks + chunks = [] + for i in range(num_chunks): + chunk_dict = {} + for key, tensor in tensor_dict.items(): + if isinstance(tensor, (int, float, bool)): + chunk_dict[key] = tensor + elif tensor is not None and (isinstance(tensor, list) or tensor.ndim > 0): + chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size] + elif tensor is not None and tensor.ndim == 0: + chunk_dict[key] = tensor + else: + chunk_dict[key] = None + chunks.append(chunk_dict) + return chunks + + +def _patched_shuffle_sequence_dict(seq_dict): + """shuffle_sequence_dict that handles scalar types (int/float/bool).""" + first_seq = next( + v + for v in seq_dict.values() + if v is not None and isinstance(v, (torch.Tensor, list)) and len(v) > 0 + ) + perm = torch.randperm(len(first_seq)) + + def permute(v): + if v is None: + return None + if isinstance(v, (int, float, bool)): + return v + if isinstance(v, torch.Tensor) and v.ndim == 0: + return v + if isinstance(v, torch.Tensor) and v.ndim >= 1: + return v[perm] + if isinstance(v, list): + return [v[i] for i in perm.tolist()] + return v + + return {k: permute(v) for k, v in seq_dict.items()} + + +def _patch_sync_weights_batched(original_init): + """Wrap VLLMGeneration.__init__ to accept weight_sync_chunk_size.""" + + @wraps(original_init) + def patched_init(self, *args, weight_sync_chunk_size=None, **kwargs): + original_init(self, *args, **kwargs) + self.weight_sync_chunk_size = weight_sync_chunk_size + + return patched_init + + +def _make_batched_sync_weights(original_sync_weights): + """Wrap sync_weights to use batched sync for non-FSDP/non-ZeRO paths.""" + + @wraps(original_sync_weights) + def patched_sync_weights(self): + from accelerate.utils import is_peft_model + + # Check if we're in a non-PEFT, non-FSDP, non-ZeRO scenario where batching helps + accelerator = self.accelerator + model = self.model + is_fsdp_enabled = self.is_fsdp_enabled + + deepspeed_plugin = accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + + is_peft = is_peft_model(model) + + # If PEFT, FSDP, or ZeRO-3, fall back to original (which handles those cases) + if is_peft or is_fsdp_enabled or zero_stage_3: + return original_sync_weights(self) + + # Non-PEFT, non-FSDP, non-ZeRO: use batched sync + if self.mode == "colocate" and getattr(self, "enable_sleep_mode", False): + from vllm.distributed.device_communicators.cuda_wrapper import ( + empty_cache, + ) + + empty_cache() + self.llm.wake_up(tags=["weights"]) + + if self.mode == "server" and accelerator.is_main_process: + params = [ + (self._fix_param_name_to_vllm(name), param.data) + for name, param in model.named_parameters() + ] + self.vllm_client.batch_update_named_params( + params, chunk_size=getattr(self, "weight_sync_chunk_size", None) + ) + elif self.mode == "colocate": + llm_model = ( + self.llm.llm_engine.model_executor.driver_worker.model_runner.model + ) + weights = [ + (self._fix_param_name_to_vllm(name), param.data) + for name, param in model.named_parameters() + ] + llm_model.load_weights(weights=weights) + + # Reset cache + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.mode == "colocate": + self.llm.reset_prefix_cache() + + return patched_sync_weights + + +def patch_trl_vllm(): + """Apply all TRL vLLM monkeypatches.""" + import trl.generation.vllm_client + import trl.generation.vllm_generation + import trl.trainer.utils + + VLLMClient = trl.generation.vllm_client.VLLMClient + VLLMGeneration = trl.generation.vllm_generation.VLLMGeneration + + # 1. Add batch_update_named_params to VLLMClient + if not hasattr(VLLMClient, "batch_update_named_params"): + VLLMClient.batch_update_named_params = _batch_update_named_params + VLLMClient.update_model_params = _update_model_params + LOG.info("Patched VLLMClient with batch_update_named_params") + + # 2. Patch extract_logprobs (NaN→0.0) + trl.generation.vllm_generation.extract_logprobs = _patched_extract_logprobs + LOG.info("Patched extract_logprobs with NaN→0.0 fix") + + # 3. Patch VLLMGeneration.__init__ to accept weight_sync_chunk_size + VLLMGeneration.__init__ = _patch_sync_weights_batched(VLLMGeneration.__init__) + + # 4. Patch sync_weights for batched non-FSDP/non-ZeRO path + VLLMGeneration.sync_weights = _make_batched_sync_weights( + VLLMGeneration.sync_weights + ) + LOG.info("Patched VLLMGeneration with batched sync_weights") + + # 5. Patch split_tensor_dict and shuffle_sequence_dict + trl.trainer.utils.split_tensor_dict = _patched_split_tensor_dict + trl.trainer.utils.shuffle_sequence_dict = _patched_shuffle_sequence_dict + LOG.info("Patched split_tensor_dict and shuffle_sequence_dict for scalar types") diff --git a/src/axolotl/scripts/__init__.py b/src/axolotl/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/scripts/vllm_serve_lora.py b/src/axolotl/scripts/vllm_serve_lora.py new file mode 100644 index 000000000..9ce4d2771 --- /dev/null +++ b/src/axolotl/scripts/vllm_serve_lora.py @@ -0,0 +1,503 @@ +"""vLLM serve script with native LoRA adapter support. + +Extends TRL's vllm_serve to enable direct LoRA adapter loading in vLLM, +instead of merging adapter weights into the base model before syncing. + +Usage: + Set ``vllm.serve_module: axolotl.scripts.vllm_serve_lora`` in your config, + or ``trl.vllm_lora_sync: true`` to auto-select. + +Benefits over merge-sync: + - Syncs only LoRA adapter weights via filesystem instead of full merged model via NCCL + - vLLM handles LoRA application natively (Punica kernels) + - No NCCL communicator needed for weight sync +""" + +import logging +import os +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from itertools import chain +from multiprocessing import Pipe, Process +from multiprocessing.connection import Connection +from typing import Any + +from trl.scripts.vllm_serve import ( + ScriptArguments, + chunk_list, + extract_logprobs, + get_open_port, +) +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + +logger = logging.getLogger(__name__) + + +@dataclass +class LoRAScriptArguments(ScriptArguments): + """Extended script arguments with LoRA support.""" + + enable_lora: bool = field( + default=True, + metadata={"help": "Enable LoRA adapter support in vLLM."}, + ) + max_lora_rank: int = field( + default=64, + metadata={"help": "Maximum LoRA rank supported."}, + ) + max_loras: int = field( + default=2, + metadata={"help": "Maximum number of LoRA adapters loaded simultaneously."}, + ) + lora_dtype: str = field( + default="bfloat16", + metadata={"help": "Data type for LoRA weights."}, + ) + + +def llm_worker( + script_args: LoRAScriptArguments, + data_parallel_rank: int, + master_port: int, + connection: Connection, +) -> None: + """Worker process that creates a vLLM LLM with LoRA enabled.""" + os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) + os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) + os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) + + llm = LLM( + model=script_args.model, + revision=script_args.revision, + tensor_parallel_size=script_args.tensor_parallel_size, + gpu_memory_utilization=script_args.gpu_memory_utilization, + enforce_eager=script_args.enforce_eager, + dtype=script_args.dtype, + enable_prefix_caching=script_args.enable_prefix_caching, + kv_cache_dtype=script_args.kv_cache_dtype, + max_model_len=script_args.max_model_len, + # Use batch-capable worker extension (adds batch_update_named_params + auto-close) + worker_extension_cls="axolotl.scripts.vllm_worker_ext.BatchWeightSyncWorkerExtension", + trust_remote_code=script_args.trust_remote_code, + model_impl=script_args.vllm_model_impl, + logprobs_mode="processed_logprobs", + # LoRA + enable_lora=script_args.enable_lora, + max_lora_rank=script_args.max_lora_rank, + max_loras=script_args.max_loras, + lora_dtype=script_args.lora_dtype, + ) + + connection.send({"status": "ready"}) + + while True: + try: + command = connection.recv() + except KeyboardInterrupt: + llm.collective_rpc(method="close_communicator") + break + + if command["type"] in ["call", "fire_and_forget"]: + method_name = command["method"] + args = command.get("args", ()) + kwargs = command.get("kwargs", {}) + + # Reconstruct LoRARequest from serialized dict (can't pickle across pipe) + if "lora_request" in kwargs and kwargs["lora_request"] is not None: + lr = kwargs["lora_request"] + kwargs["lora_request"] = LoRARequest( + lora_name=lr["lora_name"], + lora_int_id=lr["lora_int_id"], + lora_path=lr["lora_path"], + load_inplace=lr.get("load_inplace", False), + ) + + method = getattr(llm, method_name) + result = method(*args, **kwargs) + if command["type"] == "call": + connection.send(result) + elif command["type"] == "shutdown": + break + + +def main(script_args: ScriptArguments): + """Start vLLM workers with LoRA support and the HTTP server.""" + import asyncio + + import uvicorn + from fastapi import FastAPI + from pydantic import BaseModel, Field as PydanticField + + # Request/Response models (defined locally like TRL's vllm_serve.main) + class GenerateRequest(BaseModel): + prompts: list[str] + images: list[str] | None = None + n: int = 1 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + max_tokens: int = 16 + logprobs: int | None = 0 + truncate_prompt_tokens: int | None = None + structured_outputs_regex: str | None = None + generation_kwargs: dict = PydanticField(default_factory=dict) + + class GenerateResponse(BaseModel): + prompt_ids: list[list[int]] + completion_ids: list[list[int]] + logprobs: list[list[list[float]]] + logprob_token_ids: list[list[list[int]]] + + class ChatRequest(BaseModel): + messages: list[list[dict]] + n: int = 1 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + max_tokens: int = 16 + logprobs: int | None = 0 + truncate_prompt_tokens: int | None = None + structured_outputs_regex: str | None = None + generation_kwargs: dict = PydanticField(default_factory=dict) + chat_template_kwargs: dict = PydanticField(default_factory=dict) + + class ChatResponse(BaseModel): + prompt_ids: list[list[int]] + completion_ids: list[list[int]] + logprobs: list[list[list[float]]] + logprob_token_ids: list[list[list[int]]] + + class InitCommunicatorRequest(BaseModel): + host: str + port: int + world_size: int + client_device_uuid: str + + # Wrap plain ScriptArguments with LoRA defaults + if not isinstance(script_args, LoRAScriptArguments): + lora_args = LoRAScriptArguments.__new__(LoRAScriptArguments) + for f in ScriptArguments.__dataclass_fields__: + setattr(lora_args, f, getattr(script_args, f)) + # Apply LoRA defaults + for f in LoRAScriptArguments.__dataclass_fields__: + if f not in ScriptArguments.__dataclass_fields__: + setattr( + lora_args, f, LoRAScriptArguments.__dataclass_fields__[f].default + ) + script_args = lora_args + + # Spawn workers + master_port = get_open_port() + connections: list[Connection] = [] + processes: list[Process] = [] + for dp_rank in range(script_args.data_parallel_size): + parent_conn, child_conn = Pipe() + process = Process( + target=llm_worker, + args=(script_args, dp_rank, master_port, child_conn), + ) + process.start() + connections.append(parent_conn) + processes.append(process) + + @asynccontextmanager + async def lifespan(app: FastAPI): + import time + + startup_timeout = 300 # 5 minutes + start_time = time.monotonic() + ready: set[int] = set() + while len(ready) < script_args.data_parallel_size: + elapsed = time.monotonic() - start_time + if elapsed > startup_timeout: + raise RuntimeError( + f"vLLM workers failed to start within {startup_timeout}s " + f"({len(ready)}/{script_args.data_parallel_size} ready)" + ) + for i, (conn, proc) in enumerate(zip(connections, processes, strict=True)): + if id(conn) in ready: + continue + if not proc.is_alive(): + raise RuntimeError( + f"vLLM worker {i} exited unexpectedly during startup" + ) + if conn.poll(): + msg = conn.recv() + if isinstance(msg, dict) and msg.get("status") == "ready": + ready.add(id(conn)) + await asyncio.sleep(0.1) + yield + for p in processes: + p.join(timeout=10) + if p.is_alive(): + p.terminate() + p.join() + + app = FastAPI(lifespan=lifespan) + + # --- Active LoRA state (shared across endpoints via closure) --- + active_lora: dict = {"request": None} + + # ------------------------------------------------------------------ + # LoRA-specific endpoints + # ------------------------------------------------------------------ + + class SetLoRARequest(BaseModel): + lora_name: str + lora_int_id: int + lora_path: str + load_inplace: bool = False + + @app.post("/set_lora_adapter/") + async def set_lora_adapter(request: SetLoRARequest): + """Register a LoRA adapter for all subsequent generate/chat calls.""" + active_lora["request"] = { + "lora_name": request.lora_name, + "lora_int_id": request.lora_int_id, + "lora_path": request.lora_path, + "load_inplace": request.load_inplace, + } + logger.info( + "Set active LoRA: %s (id=%d, path=%s)", + request.lora_name, + request.lora_int_id, + request.lora_path, + ) + return {"status": "ok"} + + @app.post("/clear_lora_adapter/") + async def clear_lora_adapter(): + """Clear active LoRA adapter (revert to base model).""" + active_lora["request"] = None + return {"status": "ok"} + + # ------------------------------------------------------------------ + # Standard endpoints (mirrors TRL's vllm_serve) + # ------------------------------------------------------------------ + + @app.get("/health/") + async def health(): + return {"status": "ok"} + + @app.get("/get_world_size/") + async def get_world_size(): + return { + "world_size": script_args.tensor_parallel_size + * script_args.data_parallel_size + } + + @app.post("/generate/", response_model=GenerateResponse) + async def generate(request: GenerateRequest): + """Generate completions with optional LoRA adapter.""" + import base64 + from io import BytesIO + + import vllm + from packaging.version import Version + from vllm.sampling_params import GuidedDecodingParams + + images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item] + prompts: list[dict[str, Any]] = [] + for prompt, image in zip(request.prompts, images, strict=True): + row: dict[str, Any] = {"prompt": prompt} + if image is not None: + from PIL import Image + + row["multi_modal_data"] = { + "image": Image.open(BytesIO(base64.b64decode(image))) + } + prompts.append(row) + + generation_kwargs = { + "n": request.n, + "repetition_penalty": request.repetition_penalty, + "temperature": request.temperature, + "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, + "max_tokens": request.max_tokens, + "logprobs": request.logprobs, + } + generation_kwargs.update(request.generation_kwargs) + + if Version(vllm.__version__) <= Version("0.10.2"): + key = "guided_decoding" + if request.structured_outputs_regex is not None: + generation_kwargs[key] = GuidedDecodingParams( + regex=request.structured_outputs_regex + ) + else: + generation_kwargs.setdefault(key, None) + else: + from vllm.sampling_params import StructuredOutputsParams + + key = "structured_outputs" + if request.structured_outputs_regex is not None: + generation_kwargs[key] = StructuredOutputsParams( + regex=request.structured_outputs_regex + ) + elif isinstance(generation_kwargs.get(key), dict): + generation_kwargs[key] = StructuredOutputsParams( + **generation_kwargs[key] + ) + else: + generation_kwargs.setdefault(key, None) + + sampling_params = SamplingParams(**generation_kwargs) + chunked_prompts = chunk_list(prompts, script_args.data_parallel_size) + + for conn, chunk in zip(connections, chunked_prompts, strict=True): + if not chunk: + chunk = [{"prompt": ""}] + kwargs = { + "prompts": chunk, + "sampling_params": sampling_params, + "lora_request": active_lora["request"], + } + conn.send({"type": "call", "method": "generate", "kwargs": kwargs}) + + all_outputs = [conn.recv() for conn in connections] + all_outputs = [ + o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c + ] + all_outputs = list(chain.from_iterable(all_outputs)) + + return { + "prompt_ids": [o.prompt_token_ids for o in all_outputs], + "completion_ids": [ + list(out.token_ids) for o in all_outputs for out in o.outputs + ], + "logprobs": extract_logprobs(all_outputs)[0], + "logprob_token_ids": extract_logprobs(all_outputs)[1], + } + + @app.post("/chat/", response_model=ChatResponse) + async def chat(request: ChatRequest): + """Chat endpoint with optional LoRA adapter.""" + generation_kwargs = { + "n": request.n, + "repetition_penalty": request.repetition_penalty, + "temperature": request.temperature, + "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, + "max_tokens": request.max_tokens, + "logprobs": request.logprobs, + } + generation_kwargs.update(request.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) + chunked = chunk_list(request.messages, script_args.data_parallel_size) + for conn, chunk in zip(connections, chunked, strict=True): + if not chunk: + chunk = [[{"role": "user", "content": ""}]] + kwargs = { + "messages": chunk, + "sampling_params": sampling_params, + "use_tqdm": False, + "lora_request": active_lora["request"], + } + conn.send({"type": "call", "method": "chat", "kwargs": kwargs}) + + all_outputs = [conn.recv() for conn in connections] + all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c] + all_outputs = list(chain.from_iterable(all_outputs)) + + return { + "prompt_ids": [o.prompt_token_ids for o in all_outputs], + "completion_ids": [ + list(out.token_ids) for o in all_outputs for out in o.outputs + ], + "logprobs": extract_logprobs(all_outputs)[0], + "logprob_token_ids": extract_logprobs(all_outputs)[1], + } + + # --- Weight sync endpoints (legacy fallback, same as TRL) --- + + @app.post("/init_communicator/") + async def init_communicator(request: InitCommunicatorRequest): + world_size = ( + script_args.tensor_parallel_size * script_args.data_parallel_size + 1 + ) + kwargs = { + "method": "init_communicator", + "args": ( + request.host, + request.port, + world_size, + request.client_device_uuid, + ), + } + msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs} + loop = asyncio.get_running_loop() + await asyncio.gather( + *(loop.run_in_executor(None, c.send, msg) for c in connections) + ) + return {"message": "Initializing communicator"} + + class UpdateWeightsRequest(BaseModel): + name: str + dtype: str + shape: list[int] + + @app.post("/update_named_param/") + async def update_named_param(request: UpdateWeightsRequest): + kwargs = { + "method": "update_named_param", + "args": (request.name, request.dtype, tuple(request.shape)), + } + msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs} + loop = asyncio.get_running_loop() + await asyncio.gather( + *(loop.run_in_executor(None, c.send, msg) for c in connections) + ) + return {"message": "Updating parameter"} + + class BatchUpdateWeightsRequest(BaseModel): + params: list[dict] + + @app.post("/batch_update_named_params/") + async def batch_update_named_params(request: BatchUpdateWeightsRequest): + params_list = [ + (p["name"], p["dtype"], tuple(p["shape"])) for p in request.params + ] + kwargs = {"method": "batch_update_named_params", "args": (params_list,)} + msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs} + loop = asyncio.get_running_loop() + await asyncio.gather( + *(loop.run_in_executor(None, c.send, msg) for c in connections) + ) + return {"message": f"Batch update for {len(params_list)} params"} + + @app.post("/reset_prefix_cache/") + async def reset_prefix_cache(): + for conn in connections: + conn.send({"type": "call", "method": "reset_prefix_cache"}) + results = [conn.recv() for conn in connections] + return {"message": f"Reset prefix cache: {all(results)}"} + + @app.post("/close_communicator/") + async def close_communicator(): + kwargs = {"method": "close_communicator"} + for conn in connections: + conn.send( + { + "type": "fire_and_forget", + "method": "collective_rpc", + "kwargs": kwargs, + } + ) + return {"message": "Closing communicator"} + + uvicorn.run( + app, + host=script_args.host, + port=script_args.port, + log_level=script_args.log_level, + access_log=True, + ) diff --git a/src/axolotl/scripts/vllm_worker_ext.py b/src/axolotl/scripts/vllm_worker_ext.py new file mode 100644 index 000000000..386460df1 --- /dev/null +++ b/src/axolotl/scripts/vllm_worker_ext.py @@ -0,0 +1,158 @@ +"""Extended vLLM worker extension with batch weight sync support. + +Subclasses TRL's WeightSyncWorkerExtension to add: +- batch_update_named_params: receives multiple params in one call +- Auto-close stale communicator on re-init +- _direct_set_weight: proper handling for stacked (qkv_proj, gate_up_proj) params, + including LoRA-wrapped models where vLLM inserts base_layer into the hierarchy +""" + +import logging + +import torch + +try: + from transformers import is_torch_xpu_available +except ImportError: + is_torch_xpu_available = lambda: False # noqa: E731 + +from trl.scripts.vllm_serve import WeightSyncWorkerExtension + +logger = logging.getLogger(__name__) + +# Stacked param name mapping: shard_name -> (packed_name, shard_order) +_STACKED_PARAMS = { + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), +} + + +class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension): + """Worker extension that adds batch weight update and direct weight setting.""" + + def init_communicator(self, host, port, world_size, client_device_uuid): + """Auto-close stale communicator before re-initializing.""" + if self.communicator is not None: + self.close_communicator() + super().init_communicator(host, port, world_size, client_device_uuid) + + def _direct_set_weight(self, name: str, weight: torch.Tensor) -> None: + """Directly copy weight data into the model, handling stacked params. + + Bypasses model.load_weights() which may fail on vLLM 0.17's new + module-tree weight loader for stacked params (qkv_proj, gate_up_proj). + + Handles LoRA-wrapped params where vLLM inserts ``base_layer`` into the + parameter hierarchy (e.g. ``qkv_proj.base_layer.weight``). + """ + model = self.model_runner.model + params_dict = dict(model.named_parameters()) + + # Check if this is a simple direct param (exists as-is) + if name in params_dict: + params_dict[name].data.copy_(weight.to(params_dict[name].dtype)) + return + + # Also check with base_layer inserted: x.y.weight -> x.y.base_layer.weight + parts_bl = name.rsplit(".", 1) + if len(parts_bl) == 2: + base_layer_name = f"{parts_bl[0]}.base_layer.{parts_bl[1]}" + if base_layer_name in params_dict: + params_dict[base_layer_name].data.copy_( + weight.to(params_dict[base_layer_name].dtype) + ) + return + + # Handle stacked params: e.g. "model.layers.0.self_attn.q_proj.weight" + # -> "model.layers.0.self_attn.qkv_proj.weight" with shard offset + parts = name.rsplit(".", 2) # [prefix, layer_name, suffix] + if len(parts) == 3: + prefix, layer_name, suffix = parts + if layer_name in _STACKED_PARAMS: + packed_name, shard_idx = _STACKED_PARAMS[layer_name] + for packed_full in [ + f"{prefix}.{packed_name}.{suffix}", + f"{prefix}.{packed_name}.base_layer.{suffix}", + ]: + if packed_full not in params_dict: + continue + param = params_dict[packed_full] + # Navigate to the packed module to find shard sizes + module_path = packed_full.rsplit(".", 1)[0] # strip .weight/.bias + if ".base_layer" in module_path: + module_path = module_path.replace(".base_layer", "") + module = model + for attr in module_path.split("."): + module = getattr(module, attr, None) + if module is None: + break + # LoRA wrappers don't have output_sizes directly; + # check base_layer for the underlying parallel linear + if module is not None and not hasattr(module, "output_sizes"): + base = getattr(module, "base_layer", None) + if base is not None and hasattr(base, "output_sizes"): + module = base + if module is not None and hasattr(module, "output_sizes"): + tp_size = getattr(module, "tp_size", 1) + sizes = [s // tp_size for s in module.output_sizes] + offset = sum(sizes[:shard_idx]) + shard_size = sizes[shard_idx] + param.data[offset : offset + shard_size].copy_( + weight.to(param.dtype) + ) + return + + # Fallback: try load_weights (may work for non-stacked params) + logger.warning("Falling back to load_weights for param: %s", name) + model.load_weights(weights=[(name, weight)]) + + def update_named_param(self, name, dtype, shape): + """Override to use _direct_set_weight instead of load_weights.""" + if self.communicator is None: + raise RuntimeError("Communicator not initialized.") + + dtype = getattr(torch, dtype.split(".")[-1]) + weight = torch.empty(shape, dtype=dtype, device=self.device) + + if is_torch_xpu_available(): + self.communicator.broadcast(weight, root=self.client_rank) + self.communicator.barrier() + else: + self.communicator.broadcast(weight, src=self.client_rank) + self.communicator.group.barrier() + + self._direct_set_weight(name, weight) + + def batch_update_named_params(self, params_list: list[tuple[str, str, tuple]]): + """Receive and apply multiple weight tensors in sequence. + + Args: + params_list: List of (name, dtype_str, shape) tuples. + """ + if self.communicator is None: + raise RuntimeError("Communicator not initialized.") + + weights_to_load = [] + for name, dtype_str, shape in params_list: + dtype = getattr(torch, dtype_str.split(".")[-1]) + weight = torch.empty(shape, dtype=dtype, device=self.device) + + if is_torch_xpu_available(): + self.communicator.broadcast(weight, root=self.client_rank) + else: + self.communicator.broadcast(weight, src=self.client_rank) + + weights_to_load.append((name, weight)) + + # Single barrier after all broadcasts + if is_torch_xpu_available(): + self.communicator.barrier() + else: + self.communicator.group.barrier() + + # Load weights using direct set (handles stacked params) + for name, weight in weights_to_load: + self._direct_set_weight(name, weight) diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index ff96f44ce..2d7c36f96 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -189,3 +189,125 @@ class TRLConfig(BaseModel): "'normalize_then_sum' (GDPO): normalizes each reward independently, then sums." }, ) + + # Async GRPO fields + use_data_producer: bool = Field( + default=False, + json_schema_extra={ + "description": "Use the GRPODataProducer protocol for online data generation." + }, + ) + async_prefetch: bool = Field( + default=False, + json_schema_extra={ + "description": "Generate rollouts in a background thread while training on the previous rollout." + }, + ) + prefetch_depth: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of rollouts to prefetch ahead of training." + }, + ) + vllm_sync_interval: int | None = Field( + default=None, + json_schema_extra={ + "description": "Sync model weights to vLLM every N optimizer steps (async mode only)." + }, + ) + streaming_partial_batch: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Score prompt groups incrementally instead of the full batch at once." + }, + ) + streaming_min_groups: int | None = Field( + default=None, + json_schema_extra={ + "description": "Minimum prompt groups to score per streaming chunk." + }, + ) + vllm_importance_sampling_correction: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Apply IS correction for distribution mismatch between vLLM and training model." + }, + ) + vllm_importance_sampling_mode: ( + Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"] + | None + ) = Field( + default=None, + json_schema_extra={ + "description": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask." + }, + ) + vllm_importance_sampling_cap: float | None = Field( + default=None, + json_schema_extra={"description": "Cap C for IS ratio clipping/masking."}, + ) + off_policy_mask_threshold: float | None = Field( + default=None, + json_schema_extra={ + "description": "KL threshold for off-policy sequence masking (OPSM). None = disabled." + }, + ) + use_bias_correction_kl: bool | None = Field( + default=None, + json_schema_extra={"description": "Apply IS correction to KL divergence term."}, + ) + + reward_num_workers: int = Field( + default=1, + json_schema_extra={ + "description": "Number of persistent subprocess workers for parallel reward computation. Each worker has its " + "own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across " + "workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions." + }, + ) + replay_buffer_size: int = Field( + default=0, + json_schema_extra={ + "description": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout " + "groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups " + "(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True." + }, + ) + replay_recompute_logps: bool = Field( + default=True, + json_schema_extra={ + "description": "When True (default), recompute old_per_token_logps for replayed groups using the current " + "training model. This fixes the importance sampling mismatch that occurs when replaying stale data. " + "Only relevant when replay_buffer_size > 0." + }, + ) + reroll_start_fraction: float = Field( + default=1.0, + json_schema_extra={ + "description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts " + "(where all rewards in a group are identical) are buffered and re-injected into later batches when the " + "model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True." + }, + ) + reroll_max_groups: int = Field( + default=1, + json_schema_extra={ + "description": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values " + "increase data utilization but reduce prompt diversity. Only used with use_data_producer=True." + }, + ) + skip_zero_advantage_batches: bool = Field( + default=True, + json_schema_extra={ + "description": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning " + "signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is " + "logged with skipped_zero_adv_batches=1 for monitoring." + }, + ) + vllm_lora_sync: bool = Field( + default=False, + json_schema_extra={ + "description": "Sync LoRA adapter to vLLM via filesystem instead of merging + NCCL broadcast. " + "Auto-selects vllm_serve_lora serve module. Syncs only LoRA adapter weights vs full merged model." + }, + ) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 3ba484fec..50b30cd26 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -675,20 +675,6 @@ class LoRAValidationMixin: ) return data - @model_validator(mode="before") - @classmethod - def check_lora_kernels_rl(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ) and data.get("rl"): - raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " - "compatible with RL at the moment." - ) - return data - @model_validator(mode="before") @classmethod def check_lora_kernels_trust_remote_code(cls, data): diff --git a/src/axolotl/utils/schemas/vllm.py b/src/axolotl/utils/schemas/vllm.py index 518b8f62d..c0aa48d66 100644 --- a/src/axolotl/utils/schemas/vllm.py +++ b/src/axolotl/utils/schemas/vllm.py @@ -57,3 +57,10 @@ class VllmConfig(BaseModel): default=None, json_schema_extra={"description": "Reasoning parser for VLLM"}, ) + serve_module: str | None = Field( + default=None, + json_schema_extra={ + "description": "Python module for vLLM serve script. Set to 'axolotl.scripts.vllm_serve_lora' " + "for native LoRA support, or leave None for default TRL serve." + }, + ) diff --git a/tests/core/test_async_grpo.py b/tests/core/test_async_grpo.py new file mode 100644 index 000000000..eb83be1b6 --- /dev/null +++ b/tests/core/test_async_grpo.py @@ -0,0 +1,220 @@ +"""Unit tests for async GRPO""" + +import unittest +from unittest.mock import MagicMock + +import torch + + +class TestReplayBuffer(unittest.TestCase): + """Tests for ReplayBuffer edge cases.""" + + def test_add_noop_when_max_size_zero(self): + from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer + + buf = ReplayBuffer(max_size=0) + buf.add(1.0, {"data": "test"}) + self.assertEqual(len(buf), 0) + + def test_add_noop_when_max_size_negative(self): + from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer + + buf = ReplayBuffer(max_size=-1) + buf.add(1.0, {"data": "test"}) + self.assertEqual(len(buf), 0) + + def test_sample_returns_none_when_max_size_zero(self): + from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer + + buf = ReplayBuffer(max_size=0) + self.assertIsNone(buf.sample(1)) + + def test_sample_returns_none_when_empty(self): + from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer + + buf = ReplayBuffer(max_size=5) + self.assertIsNone(buf.sample(1)) + + def test_normal_add_and_sample(self): + from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer + + buf = ReplayBuffer(max_size=3) + buf.add(1.0, {"a": 1}) + buf.add(2.0, {"a": 2}) + buf.add(3.0, {"a": 3}) + self.assertEqual(len(buf), 3) + result = buf.sample(1) + self.assertIsNotNone(result) + self.assertEqual(len(result), 1) + + def test_replaces_lowest_when_full(self): + from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer + + buf = ReplayBuffer(max_size=2) + buf.add(1.0, {"a": 1}) + buf.add(2.0, {"a": 2}) + buf.add(3.0, {"a": 3}) # should replace score=1.0 + self.assertEqual(len(buf), 2) + scores = sorted(item[0] for item in buf._heap) + self.assertEqual(scores, [2.0, 3.0]) + + +class TestGRPOStrategyConflict(unittest.TestCase): + """Tests for sequence_parallel + async_grpo conflict detection.""" + + def test_raises_on_both_enabled(self): + from axolotl.core.trainers.grpo import GRPOStrategy + + with self.assertRaises(ValueError) as ctx: + GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=True) + self.assertIn("sequence_parallel", str(ctx.exception)) + self.assertIn("async_grpo", str(ctx.exception)) + + def test_sequence_parallel_only(self): + from axolotl.core.trainers.grpo import GRPOStrategy + from axolotl.core.trainers.grpo.trainer import ( + AxolotlGRPOSequenceParallelTrainer, + ) + + cls = GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=False) + self.assertIs(cls, AxolotlGRPOSequenceParallelTrainer) + + def test_async_only(self): + from axolotl.core.trainers.grpo import GRPOStrategy + from axolotl.core.trainers.grpo.trainer import AxolotlAsyncGRPOTrainer + + cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=True) + self.assertIs(cls, AxolotlAsyncGRPOTrainer) + + def test_neither(self): + from axolotl.core.trainers.grpo import GRPOStrategy + from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer + + cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=False) + self.assertIs(cls, AxolotlGRPOTrainer) + + +class TestDequantizeFP8TailBlocks(unittest.TestCase): + """Tests for FP8 dequantization with non-divisible dimensions.""" + + def test_exact_divisible_shape(self): + from axolotl.kernels.quantize import dequantize_fp8 + + W = torch.randn(256, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn) + scale_inv = torch.ones(2, 1, dtype=torch.bfloat16) + result = dequantize_fp8(W, scale_inv) + self.assertEqual(result.shape, (256, 128)) + self.assertEqual(result.dtype, torch.bfloat16) + + def test_non_divisible_rows(self): + from axolotl.kernels.quantize import dequantize_fp8 + + # 130 rows, scale has 2 blocks (block_size ~65 for exact div, but with + # tail blocks: first block=65 rows, second=65 rows, 130%2=0 actually). + # Use 131 rows with 2 scale blocks to trigger tail handling. + W = torch.ones(131, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn) + scale_inv = torch.tensor([[2.0], [3.0]], dtype=torch.bfloat16) + result = dequantize_fp8(W, scale_inv) + self.assertEqual(result.shape, (131, 128)) + self.assertEqual(result.dtype, torch.bfloat16) + + def test_non_divisible_cols(self): + from axolotl.kernels.quantize import dequantize_fp8 + + W = torch.ones(128, 200, dtype=torch.bfloat16).to(torch.float8_e4m3fn) + scale_inv = torch.ones(1, 2, dtype=torch.bfloat16) + result = dequantize_fp8(W, scale_inv) + self.assertEqual(result.shape, (128, 200)) + + def test_scalar_scale(self): + from axolotl.kernels.quantize import dequantize_fp8 + + W = torch.ones(64, 64, dtype=torch.bfloat16).to(torch.float8_e4m3fn) + scale_inv = torch.tensor(2.0, dtype=torch.bfloat16) + result = dequantize_fp8(W, scale_inv) + self.assertEqual(result.shape, (64, 64)) + + +class TestLoraFP8Guard(unittest.TestCase): + """Tests that get_lora_parameters only uses weight_scale_inv for FP8 weights.""" + + def test_non_fp8_weight_skips_scale_inv(self): + """Non-FP8 weight should NOT pick up weight_scale_inv as quant_state.""" + from axolotl.kernels.lora import get_lora_parameters + + proj = MagicMock() + proj.disable_adapters = True + base_layer = MagicMock(spec=[]) # empty spec to control attrs precisely + + # Use a real tensor for weight (bf16, no quant_state attr) + base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16) + base_layer.bias = None + base_layer.weight_scale_inv = torch.ones(1) # should NOT be used for bf16 + + proj.base_layer = base_layer + + W, b, quant_state, A, B, s = get_lora_parameters(proj) + # quant_state should be None since weight is bf16, not FP8 + self.assertIsNone(quant_state) + + def test_fp8_weight_uses_scale_inv(self): + """FP8 weight should pick up weight_scale_inv as quant_state.""" + from axolotl.kernels.lora import get_lora_parameters + + proj = MagicMock() + proj.disable_adapters = True + base_layer = MagicMock() + proj.base_layer = base_layer + + # FP8 weight + base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16).to( + torch.float8_e4m3fn + ) + base_layer.bias = None + scale_inv = torch.ones(1) + base_layer.weight_scale_inv = scale_inv + + W, b, quant_state, A, B, s = get_lora_parameters(proj) + self.assertIs(quant_state, scale_inv) + + +class TestValidateQuantPatchRestore(unittest.TestCase): + """Test that validate_quantization_for_training is restored after trainer creation.""" + + def test_patch_restored_on_success(self): + """Monkeypatch should be restored even after successful trainer creation.""" + import transformers.trainer as _trainer_module + + original = _trainer_module.validate_quantization_for_training + + # After the build() method runs, original should be restored. + # We can't easily test the full build(), but we can test the pattern. + _orig = _trainer_module.validate_quantization_for_training + _trainer_module.validate_quantization_for_training = lambda model: None + try: + pass # simulate trainer_cls() succeeding + finally: + _trainer_module.validate_quantization_for_training = _orig + + self.assertIs(_trainer_module.validate_quantization_for_training, original) + + def test_patch_restored_on_error(self): + """Monkeypatch should be restored even if trainer creation raises.""" + import transformers.trainer as _trainer_module + + original = _trainer_module.validate_quantization_for_training + + _orig = _trainer_module.validate_quantization_for_training + _trainer_module.validate_quantization_for_training = lambda model: None + try: + raise ValueError("test error") + except ValueError: + pass + finally: + _trainer_module.validate_quantization_for_training = _orig + + self.assertIs(_trainer_module.validate_quantization_for_training, original) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/monkeypatch/test_trl_vllm.py b/tests/monkeypatch/test_trl_vllm.py new file mode 100644 index 000000000..755ac2577 --- /dev/null +++ b/tests/monkeypatch/test_trl_vllm.py @@ -0,0 +1,286 @@ +"""Unit tests for TRL vLLM monkeypatches. + +Tests: +- split_tensor_dict: scalar type preservation (int/float/bool) +- shuffle_sequence_dict: scalar type preservation +- extract_logprobs: NaN → 0.0 replacement +- VLLMClient.batch_update_named_params: method exists after patch +- VLLMGeneration: weight_sync_chunk_size attribute after patch +- Patch idempotency: applying patch twice doesn't break anything +""" + +import unittest +from dataclasses import dataclass +from unittest.mock import MagicMock + +import torch + + +class TestSplitTensorDict(unittest.TestCase): + """Tests for patched split_tensor_dict.""" + + def setUp(self): + from axolotl.monkeypatch.trainer.trl_vllm import _patched_split_tensor_dict + + self.split = _patched_split_tensor_dict + + def test_scalar_int_preserved(self): + d = {"a": torch.randn(4, 3), "count": 42} + chunks = self.split(d, 2) + self.assertEqual(len(chunks), 2) + self.assertEqual(chunks[0]["count"], 42) + self.assertEqual(chunks[1]["count"], 42) + + def test_scalar_float_preserved(self): + d = {"a": torch.randn(6, 2), "lr": 1e-5} + chunks = self.split(d, 3) + for c in chunks: + self.assertEqual(c["lr"], 1e-5) + + def test_scalar_bool_preserved(self): + d = {"a": torch.randn(4, 2), "flag": True} + chunks = self.split(d, 2) + for c in chunks: + self.assertTrue(c["flag"]) + + def test_none_preserved(self): + d = {"a": torch.randn(4, 2), "b": None} + chunks = self.split(d, 2) + for c in chunks: + self.assertIsNone(c["b"]) + + def test_tensor_split(self): + t = torch.arange(8).reshape(4, 2) + d = {"a": t, "n": 10} + chunks = self.split(d, 2) + self.assertEqual(chunks[0]["a"].shape, (2, 2)) + self.assertEqual(chunks[1]["a"].shape, (2, 2)) + torch.testing.assert_close(chunks[0]["a"], t[:2]) + torch.testing.assert_close(chunks[1]["a"], t[2:]) + + def test_0d_tensor_preserved(self): + d = {"a": torch.randn(4, 2), "scalar_t": torch.tensor(3.14)} + chunks = self.split(d, 2) + for c in chunks: + self.assertAlmostEqual(c["scalar_t"].item(), 3.14, places=5) + + def test_list_split(self): + d = {"a": torch.randn(4, 2), "names": ["a", "b", "c", "d"]} + chunks = self.split(d, 2) + self.assertEqual(chunks[0]["names"], ["a", "b"]) + self.assertEqual(chunks[1]["names"], ["c", "d"]) + + +class TestShuffleSequenceDict(unittest.TestCase): + """Tests for patched shuffle_sequence_dict.""" + + def setUp(self): + from axolotl.monkeypatch.trainer.trl_vllm import _patched_shuffle_sequence_dict + + self.shuffle = _patched_shuffle_sequence_dict + + def test_scalar_int_preserved(self): + d = {"a": torch.randn(4, 3), "count": 42} + result = self.shuffle(d) + self.assertEqual(result["count"], 42) + + def test_scalar_float_preserved(self): + d = {"a": torch.randn(4, 3), "lr": 1e-5} + result = self.shuffle(d) + self.assertEqual(result["lr"], 1e-5) + + def test_scalar_bool_preserved(self): + d = {"a": torch.randn(4, 3), "flag": False} + result = self.shuffle(d) + self.assertFalse(result["flag"]) + + def test_none_preserved(self): + d = {"a": torch.randn(4, 3), "b": None} + result = self.shuffle(d) + self.assertIsNone(result["b"]) + + def test_tensor_permuted(self): + torch.manual_seed(42) + t = torch.arange(4).float() + d = {"a": t} + result = self.shuffle(d) + # Same elements, possibly different order + self.assertEqual(sorted(result["a"].tolist()), sorted(t.tolist())) + self.assertEqual(result["a"].shape, t.shape) + + def test_list_permuted(self): + torch.manual_seed(42) + d = {"a": torch.randn(3, 2), "names": ["x", "y", "z"]} + result = self.shuffle(d) + self.assertEqual(sorted(result["names"]), ["x", "y", "z"]) + self.assertEqual(len(result["names"]), 3) + + def test_0d_tensor_preserved(self): + d = {"a": torch.randn(4, 2), "scalar_t": torch.tensor(3.14)} + result = self.shuffle(d) + self.assertAlmostEqual(result["scalar_t"].item(), 3.14, places=5) + + +class TestExtractLogprobs(unittest.TestCase): + """Tests for patched extract_logprobs (NaN → 0.0).""" + + def setUp(self): + from axolotl.monkeypatch.trainer.trl_vllm import _patched_extract_logprobs + + self.extract = _patched_extract_logprobs + + def _make_output(self, logprob_values): + """Create a mock vLLM RequestOutput with given logprob values.""" + + @dataclass + class LogprobItem: + logprob: float + rank: int + + @dataclass + class SeqOutput: + logprobs: list[dict[int, LogprobItem]] | None + + @dataclass + class RequestOutput: + outputs: list[SeqOutput] + + logprobs_list = [] + for vals in logprob_values: + lp_dict = {i: LogprobItem(logprob=v, rank=i) for i, v in enumerate(vals)} + logprobs_list.append(lp_dict) + + return RequestOutput(outputs=[SeqOutput(logprobs=logprobs_list)]) + + def test_nan_replaced_with_zero(self): + output = self._make_output([[float("nan"), 0.5], [-0.3, float("nan")]]) + logprobs, token_ids = self.extract([output]) + self.assertEqual(logprobs[0][0][0], 0.0) # NaN → 0.0 + self.assertEqual(logprobs[0][0][1], 0.5) + self.assertEqual(logprobs[0][1][0], -0.3) + self.assertEqual(logprobs[0][1][1], 0.0) # NaN → 0.0 + + def test_normal_values_preserved(self): + output = self._make_output([[-0.5, -1.2], [-0.1, -2.0]]) + logprobs, token_ids = self.extract([output]) + self.assertAlmostEqual(logprobs[0][0][0], -0.5) + self.assertAlmostEqual(logprobs[0][0][1], -1.2) + + def test_none_logprobs_returns_none(self): + @dataclass + class SeqOutput: + logprobs: None = None + + @dataclass + class RequestOutput: + outputs: list + + output = RequestOutput(outputs=[SeqOutput()]) + logprobs, token_ids = self.extract([output]) + self.assertIsNone(logprobs) + self.assertIsNone(token_ids) + + def test_token_ids_extracted(self): + output = self._make_output([[-0.5]]) + logprobs, token_ids = self.extract([output]) + self.assertEqual(token_ids[0][0], [0]) # token_id=0 from enumerate + + +class TestPatchApplication(unittest.TestCase): + """Tests for patch_trl_vllm() application.""" + + def test_batch_update_added_to_client(self): + from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm + + patch_trl_vllm() + from trl.generation.vllm_client import VLLMClient + + self.assertTrue(hasattr(VLLMClient, "batch_update_named_params")) + + def test_extract_logprobs_patched(self): + from axolotl.monkeypatch.trainer.trl_vllm import ( + _patched_extract_logprobs, + patch_trl_vllm, + ) + + patch_trl_vllm() + from trl.generation import vllm_generation + + self.assertIs(vllm_generation.extract_logprobs, _patched_extract_logprobs) + + def test_utils_patched(self): + from axolotl.monkeypatch.trainer.trl_vllm import ( + _patched_shuffle_sequence_dict, + _patched_split_tensor_dict, + patch_trl_vllm, + ) + + patch_trl_vllm() + import trl.trainer.utils + + self.assertIs(trl.trainer.utils.split_tensor_dict, _patched_split_tensor_dict) + self.assertIs( + trl.trainer.utils.shuffle_sequence_dict, _patched_shuffle_sequence_dict + ) + + def test_patch_idempotent(self): + from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm + + patch_trl_vllm() + patch_trl_vllm() # second call should not error + from trl.generation.vllm_client import VLLMClient + + self.assertTrue(hasattr(VLLMClient, "batch_update_named_params")) + + +class TestBatchUpdateChunking(unittest.TestCase): + """Tests for batch_update_named_params chunking logic.""" + + def test_no_chunk_single_batch(self): + from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params + + # Test that with chunk_size=None, all params go in one chunk + client = MagicMock() + client.base_url = "http://localhost:8000" + client.session.post.return_value = MagicMock(status_code=200) + client.communicator = MagicMock() + client.communicator.group = MagicMock() + client.rank = 0 + + params = [ + ("layer.0.weight", torch.randn(10, 10)), + ("layer.1.weight", torch.randn(10, 10)), + ] + _batch_update_named_params(client, params, chunk_size=None) + + # Should make exactly 1 HTTP call + self.assertEqual(client.session.post.call_count, 1) + + def test_chunk_splits_params(self): + from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params + + client = MagicMock() + client.base_url = "http://localhost:8000" + client.session.post.return_value = MagicMock(status_code=200) + client.communicator = MagicMock() + client.communicator.group = MagicMock() + client.rank = 0 + + params = [ + ("a", torch.randn(100)), # 100 elements + ("b", torch.randn(100)), # 100 elements + ("c", torch.randn(100)), # 100 elements + ] + _batch_update_named_params(client, params, chunk_size=150) + + # Should make 2 HTTP calls: [a,b] then [c] (100+100 > 150 triggers split) + # Actually: a=100 < 150, a+b=200 > 150 → chunk [a], then b=100 < 150, + # b+c=200 > 150 → chunk [b], then [c]. So 3 calls. + # Wait: first a added (100 < 150), then b: 100+100=200 > 150, so chunk=[a], + # new chunk starts with b (100 < 150), then c: 100+100=200 > 150, so chunk=[b], + # final chunk=[c]. 3 HTTP calls. + self.assertEqual(client.session.post.call_count, 3) + + +if __name__ == "__main__": + unittest.main()