Support for Async GRPO (#3486)

* async grpo support

* implement data producer

* use fast async

* handle call to create data producer

* fix liger kernel setup

* fix replay buffer

* chore: lint

* make gpus go brrr

* chore: lint

* inplace div_, unwrap model for logits in bf16

* fuse selective softmax and empty cuda cache on each scoring step

* remove waiting for synch time and fix race

* make fp8 work and allow lora kernels w rl

* grpo with lora vllm sync and fixes for sharded distributed

* update docs

* more patches so it works against trl main

* address PR feedback for corerabbit
This commit is contained in:
Wing Lian
2026-03-17 11:42:47 -04:00
committed by GitHub
parent 999b3fec2e
commit 5ef3f28340
23 changed files with 5474 additions and 36 deletions

View File

@@ -721,6 +721,213 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
#### Async GRPO
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
```yaml
trl:
use_data_producer: true # Enable data producer protocol
use_vllm: true
async_prefetch: true # Generate rollouts in background thread
prefetch_depth: 1 # Number of rollouts to prefetch
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
```
::: {.callout-note}
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
:::
##### vLLM LoRA Sync
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
```yaml
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
trl:
vllm_lora_sync: true # Enable native LoRA sync
```
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
```bash
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
```
Then start training on a separate GPU:
```bash
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
::: {.callout-tip}
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
:::
##### Streaming Partial Batch
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
```yaml
trl:
streaming_partial_batch: true
```
##### Importance Sampling Correction
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
```yaml
trl:
vllm_importance_sampling_correction: true # Enable IS correction
importance_sampling_level: token # 'token' or 'sequence'
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
```
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
- `importance_sampling_level: sequence` applies per-sequence IS ratios
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
##### Replay Buffer
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
```yaml
trl:
replay_buffer_size: 100 # Max cached groups (0 = disabled)
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
```
::: {.callout-note}
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
:::
##### Deferred Re-rolling
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
```yaml
trl:
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
reroll_max_groups: 1 # Max groups to replace per batch
```
##### Zero-Advantage Batch Skipping
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
```yaml
trl:
skip_zero_advantage_batches: true # default
```
##### Parallel Reward Workers
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
```yaml
trl:
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
```
##### Full Async GRPO Example
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.35
dtype: auto
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
rl: grpo
trl:
use_data_producer: true
use_vllm: true
async_prefetch: true
prefetch_depth: 1
vllm_sync_interval: 2
vllm_lora_sync: true
streaming_partial_batch: true
vllm_importance_sampling_correction: true
off_policy_mask_threshold: 0.5
importance_sampling_level: token
num_generations: 8
max_completion_length: 512
reward_funcs:
- rewards.accuracy_reward
reroll_start_fraction: 0.5
replay_buffer_size: 100
reward_num_workers: 4
skip_zero_advantage_batches: true
datasets:
- path: AI-MO/NuminaMath-TIR
type: rewards.prompt_transform
split: train
gradient_accumulation_steps: 4
micro_batch_size: 2
max_steps: 500
learning_rate: 1e-5
bf16: true
gradient_checkpointing: true
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
##### Multi-GPU Async GRPO
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
**FSDP:**
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
use_reentrant: false
```
**DeepSpeed ZeRO-3:**
```yaml
deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
use_reentrant: true # Required for ZeRO-3
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPUs 0,1
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
```
::: {.callout-important}
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
:::
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.

View File

@@ -38,7 +38,18 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
serve_module = cli_args.get("serve_module") or getattr(
cfg.vllm, "serve_module", None
)
if (
serve_module is None
and getattr(cfg, "trl", None)
and getattr(cfg.trl, "vllm_lora_sync", False)
):
serve_module = "axolotl.scripts.vllm_serve_lora"
if serve_module is None:
serve_module = "trl.scripts.vllm_serve"
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1
data_parallel_size = 1
@@ -68,7 +79,7 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
vllm_script_args = AxolotlScriptArguments(
base_kwargs = dict(
model=model,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
@@ -78,7 +89,21 @@ def do_vllm_serve(
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
# Use LoRAScriptArguments when serving with native LoRA support
if serve_module == "axolotl.scripts.vllm_serve_lora":
from axolotl.scripts.vllm_serve_lora import LoRAScriptArguments
lora_kwargs = {}
if hasattr(cfg, "lora_r") and cfg.lora_r:
lora_kwargs["max_lora_rank"] = cfg.lora_r
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
else:
vllm_script_args = AxolotlScriptArguments(
**base_kwargs,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)

View File

@@ -54,8 +54,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
async_grpo = bool(
self.cfg.trl
and (
getattr(self.cfg.trl, "async_prefetch", False)
or getattr(self.cfg.trl, "use_data_producer", False)
)
)
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
sequence_parallel=self.cfg.context_parallel_size > 1,
async_grpo=async_grpo,
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
@@ -151,7 +159,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class()
async_grpo = bool(
self.cfg.trl
and (
getattr(self.cfg.trl, "async_prefetch", False)
or getattr(self.cfg.trl, "use_data_producer", False)
)
)
training_args_cls = GRPOStrategy.get_training_args_class(
async_grpo=async_grpo
)
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
@@ -217,13 +234,36 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
# Allow FP8-quantized models to be fine-tuned with LoRA adapters.
# transformers' validate_quantization_for_training blocks FP8 because
# hf_quantizer.is_trainable is False, but LoRA only trains the adapters
# (base weights stay frozen in FP8).
_orig_validate_quant = None
if (
self.cfg.adapter
and hasattr(self.model, "is_quantized")
and self.model.is_quantized
):
import transformers.trainer as _trainer_module
_orig_validate_quant = _trainer_module.validate_quantization_for_training
_trainer_module.validate_quantization_for_training = lambda model: None
try:
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
finally:
if _orig_validate_quant is not None:
import transformers.trainer as _trainer_module
_trainer_module.validate_quantization_for_training = (
_orig_validate_quant
)
if self.cfg.fsdp_config or self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:

View File

@@ -9,8 +9,9 @@ from huggingface_hub import snapshot_download
from requests import HTTPError
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
from axolotl.core.trainers.grpo.trainer import (
AxolotlAsyncGRPOTrainer,
AxolotlGRPOSequenceParallelTrainer,
AxolotlGRPOTrainer,
)
@@ -27,14 +28,31 @@ class GRPOStrategy:
@classmethod
def get_trainer_class(
cls, sequence_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
cls,
sequence_parallel: bool,
async_grpo: bool = False,
) -> (
type[AxolotlGRPOTrainer]
| type[AxolotlGRPOSequenceParallelTrainer]
| type[AxolotlAsyncGRPOTrainer]
):
if sequence_parallel and async_grpo:
raise ValueError(
"sequence_parallel and async_grpo cannot both be enabled. "
"Disable one of context_parallel_size > 1 or async_prefetch/use_data_producer."
)
if sequence_parallel:
return AxolotlGRPOSequenceParallelTrainer
if async_grpo:
return AxolotlAsyncGRPOTrainer
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
def get_training_args_class(
cls, async_grpo: bool = False
) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
if async_grpo:
return AxolotlAsyncGRPOConfig
return AxolotlGRPOConfig
@classmethod
@@ -124,13 +142,63 @@ class GRPOStrategy:
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
# Async GRPO fields
if getattr(trl, "use_data_producer", None) is not None:
grpo_args_kwargs["use_data_producer"] = trl.use_data_producer
if getattr(trl, "async_prefetch", None) is not None:
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
if getattr(trl, "prefetch_depth", None) is not None:
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
if getattr(trl, "vllm_sync_interval", None) is not None:
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
if getattr(trl, "streaming_partial_batch", None) is not None:
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
if getattr(trl, "streaming_min_groups", None) is not None:
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_correction"] = (
trl.vllm_importance_sampling_correction
)
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_mode"] = (
trl.vllm_importance_sampling_mode
)
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_cap"] = (
trl.vllm_importance_sampling_cap
)
if getattr(trl, "off_policy_mask_threshold", None) is not None:
grpo_args_kwargs["off_policy_mask_threshold"] = (
trl.off_policy_mask_threshold
)
if getattr(trl, "use_bias_correction_kl", None) is not None:
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
# Fast Async GRPO fields
if getattr(trl, "reward_num_workers", None) is not None:
grpo_args_kwargs["reward_num_workers"] = trl.reward_num_workers
if getattr(trl, "replay_buffer_size", None) is not None:
grpo_args_kwargs["replay_buffer_size"] = trl.replay_buffer_size
if getattr(trl, "replay_recompute_logps", None) is not None:
grpo_args_kwargs["replay_recompute_logps"] = trl.replay_recompute_logps
if getattr(trl, "reroll_start_fraction", None) is not None:
grpo_args_kwargs["reroll_start_fraction"] = trl.reroll_start_fraction
if getattr(trl, "reroll_max_groups", None) is not None:
grpo_args_kwargs["reroll_max_groups"] = trl.reroll_max_groups
if getattr(trl, "skip_zero_advantage_batches", None) is not None:
grpo_args_kwargs["skip_zero_advantage_batches"] = (
trl.skip_zero_advantage_batches
)
if getattr(trl, "vllm_lora_sync", None) is not None:
grpo_args_kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
return grpo_args_kwargs
@classmethod

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from trl import GRPOConfig
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@@ -14,3 +15,10 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
context_parallel_size: int | None = None
@dataclass
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig):
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
context_parallel_size: int | None = None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,768 @@
# Copyright 2020-2026 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Experimental GRPO extensions: parallel reward workers, replay buffer,
deferred re-roll, and zero-advantage skipping.
These features are built as subclasses of GRPOTrainer and GRPODataProducer,
using the hook system (_compute_rewards_for_batch, _post_advantage_hook,
_pre_produce_hook) defined in the base classes.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from dataclasses import dataclass, field
import torch
from torch import nn
from trl import GRPOTrainer
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOConfig,
AsyncGRPOTrainer,
GRPODataProducer,
)
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Extended config
# ---------------------------------------------------------------------------
@dataclass
class FastAsyncGRPOConfig(AsyncGRPOConfig):
"""GRPOConfig with additional experimental parameters."""
reward_num_workers: int = field(
default=1,
metadata={
"help": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
},
)
replay_buffer_size: int = field(
default=0,
metadata={
"help": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
},
)
replay_recompute_logps: bool = field(
default=True,
metadata={
"help": "When True (default), recompute old_per_token_logps for replayed groups using the current "
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
"Only relevant when replay_buffer_size > 0."
},
)
reroll_start_fraction: float = field(
default=0.5,
metadata={
"help": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
},
)
reroll_max_groups: int = field(
default=1,
metadata={
"help": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
},
)
skip_zero_advantage_batches: bool = field(
default=True,
metadata={
"help": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
"logged with skipped_zero_adv_batches=1 for monitoring."
},
)
vllm_lora_sync: bool = field(
default=False,
metadata={
"help": "When True, sync LoRA adapter weights to vLLM via filesystem instead of merging into base model "
"and NCCL-broadcasting all parameters. vLLM loads the adapter natively using Punica kernels. "
"Requires vllm_serve_lora serve module (auto-selected when this is True). "
"Syncs only LoRA adapter weights (much smaller) vs full merged model. Legacy merge behavior is used when False."
},
)
# ---------------------------------------------------------------------------
# Extended data producer with re-roll injection
# ---------------------------------------------------------------------------
class RerollDataProducer(GRPODataProducer):
"""GRPODataProducer that injects re-roll candidates into prompt batches.
Reads from the trainer's ``_reroll_buffer`` (populated by
``GRPOExperimentalTrainer._post_advantage_hook``) and replaces the
last N prompt groups with previously-failed prompts.
"""
def _pre_produce_hook(self, inputs: list, global_step: int) -> list:
trainer = self._trainer
reroll_buf = getattr(trainer, "_reroll_buffer", None)
reroll_lock = getattr(trainer, "_reroll_lock", None)
if reroll_buf is None or reroll_lock is None:
return inputs
max_steps = getattr(trainer.args, "max_steps", -1)
start_frac = getattr(trainer.args, "reroll_start_fraction", 1.0)
max_groups = getattr(trainer.args, "reroll_max_groups", 1)
reroll_start_step = (
max(1, int(max_steps * start_frac)) if max_steps > 0 else float("inf")
)
if global_step < reroll_start_step:
return inputs
with reroll_lock:
n_to_take = min(max_groups, len(reroll_buf))
reroll_prompts = [reroll_buf.pop(0) for _ in range(n_to_take)]
if reroll_prompts:
num_gen = self._num_generations
n_groups = len(inputs) // num_gen
for i, reroll_prompt in enumerate(reroll_prompts):
group_idx = n_groups - 1 - i
if group_idx < 0:
break
start = group_idx * num_gen
for j in range(num_gen):
inputs[start + j] = reroll_prompt
logger.info(
f"[REROLL] Step {global_step}: replaced {len(reroll_prompts)}/{n_groups} prompt groups "
f"with deferred re-roll candidates ({len(reroll_buf)} remaining)"
)
return inputs
# ---------------------------------------------------------------------------
# Persistent reward subprocess pool
# ---------------------------------------------------------------------------
def _persistent_reward_worker(conn):
"""Long-lived reward worker. Receives work items, returns results."""
while True:
try:
msg = conn.recv()
except EOFError:
break
if msg is None: # Shutdown signal
break
(
reward_funcs,
prompts,
completions,
completion_ids_list,
inputs,
reward_func_names,
) = msg
try:
keys = [
key
for key in inputs[0]
if key not in ["prompt", "completion", "completion_ids"]
]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
results = []
for reward_func, _reward_func_name in zip(
reward_funcs, reward_func_names, strict=True
):
output = reward_func(
prompts=prompts,
completions=completions,
completion_ids=completion_ids_list,
**reward_kwargs,
)
results.append(
[float(r) if r is not None else float("nan") for r in output]
)
conn.send(results)
except Exception:
conn.send(None)
# ---------------------------------------------------------------------------
# Extended trainer
# ---------------------------------------------------------------------------
class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
"""GRPOTrainer with experimental extensions.
Adds:
- Parallel reward subprocess workers (``reward_num_workers``)
- Replay buffer for high-signal group reuse (``replay_buffer_size``)
- Deferred re-roll of failed prompts (``reroll_start_fraction``)
- Zero-advantage micro-batch skipping
"""
def __init__(self, *args, **kwargs):
# These must be initialized before super().__init__() because
# _create_data_producer (called during super().__init__) needs them.
self._reroll_buffer: list = []
self._reroll_lock = threading.Lock()
# Temporarily suppress the base class's Liger + OPSM validation check,
# since this subclass supports it via a custom compute_liger_loss override.
grpo_args = kwargs.get("args")
if grpo_args is None:
for a in args:
if hasattr(a, "off_policy_mask_threshold"):
grpo_args = a
break
saved_threshold = None
if grpo_args is not None and getattr(grpo_args, "use_liger_kernel", False):
saved_threshold = grpo_args.off_policy_mask_threshold
grpo_args.off_policy_mask_threshold = None
super().__init__(*args, **kwargs)
if saved_threshold is not None:
grpo_args.off_policy_mask_threshold = saved_threshold
self.off_policy_mask_threshold = saved_threshold
# Replay buffer
if getattr(self.args, "replay_buffer_size", 0) > 0:
self._replay_buffer = ReplayBuffer(max_size=self.args.replay_buffer_size)
else:
self._replay_buffer = None
self._replay_recompute_logps = getattr(
self.args, "replay_recompute_logps", True
)
# Reward worker pool (lazy-initialized)
self._reward_workers = None
# -- Factory override: use RerollDataProducer ----------------------------
def _create_data_producer(self, args, train_dataset):
"""Override to use RerollDataProducer for re-roll prompt injection."""
from axolotl.core.trainers.grpo.async_trainer import (
AsyncDataProducer,
ProducerConfig,
)
producer_config = ProducerConfig(
mini_epochs=args.num_iterations,
max_rollouts=None,
eval_during_produce=False,
empty_cache_before_produce=True,
empty_cache_after_produce=True,
async_prefetch=args.async_prefetch,
prefetch_depth=args.prefetch_depth,
)
data_producer = RerollDataProducer(
config=producer_config,
prompt_dataset=train_dataset,
num_generations=self.num_generations,
generation_batch_size=args.generation_batch_size,
train_batch_size=args.per_device_train_batch_size,
steps_per_generation=args.steps_per_generation,
shuffle_dataset=self.shuffle_dataset,
seed=args.seed,
)
data_producer.set_trainer(self)
if args.async_prefetch:
data_producer = AsyncDataProducer(
data_producer,
background_produce_kwargs={"skip_policy_logps": True},
)
return data_producer
# -- Reward worker pool --------------------------------------------------
def _get_reward_workers(self):
"""Return a list of persistent reward worker subprocesses (lazy-initialized)."""
import multiprocessing as _mp
num_workers = getattr(self.args, "reward_num_workers", 1)
if num_workers < 1:
num_workers = 1
if self._reward_workers is not None:
alive = all(proc.is_alive() for conn, proc in self._reward_workers)
if alive and len(self._reward_workers) == num_workers:
return self._reward_workers
self._shutdown_reward_workers()
workers = []
for _ in range(num_workers):
parent_conn, child_conn = _mp.Pipe()
proc = _mp.Process(
target=_persistent_reward_worker, args=(child_conn,), daemon=True
)
proc.start()
child_conn.close()
workers.append((parent_conn, proc))
self._reward_workers = workers
return workers
def _shutdown_reward_workers(self):
"""Shut down all persistent reward workers."""
if self._reward_workers is None:
return
for conn, proc in self._reward_workers:
try:
conn.send(None)
proc.join(timeout=5)
except Exception:
pass
try:
conn.close()
except Exception:
pass
self._reward_workers = None
# -- Hook overrides ------------------------------------------------------
def _compute_rewards_for_batch(
self, inputs, prompts, completions, completion_ids_list
):
"""Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
return self._collect_reward_workers(
inputs, prompts, completions, completion_ids_list
)
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Send reward work to subprocess workers (non-blocking).
Results are collected later by _collect_reward_workers, allowing GPU
logprob computation to overlap with CPU reward computation.
"""
reward_can_bg = all(
callable(rf)
and not isinstance(rf, nn.Module)
and not asyncio.iscoroutinefunction(rf)
for rf in self.reward_funcs
)
num_workers = getattr(self.args, "reward_num_workers", 1)
if not reward_can_bg or num_workers <= 1:
# Can't parallelize — store args for sync fallback in collect
self._reward_workers_used = None
self._pending_reward_args = (
inputs,
prompts,
completions,
completion_ids_list,
)
return
workers = self._get_reward_workers()
num_generations = self.num_generations
num_prompts = len(prompts)
num_groups = num_prompts // num_generations
# Shard by prompt groups across workers
groups_per_worker = max(1, (num_groups + len(workers) - 1) // len(workers))
workers_used = []
for w_idx, (conn, _proc) in enumerate(workers):
g_start = w_idx * groups_per_worker
g_end = min((w_idx + 1) * groups_per_worker, num_groups)
if g_start >= num_groups:
break
s_start = g_start * num_generations
s_end = g_end * num_generations
conn.send(
(
self.reward_funcs,
prompts[s_start:s_end],
completions[s_start:s_end],
completion_ids_list[s_start:s_end],
inputs[s_start:s_end],
self.reward_func_names,
)
)
workers_used.append(conn)
self._reward_workers_used = workers_used
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
def _collect_reward_workers(
self, inputs, prompts, completions, completion_ids_list
):
"""Collect reward results from subprocess workers (blocks until done)."""
from accelerate.utils import gather
workers_used = getattr(self, "_reward_workers_used", None)
args = getattr(self, "_pending_reward_args", None)
self._reward_workers_used = None
self._pending_reward_args = None
if workers_used is None:
# Sync fallback — compute on main thread
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
device = self.accelerator.device
num_prompts = len(args[1]) if args else len(prompts)
# Collect results from workers
all_worker_results = []
any_failed = False
for conn in workers_used:
result = conn.recv()
if result is None:
any_failed = True
# Drain remaining workers to prevent stale results in pipes
for remaining_conn in workers_used:
if remaining_conn is not conn:
try:
remaining_conn.recv()
except Exception:
pass
break
all_worker_results.append(result)
if not any_failed:
rewards_per_func = torch.zeros(
num_prompts, len(self.reward_funcs), device=device
)
offset = 0
for worker_result in all_worker_results:
chunk_size = len(worker_result[0])
for i, result in enumerate(worker_result):
rewards_per_func[offset : offset + chunk_size, i] = torch.tensor(
result, dtype=torch.float32, device=device
)
offset += chunk_size
return gather(rewards_per_func)
# Fallback to main thread on failure
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
def _post_advantage_hook(
self,
data: dict,
rewards_per_func,
advantages,
inputs: list,
num_generations: int,
mode: str,
s_start: int | None = None,
s_end: int | None = None,
is_last_chunk: bool = True,
) -> None:
"""Replay buffer store/replace + re-roll buffering."""
from trl.models.utils import disable_gradient_checkpointing
# -- Replay buffer: store high-signal groups --
if self._replay_buffer is not None:
local_grouped = rewards_per_func.view(
-1, num_generations, len(self.reward_funcs)
)
per_group_std = local_grouped.std(dim=1)
has_signal = (per_group_std > 0).any(dim=1)
offset = s_start or 0
if has_signal.any():
grouped_adv = advantages.view(-1, num_generations)
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
gi = group_idx.item()
start = offset + gi * num_generations
end = start + num_generations
group_data = {}
for key in data:
val = data[key]
if (
isinstance(val, torch.Tensor)
and val.dim() > 0
and val.size(0) >= end
):
group_data[key] = val[start:end].clone()
self._replay_buffer.add(replay_scores[gi].item(), group_data)
# Replace zero-signal groups with high-signal replay buffer entries
# Only in non-streaming path (s_start is None) — streaming scores
# groups incrementally, so replacement + logprob recompute would be
# too expensive per chunk.
n_replaced = 0
if s_start is None:
no_signal = ~has_signal
replaced_ranges = []
if no_signal.any() and len(self._replay_buffer) > 0:
for group_idx in no_signal.nonzero(as_tuple=True)[0]:
sampled = self._replay_buffer.sample(1)
if sampled is None:
break
sampled_group = sampled[0]
gi = group_idx.item()
start = offset + gi * num_generations
end = start + num_generations
for key, val in sampled_group.items():
if key in data and isinstance(data[key], torch.Tensor):
src = val.to(data[key].device)
tgt_seq_len = (
data[key].size(1) if data[key].dim() > 1 else None
)
if start >= data[key].size(0) or end > data[key].size(
0
):
continue
if tgt_seq_len is not None:
if src.size(1) <= tgt_seq_len:
data[key][start:end] = 0
data[key][start:end, : src.size(1)] = src
else:
data[key][start:end] = src[:, :tgt_seq_len]
else:
data[key][start:end] = src
replaced_ranges.append((start, end))
n_replaced += 1
# Recompute old_per_token_logps for replayed groups
if (
n_replaced > 0
and self._replay_recompute_logps
and "old_per_token_logps" in data
):
with (
torch.no_grad(),
disable_gradient_checkpointing(
self.model, self.args.gradient_checkpointing_kwargs
),
):
for r_start, r_end in replaced_ranges:
r_ids = torch.cat(
[
data["prompt_ids"][r_start:r_end],
data["completion_ids"][r_start:r_end],
],
dim=1,
)
r_mask = torch.cat(
[
data["prompt_mask"][r_start:r_end],
data["completion_mask"][r_start:r_end],
],
dim=1,
)
r_logits_to_keep = data["completion_ids"].size(1)
r_fwd_kwargs = {}
for fk in (
"pixel_values",
"image_grid_thw",
"pixel_attention_mask",
"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):
if fk in data:
r_fwd_kwargs[fk] = data[fk]
r_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
r_ids,
r_mask,
r_logits_to_keep,
r_end - r_start,
**r_fwd_kwargs,
)
data["old_per_token_logps"][r_start:r_end] = r_logps
if n_replaced > 0:
self._metrics[mode]["replay_buffer_replacements"].append(
float(n_replaced)
)
if is_last_chunk:
self._metrics[mode]["replay_buffer_size"].append(
float(len(self._replay_buffer))
)
# -- Re-roll buffer: store failed prompts --
if getattr(self.args, "reroll_start_fraction", 1.0) < 1.0:
grouped_rewards = rewards_per_func.view(
-1, num_generations, len(self.reward_funcs)
)
per_group_std = grouped_rewards.std(dim=1)
per_group_mean = grouped_rewards.mean(dim=1)
zero_signal = (per_group_std == 0).all(dim=1)
all_failed = (per_group_mean.abs() < 1e-6).all(dim=1)
should_reroll = zero_signal & all_failed
_n_buffered = 0
with self._reroll_lock:
for group_idx in should_reroll.nonzero(as_tuple=True)[0]:
idx = group_idx.item() * num_generations
if idx >= len(inputs):
continue
prompt_input = inputs[idx]
self._reroll_buffer.append(prompt_input)
_n_buffered += 1
if _n_buffered > 0:
self._metrics[mode]["reroll_buffered"].append(float(_n_buffered))
if is_last_chunk:
self._metrics[mode]["reroll_buffer_size"].append(
float(len(self._reroll_buffer))
)
# -- Zero-advantage skipping + Liger OPSM ---------------------------------
def compute_liger_loss(self, unwrapped_model, inputs):
"""Liger loss with zero-adv skipping and off-policy sequence masking (OPSM).
The base class Liger path doesn't support OPSM because the fused kernel
doesn't expose per-token logprobs needed for the KL computation. This
override computes them via chunked lm_head matmul (no grad, low memory)
and applies the OPSM to the loss mask before calling the kernel.
"""
if self.args.skip_zero_advantage_batches and torch.all(
inputs["advantages"] == 0
):
mode = "train" if self.model.training else "eval"
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
return torch.tensor(
0.0, device=inputs["advantages"].device, requires_grad=True
)
if self.off_policy_mask_threshold is None:
return super().compute_liger_loss(unwrapped_model, inputs)
# OPSM path: need per_token_logps for KL, which Liger kernel doesn't provide
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = (
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
last_hidden_state = self._get_last_hidden_state(
unwrapped_model,
input_ids,
attention_mask,
logits_to_keep,
inputs.get("pixel_values"),
inputs.get("image_grid_thw"),
inputs.get("pixel_attention_mask"),
inputs.get("image_sizes"),
)
loss_mask = (
completion_mask
if "tool_mask" not in inputs
else completion_mask * inputs["tool_mask"]
)
# Compute per_token_logps via chunked lm_head matmul (no grad, low memory)
lm_weight = unwrapped_model.lm_head.weight
lm_bias = unwrapped_model.lm_head.bias
with torch.no_grad():
per_token_logps_chunks = []
for i in range(last_hidden_state.size(0)):
chunk_logits = torch.matmul(last_hidden_state[i : i + 1], lm_weight.t())
if lm_bias is not None:
chunk_logits = chunk_logits + lm_bias
chunk_lps = (
chunk_logits.float()
.log_softmax(-1)
.gather(-1, completion_ids[i : i + 1].unsqueeze(-1))
.squeeze(-1)
)
per_token_logps_chunks.append(chunk_lps)
del chunk_logits
per_token_logps = torch.cat(per_token_logps_chunks, dim=0)
advantages = inputs["advantages"]
if advantages.dim() == 1:
advantages_2d = advantages.unsqueeze(1)
else:
advantages_2d = advantages
sampling_per_token_logps = inputs.get("sampling_per_token_logps")
if sampling_per_token_logps is None:
sampling_per_token_logps = inputs.get("old_per_token_logps")
if sampling_per_token_logps is None:
sampling_per_token_logps = per_token_logps
off_policy_mask = GRPOTrainer.get_off_policy_mask(
advantages=advantages_2d,
per_token_logps=per_token_logps,
sampling_per_token_logps=sampling_per_token_logps,
mask=loss_mask,
off_policy_threshold=self.off_policy_mask_threshold,
)
loss_mask = loss_mask * off_policy_mask
# Call the Liger fused kernel with OPSM-modified mask
loss, metrics = self.liger_grpo_loss(
_input=last_hidden_state,
lin_weight=unwrapped_model.lm_head.weight,
selected_token_ids=completion_ids,
attention_mask=loss_mask,
advantages=inputs["advantages"],
bias=unwrapped_model.lm_head.bias,
old_per_token_logps=inputs.get("old_per_token_logps"),
ref_per_token_logps=inputs.get("ref_per_token_logps"),
vllm_is_ratio=inputs.get("importance_sampling_ratio"),
)
mean_kl = metrics[0] if self.beta != 0.0 else None
clip_ratio = metrics[-1]
mode = "train" if self.model.training else "eval"
if self.beta != 0.0:
self._metrics[mode]["kl"].append(
self.accelerator.gather(mean_kl).mean().item()
)
self._metrics[mode]["clip_ratio"].append(
self.accelerator.gather(clip_ratio).mean().item()
)
normalizer = (
self.current_gradient_accumulation_steps if mode == "train" else 1.0
)
return loss / normalizer
def _compute_loss(self, model, inputs):
if self.args.skip_zero_advantage_batches and torch.all(
inputs["advantages"] == 0
):
mode = "train" if self.model.training else "eval"
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
# Create zero loss with grad_fn. DeepSpeed requires grad_fn != None.
# With ZeRO-3, parameters are partitioned (shape=[0], requires_grad=False)
# so we can't just do `(p * 0).sum()`. Instead, do a tiny forward pass
# with a single token to create a proper computation graph.
prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1)
attn = torch.ones_like(prompt_ids)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model(input_ids=prompt_ids, attention_mask=attn)
return out.logits.sum() * 0
return super()._compute_loss(model, inputs)

View File

@@ -0,0 +1,44 @@
"""Simple replay buffer for storing and sampling high-signal rollout groups."""
import heapq
import torch
class ReplayBuffer:
"""Min-heap replay buffer that keeps the highest-scoring rollout groups.
Groups are scored by signal quality (advantage magnitude * reward variance).
When sampling, groups are drawn proportional to their scores.
"""
def __init__(self, max_size: int):
self.max_size = max_size
self._heap: list[tuple[float, int, dict]] = [] # min-heap of (score, id, data)
self._counter = 0 # unique tiebreaker for heap
def __len__(self):
return len(self._heap)
def add(self, score: float, data: dict):
"""Add a group to the buffer. If full, replaces lowest-scoring entry."""
if self.max_size <= 0:
return
self._counter += 1
if len(self._heap) < self.max_size:
heapq.heappush(self._heap, (score, self._counter, data))
elif score > self._heap[0][0]:
heapq.heapreplace(self._heap, (score, self._counter, data))
def sample(self, num_samples: int) -> list[dict] | None:
"""Sample groups weighted by their scores. Returns None if buffer is empty."""
if self.max_size <= 0 or not self._heap:
return None
scores = torch.tensor([item[0] for item in self._heap], dtype=torch.float32)
scores = scores.clamp(min=1e-8) # avoid zero probabilities
probs = scores / scores.sum()
replacement = num_samples > len(self._heap)
indices = torch.multinomial(
probs, num_samples, replacement=replacement
).tolist()
return [self._heap[i][2] for i in indices]

View File

@@ -40,6 +40,7 @@ from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOTrainer
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
@@ -66,6 +67,19 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
class AxolotlAsyncGRPOTrainer(
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
FastAsyncGRPOTrainer,
):
"""Extend AsyncGRPOTrainer with axolotl helpers"""
_tag_names = ["trl", "grpo", "async", "axolotl"]
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -25,7 +25,7 @@ def get_lora_parameters(
) -> tuple[
torch.Tensor,
torch.Tensor | None,
QuantState | None,
QuantState | torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
float | None,
@@ -48,9 +48,13 @@ def get_lora_parameters(
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
if quant_state is None and W.dtype == torch.float8_e4m3fn:
quant_state = getattr(base_layer, "weight_scale_inv", None)
return W, b, quant_state, None, None, None
quant_state = getattr(W, "quant_state", None)
if quant_state is None and W.dtype == torch.float8_e4m3fn:
quant_state = getattr(base_layer, "weight_scale_inv", None)
active_adapter = (
proj.active_adapters[0]
@@ -81,7 +85,7 @@ def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
b: torch.Tensor | None,
W_quant: QuantState | None,
W_quant: QuantState | torch.Tensor | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
s: float | None,

View File

@@ -1,4 +1,4 @@
"""Dequantization utilities for `bitsandbytes` integration."""
"""Dequantization utilities for `bitsandbytes` and FP8 integration."""
import ctypes
@@ -15,9 +15,50 @@ CUDA_STREAM: torch.cuda.Stream | None = None
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
def dequantize_fp8(
W: torch.Tensor,
scale_inv: torch.Tensor,
dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.
Args:
W: FP8 weight tensor [out_features, in_features] in float8_e4m3fn.
scale_inv: Per-block inverse scale [ceil(out/block), ceil(in/block)]
or per-tensor scalar.
dtype: Output dtype (default bf16).
Returns:
Dequantized tensor in the specified dtype.
"""
W_float = W.to(dtype)
if scale_inv.numel() == 1:
return W_float * scale_inv.to(dtype)
if scale_inv.dim() == 2 and W.dim() == 2:
sr, sc = scale_inv.shape
br = W.shape[0] // sr
bc = W.shape[1] // sc
# If dimensions are exactly divisible, use fast reshape path
if sr * br == W.shape[0] and sc * bc == W.shape[1]:
return (
W_float.reshape(sr, br, sc, bc) * scale_inv[:, None, :, None].to(dtype)
).reshape(W.shape)
# Tail-block handling: compute actual block size (ceil division),
# tile scale_inv to cover full shape, then crop to W's dimensions
br_ceil = -(-W.shape[0] // sr) # ceil(rows / scale_rows) = block_size
bc_ceil = -(-W.shape[1] // sc)
scale_expanded = (
scale_inv.to(dtype)
.repeat_interleave(br_ceil, dim=0)
.repeat_interleave(bc_ceil, dim=1)
)[: W.shape[0], : W.shape[1]]
return W_float * scale_expanded
return W_float * scale_inv.to(dtype)
def dequantize(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
quant_state: QuantState | list | torch.Tensor | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
@@ -49,6 +90,15 @@ def dequantize(
if quant_state is None:
return W
# FP8 path: quant_state is actually scale_inv tensor
if W.dtype == torch.float8_e4m3fn:
scale_inv = quant_state
# Caller may pass W.t() (non-contiguous) — dequantize in original
# layout then transpose back so the result shape matches the input.
if not W.is_contiguous() and W.dim() == 2:
return dequantize_fp8(W.t(), scale_inv).t()
return dequantize_fp8(W, scale_inv)
# Get the target device from input tensor W
target_device = W.device

View File

@@ -160,6 +160,18 @@ def load_lora(
else:
model = get_peft_model(model, lora_config, **model_kwargs)
# FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training
# requires a compute dtype (bf16/fp16). Cast trainable LoRA params.
if cfg.torch_dtype:
_fp8_cast_dtype = cfg.torch_dtype
elif torch.cuda.is_available() and torch.cuda.is_bf16_supported():
_fp8_cast_dtype = torch.bfloat16
else:
_fp8_cast_dtype = torch.float16
for _name, param in model.named_parameters():
if param.requires_grad and param.dtype == torch.float8_e4m3fn:
param.data = param.data.to(_fp8_cast_dtype)
if rank == 0:
try:
model.print_trainable_parameters()

View File

@@ -215,6 +215,8 @@ class ModelLoader:
self.model_kwargs["revision"] = self.cfg.revision_of_model
if self.cfg.use_kernels:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
if "allow_all_kernels" not in self.model_kwargs:
self.model_kwargs["allow_all_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
self._check_model_requirements()

View File

@@ -116,6 +116,7 @@ class PatchManager:
self._apply_patch_deepspeed_zero3()
self._apply_voxtral_patches()
self._apply_apertus_patches()
self._apply_trl_vllm_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -667,6 +668,17 @@ class PatchManager:
patch_apertus_xielu_activation()
def _apply_trl_vllm_patches(self):
"""Apply TRL vLLM patches for batched weight sync, NaN logprobs fix, and scalar handling."""
if (
self.cfg.rl
and getattr(self.cfg, "trl", None)
and getattr(self.cfg.trl, "use_vllm", False)
):
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
patch_trl_vllm()
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
if self.cfg.scaling_softmax:

View File

@@ -0,0 +1,245 @@
"""Monkeypatches for TRL's vLLM integration and trainer utils.
Adds:
- VLLMClient.batch_update_named_params: batched weight sync (fewer HTTP round-trips)
- extract_logprobs: NaN→0.0 fix (prevents downstream NaN propagation)
- VLLMGeneration: weight_sync_chunk_size + batched sync path for non-FSDP/non-ZeRO
- split_tensor_dict / shuffle_sequence_dict: scalar type handling (int/float/bool passthrough)
"""
import logging
import math
from functools import wraps
import torch
from torch import nn
LOG = logging.getLogger(__name__)
def _batch_update_named_params(
self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None
):
"""Batched weight sync — sends param metadata via HTTP, tensors via NCCL."""
from transformers import is_torch_xpu_available
if chunk_size is None:
chunks = [params]
else:
chunks = []
current_chunk: list[tuple[str, torch.Tensor]] = []
current_elements = 0
for name, weights in params:
n_elem = weights.numel()
if current_chunk and current_elements + n_elem > chunk_size:
chunks.append(current_chunk)
current_chunk = []
current_elements = 0
current_chunk.append((name, weights))
current_elements += n_elem
if current_chunk:
chunks.append(current_chunk)
for chunk in chunks:
param_metadata = [
{"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
for name, weights in chunk
]
url = f"{self.base_url}/batch_update_named_params/"
response = self.session.post(url, json={"params": param_metadata})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
for _name, weights in chunk:
if is_torch_xpu_available():
self.communicator.broadcast(weights, root=self.rank)
else:
self.communicator.broadcast(weights, src=self.rank)
if is_torch_xpu_available():
self.communicator.barrier()
else:
self.communicator.group.barrier()
def _update_model_params(self, model: nn.Module, chunk_size: int | None = None):
"""Updates all model params using batch_update_named_params."""
params = [(name, param.data) for name, param in model.named_parameters()]
self.batch_update_named_params(params, chunk_size=chunk_size)
def _patched_extract_logprobs(all_outputs):
"""extract_logprobs with NaN→0.0 fix (stock TRL uses None which causes downstream errors)."""
all_logprobs = []
all_token_ids = []
for outputs in all_outputs:
for output in outputs.outputs:
if output.logprobs is None:
return None, None
seq_logprobs = []
seq_token_ids = []
for lp in output.logprobs:
sorted_items = sorted(lp.items(), key=lambda x: x[1].rank)
seq_token_ids.append([token_id for token_id, _ in sorted_items])
seq_logprobs.append(
[
0.0 if math.isnan(item.logprob) else item.logprob
for _, item in sorted_items
]
)
all_logprobs.append(seq_logprobs)
all_token_ids.append(seq_token_ids)
return all_logprobs, all_token_ids
def _patched_split_tensor_dict(tensor_dict, num_chunks):
"""split_tensor_dict that handles scalar types (int/float/bool) for num_items_in_batch."""
first_tensor = next(
tensor
for tensor in tensor_dict.values()
if tensor is not None and isinstance(tensor, torch.Tensor) and tensor.ndim > 0
)
chunk_size = first_tensor.shape[0] // num_chunks
chunks = []
for i in range(num_chunks):
chunk_dict = {}
for key, tensor in tensor_dict.items():
if isinstance(tensor, (int, float, bool)):
chunk_dict[key] = tensor
elif tensor is not None and (isinstance(tensor, list) or tensor.ndim > 0):
chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size]
elif tensor is not None and tensor.ndim == 0:
chunk_dict[key] = tensor
else:
chunk_dict[key] = None
chunks.append(chunk_dict)
return chunks
def _patched_shuffle_sequence_dict(seq_dict):
"""shuffle_sequence_dict that handles scalar types (int/float/bool)."""
first_seq = next(
v
for v in seq_dict.values()
if v is not None and isinstance(v, (torch.Tensor, list)) and len(v) > 0
)
perm = torch.randperm(len(first_seq))
def permute(v):
if v is None:
return None
if isinstance(v, (int, float, bool)):
return v
if isinstance(v, torch.Tensor) and v.ndim == 0:
return v
if isinstance(v, torch.Tensor) and v.ndim >= 1:
return v[perm]
if isinstance(v, list):
return [v[i] for i in perm.tolist()]
return v
return {k: permute(v) for k, v in seq_dict.items()}
def _patch_sync_weights_batched(original_init):
"""Wrap VLLMGeneration.__init__ to accept weight_sync_chunk_size."""
@wraps(original_init)
def patched_init(self, *args, weight_sync_chunk_size=None, **kwargs):
original_init(self, *args, **kwargs)
self.weight_sync_chunk_size = weight_sync_chunk_size
return patched_init
def _make_batched_sync_weights(original_sync_weights):
"""Wrap sync_weights to use batched sync for non-FSDP/non-ZeRO paths."""
@wraps(original_sync_weights)
def patched_sync_weights(self):
from accelerate.utils import is_peft_model
# Check if we're in a non-PEFT, non-FSDP, non-ZeRO scenario where batching helps
accelerator = self.accelerator
model = self.model
is_fsdp_enabled = self.is_fsdp_enabled
deepspeed_plugin = accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
is_peft = is_peft_model(model)
# If PEFT, FSDP, or ZeRO-3, fall back to original (which handles those cases)
if is_peft or is_fsdp_enabled or zero_stage_3:
return original_sync_weights(self)
# Non-PEFT, non-FSDP, non-ZeRO: use batched sync
if self.mode == "colocate" and getattr(self, "enable_sleep_mode", False):
from vllm.distributed.device_communicators.cuda_wrapper import (
empty_cache,
)
empty_cache()
self.llm.wake_up(tags=["weights"])
if self.mode == "server" and accelerator.is_main_process:
params = [
(self._fix_param_name_to_vllm(name), param.data)
for name, param in model.named_parameters()
]
self.vllm_client.batch_update_named_params(
params, chunk_size=getattr(self, "weight_sync_chunk_size", None)
)
elif self.mode == "colocate":
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
weights = [
(self._fix_param_name_to_vllm(name), param.data)
for name, param in model.named_parameters()
]
llm_model.load_weights(weights=weights)
# Reset cache
if self.mode == "server" and accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
elif self.mode == "colocate":
self.llm.reset_prefix_cache()
return patched_sync_weights
def patch_trl_vllm():
"""Apply all TRL vLLM monkeypatches."""
import trl.generation.vllm_client
import trl.generation.vllm_generation
import trl.trainer.utils
VLLMClient = trl.generation.vllm_client.VLLMClient
VLLMGeneration = trl.generation.vllm_generation.VLLMGeneration
# 1. Add batch_update_named_params to VLLMClient
if not hasattr(VLLMClient, "batch_update_named_params"):
VLLMClient.batch_update_named_params = _batch_update_named_params
VLLMClient.update_model_params = _update_model_params
LOG.info("Patched VLLMClient with batch_update_named_params")
# 2. Patch extract_logprobs (NaN→0.0)
trl.generation.vllm_generation.extract_logprobs = _patched_extract_logprobs
LOG.info("Patched extract_logprobs with NaN→0.0 fix")
# 3. Patch VLLMGeneration.__init__ to accept weight_sync_chunk_size
VLLMGeneration.__init__ = _patch_sync_weights_batched(VLLMGeneration.__init__)
# 4. Patch sync_weights for batched non-FSDP/non-ZeRO path
VLLMGeneration.sync_weights = _make_batched_sync_weights(
VLLMGeneration.sync_weights
)
LOG.info("Patched VLLMGeneration with batched sync_weights")
# 5. Patch split_tensor_dict and shuffle_sequence_dict
trl.trainer.utils.split_tensor_dict = _patched_split_tensor_dict
trl.trainer.utils.shuffle_sequence_dict = _patched_shuffle_sequence_dict
LOG.info("Patched split_tensor_dict and shuffle_sequence_dict for scalar types")

View File

View File

@@ -0,0 +1,503 @@
"""vLLM serve script with native LoRA adapter support.
Extends TRL's vllm_serve to enable direct LoRA adapter loading in vLLM,
instead of merging adapter weights into the base model before syncing.
Usage:
Set ``vllm.serve_module: axolotl.scripts.vllm_serve_lora`` in your config,
or ``trl.vllm_lora_sync: true`` to auto-select.
Benefits over merge-sync:
- Syncs only LoRA adapter weights via filesystem instead of full merged model via NCCL
- vLLM handles LoRA application natively (Punica kernels)
- No NCCL communicator needed for weight sync
"""
import logging
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from itertools import chain
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from typing import Any
from trl.scripts.vllm_serve import (
ScriptArguments,
chunk_list,
extract_logprobs,
get_open_port,
)
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
logger = logging.getLogger(__name__)
@dataclass
class LoRAScriptArguments(ScriptArguments):
"""Extended script arguments with LoRA support."""
enable_lora: bool = field(
default=True,
metadata={"help": "Enable LoRA adapter support in vLLM."},
)
max_lora_rank: int = field(
default=64,
metadata={"help": "Maximum LoRA rank supported."},
)
max_loras: int = field(
default=2,
metadata={"help": "Maximum number of LoRA adapters loaded simultaneously."},
)
lora_dtype: str = field(
default="bfloat16",
metadata={"help": "Data type for LoRA weights."},
)
def llm_worker(
script_args: LoRAScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
"""Worker process that creates a vLLM LLM with LoRA enabled."""
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
# Use batch-capable worker extension (adds batch_update_named_params + auto-close)
worker_extension_cls="axolotl.scripts.vllm_worker_ext.BatchWeightSyncWorkerExtension",
trust_remote_code=script_args.trust_remote_code,
model_impl=script_args.vllm_model_impl,
logprobs_mode="processed_logprobs",
# LoRA
enable_lora=script_args.enable_lora,
max_lora_rank=script_args.max_lora_rank,
max_loras=script_args.max_loras,
lora_dtype=script_args.lora_dtype,
)
connection.send({"status": "ready"})
while True:
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args = command.get("args", ())
kwargs = command.get("kwargs", {})
# Reconstruct LoRARequest from serialized dict (can't pickle across pipe)
if "lora_request" in kwargs and kwargs["lora_request"] is not None:
lr = kwargs["lora_request"]
kwargs["lora_request"] = LoRARequest(
lora_name=lr["lora_name"],
lora_int_id=lr["lora_int_id"],
lora_path=lr["lora_path"],
load_inplace=lr.get("load_inplace", False),
)
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
def main(script_args: ScriptArguments):
"""Start vLLM workers with LoRA support and the HTTP server."""
import asyncio
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel, Field as PydanticField
# Request/Response models (defined locally like TRL's vllm_serve.main)
class GenerateRequest(BaseModel):
prompts: list[str]
images: list[str] | None = None
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
logprobs: int | None = 0
truncate_prompt_tokens: int | None = None
structured_outputs_regex: str | None = None
generation_kwargs: dict = PydanticField(default_factory=dict)
class GenerateResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[list[float]]]
logprob_token_ids: list[list[list[int]]]
class ChatRequest(BaseModel):
messages: list[list[dict]]
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
logprobs: int | None = 0
truncate_prompt_tokens: int | None = None
structured_outputs_regex: str | None = None
generation_kwargs: dict = PydanticField(default_factory=dict)
chat_template_kwargs: dict = PydanticField(default_factory=dict)
class ChatResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[list[float]]]
logprob_token_ids: list[list[list[int]]]
class InitCommunicatorRequest(BaseModel):
host: str
port: int
world_size: int
client_device_uuid: str
# Wrap plain ScriptArguments with LoRA defaults
if not isinstance(script_args, LoRAScriptArguments):
lora_args = LoRAScriptArguments.__new__(LoRAScriptArguments)
for f in ScriptArguments.__dataclass_fields__:
setattr(lora_args, f, getattr(script_args, f))
# Apply LoRA defaults
for f in LoRAScriptArguments.__dataclass_fields__:
if f not in ScriptArguments.__dataclass_fields__:
setattr(
lora_args, f, LoRAScriptArguments.__dataclass_fields__[f].default
)
script_args = lora_args
# Spawn workers
master_port = get_open_port()
connections: list[Connection] = []
processes: list[Process] = []
for dp_rank in range(script_args.data_parallel_size):
parent_conn, child_conn = Pipe()
process = Process(
target=llm_worker,
args=(script_args, dp_rank, master_port, child_conn),
)
process.start()
connections.append(parent_conn)
processes.append(process)
@asynccontextmanager
async def lifespan(app: FastAPI):
import time
startup_timeout = 300 # 5 minutes
start_time = time.monotonic()
ready: set[int] = set()
while len(ready) < script_args.data_parallel_size:
elapsed = time.monotonic() - start_time
if elapsed > startup_timeout:
raise RuntimeError(
f"vLLM workers failed to start within {startup_timeout}s "
f"({len(ready)}/{script_args.data_parallel_size} ready)"
)
for i, (conn, proc) in enumerate(zip(connections, processes, strict=True)):
if id(conn) in ready:
continue
if not proc.is_alive():
raise RuntimeError(
f"vLLM worker {i} exited unexpectedly during startup"
)
if conn.poll():
msg = conn.recv()
if isinstance(msg, dict) and msg.get("status") == "ready":
ready.add(id(conn))
await asyncio.sleep(0.1)
yield
for p in processes:
p.join(timeout=10)
if p.is_alive():
p.terminate()
p.join()
app = FastAPI(lifespan=lifespan)
# --- Active LoRA state (shared across endpoints via closure) ---
active_lora: dict = {"request": None}
# ------------------------------------------------------------------
# LoRA-specific endpoints
# ------------------------------------------------------------------
class SetLoRARequest(BaseModel):
lora_name: str
lora_int_id: int
lora_path: str
load_inplace: bool = False
@app.post("/set_lora_adapter/")
async def set_lora_adapter(request: SetLoRARequest):
"""Register a LoRA adapter for all subsequent generate/chat calls."""
active_lora["request"] = {
"lora_name": request.lora_name,
"lora_int_id": request.lora_int_id,
"lora_path": request.lora_path,
"load_inplace": request.load_inplace,
}
logger.info(
"Set active LoRA: %s (id=%d, path=%s)",
request.lora_name,
request.lora_int_id,
request.lora_path,
)
return {"status": "ok"}
@app.post("/clear_lora_adapter/")
async def clear_lora_adapter():
"""Clear active LoRA adapter (revert to base model)."""
active_lora["request"] = None
return {"status": "ok"}
# ------------------------------------------------------------------
# Standard endpoints (mirrors TRL's vllm_serve)
# ------------------------------------------------------------------
@app.get("/health/")
async def health():
return {"status": "ok"}
@app.get("/get_world_size/")
async def get_world_size():
return {
"world_size": script_args.tensor_parallel_size
* script_args.data_parallel_size
}
@app.post("/generate/", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
"""Generate completions with optional LoRA adapter."""
import base64
from io import BytesIO
import vllm
from packaging.version import Version
from vllm.sampling_params import GuidedDecodingParams
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
prompts: list[dict[str, Any]] = []
for prompt, image in zip(request.prompts, images, strict=True):
row: dict[str, Any] = {"prompt": prompt}
if image is not None:
from PIL import Image
row["multi_modal_data"] = {
"image": Image.open(BytesIO(base64.b64decode(image)))
}
prompts.append(row)
generation_kwargs = {
"n": request.n,
"repetition_penalty": request.repetition_penalty,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"logprobs": request.logprobs,
}
generation_kwargs.update(request.generation_kwargs)
if Version(vllm.__version__) <= Version("0.10.2"):
key = "guided_decoding"
if request.structured_outputs_regex is not None:
generation_kwargs[key] = GuidedDecodingParams(
regex=request.structured_outputs_regex
)
else:
generation_kwargs.setdefault(key, None)
else:
from vllm.sampling_params import StructuredOutputsParams
key = "structured_outputs"
if request.structured_outputs_regex is not None:
generation_kwargs[key] = StructuredOutputsParams(
regex=request.structured_outputs_regex
)
elif isinstance(generation_kwargs.get(key), dict):
generation_kwargs[key] = StructuredOutputsParams(
**generation_kwargs[key]
)
else:
generation_kwargs.setdefault(key, None)
sampling_params = SamplingParams(**generation_kwargs)
chunked_prompts = chunk_list(prompts, script_args.data_parallel_size)
for conn, chunk in zip(connections, chunked_prompts, strict=True):
if not chunk:
chunk = [{"prompt": "<placeholder>"}]
kwargs = {
"prompts": chunk,
"sampling_params": sampling_params,
"lora_request": active_lora["request"],
}
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
all_outputs = [conn.recv() for conn in connections]
all_outputs = [
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
]
all_outputs = list(chain.from_iterable(all_outputs))
return {
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
"completion_ids": [
list(out.token_ids) for o in all_outputs for out in o.outputs
],
"logprobs": extract_logprobs(all_outputs)[0],
"logprob_token_ids": extract_logprobs(all_outputs)[1],
}
@app.post("/chat/", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""Chat endpoint with optional LoRA adapter."""
generation_kwargs = {
"n": request.n,
"repetition_penalty": request.repetition_penalty,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"logprobs": request.logprobs,
}
generation_kwargs.update(request.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
chunked = chunk_list(request.messages, script_args.data_parallel_size)
for conn, chunk in zip(connections, chunked, strict=True):
if not chunk:
chunk = [[{"role": "user", "content": "<placeholder>"}]]
kwargs = {
"messages": chunk,
"sampling_params": sampling_params,
"use_tqdm": False,
"lora_request": active_lora["request"],
}
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
all_outputs = [conn.recv() for conn in connections]
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
all_outputs = list(chain.from_iterable(all_outputs))
return {
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
"completion_ids": [
list(out.token_ids) for o in all_outputs for out in o.outputs
],
"logprobs": extract_logprobs(all_outputs)[0],
"logprob_token_ids": extract_logprobs(all_outputs)[1],
}
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
@app.post("/init_communicator/")
async def init_communicator(request: InitCommunicatorRequest):
world_size = (
script_args.tensor_parallel_size * script_args.data_parallel_size + 1
)
kwargs = {
"method": "init_communicator",
"args": (
request.host,
request.port,
world_size,
request.client_device_uuid,
),
}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(
*(loop.run_in_executor(None, c.send, msg) for c in connections)
)
return {"message": "Initializing communicator"}
class UpdateWeightsRequest(BaseModel):
name: str
dtype: str
shape: list[int]
@app.post("/update_named_param/")
async def update_named_param(request: UpdateWeightsRequest):
kwargs = {
"method": "update_named_param",
"args": (request.name, request.dtype, tuple(request.shape)),
}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(
*(loop.run_in_executor(None, c.send, msg) for c in connections)
)
return {"message": "Updating parameter"}
class BatchUpdateWeightsRequest(BaseModel):
params: list[dict]
@app.post("/batch_update_named_params/")
async def batch_update_named_params(request: BatchUpdateWeightsRequest):
params_list = [
(p["name"], p["dtype"], tuple(p["shape"])) for p in request.params
]
kwargs = {"method": "batch_update_named_params", "args": (params_list,)}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(
*(loop.run_in_executor(None, c.send, msg) for c in connections)
)
return {"message": f"Batch update for {len(params_list)} params"}
@app.post("/reset_prefix_cache/")
async def reset_prefix_cache():
for conn in connections:
conn.send({"type": "call", "method": "reset_prefix_cache"})
results = [conn.recv() for conn in connections]
return {"message": f"Reset prefix cache: {all(results)}"}
@app.post("/close_communicator/")
async def close_communicator():
kwargs = {"method": "close_communicator"}
for conn in connections:
conn.send(
{
"type": "fire_and_forget",
"method": "collective_rpc",
"kwargs": kwargs,
}
)
return {"message": "Closing communicator"}
uvicorn.run(
app,
host=script_args.host,
port=script_args.port,
log_level=script_args.log_level,
access_log=True,
)

View File

@@ -0,0 +1,158 @@
"""Extended vLLM worker extension with batch weight sync support.
Subclasses TRL's WeightSyncWorkerExtension to add:
- batch_update_named_params: receives multiple params in one call
- Auto-close stale communicator on re-init
- _direct_set_weight: proper handling for stacked (qkv_proj, gate_up_proj) params,
including LoRA-wrapped models where vLLM inserts base_layer into the hierarchy
"""
import logging
import torch
try:
from transformers import is_torch_xpu_available
except ImportError:
is_torch_xpu_available = lambda: False # noqa: E731
from trl.scripts.vllm_serve import WeightSyncWorkerExtension
logger = logging.getLogger(__name__)
# Stacked param name mapping: shard_name -> (packed_name, shard_order)
_STACKED_PARAMS = {
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
"""Worker extension that adds batch weight update and direct weight setting."""
def init_communicator(self, host, port, world_size, client_device_uuid):
"""Auto-close stale communicator before re-initializing."""
if self.communicator is not None:
self.close_communicator()
super().init_communicator(host, port, world_size, client_device_uuid)
def _direct_set_weight(self, name: str, weight: torch.Tensor) -> None:
"""Directly copy weight data into the model, handling stacked params.
Bypasses model.load_weights() which may fail on vLLM 0.17's new
module-tree weight loader for stacked params (qkv_proj, gate_up_proj).
Handles LoRA-wrapped params where vLLM inserts ``base_layer`` into the
parameter hierarchy (e.g. ``qkv_proj.base_layer.weight``).
"""
model = self.model_runner.model
params_dict = dict(model.named_parameters())
# Check if this is a simple direct param (exists as-is)
if name in params_dict:
params_dict[name].data.copy_(weight.to(params_dict[name].dtype))
return
# Also check with base_layer inserted: x.y.weight -> x.y.base_layer.weight
parts_bl = name.rsplit(".", 1)
if len(parts_bl) == 2:
base_layer_name = f"{parts_bl[0]}.base_layer.{parts_bl[1]}"
if base_layer_name in params_dict:
params_dict[base_layer_name].data.copy_(
weight.to(params_dict[base_layer_name].dtype)
)
return
# Handle stacked params: e.g. "model.layers.0.self_attn.q_proj.weight"
# -> "model.layers.0.self_attn.qkv_proj.weight" with shard offset
parts = name.rsplit(".", 2) # [prefix, layer_name, suffix]
if len(parts) == 3:
prefix, layer_name, suffix = parts
if layer_name in _STACKED_PARAMS:
packed_name, shard_idx = _STACKED_PARAMS[layer_name]
for packed_full in [
f"{prefix}.{packed_name}.{suffix}",
f"{prefix}.{packed_name}.base_layer.{suffix}",
]:
if packed_full not in params_dict:
continue
param = params_dict[packed_full]
# Navigate to the packed module to find shard sizes
module_path = packed_full.rsplit(".", 1)[0] # strip .weight/.bias
if ".base_layer" in module_path:
module_path = module_path.replace(".base_layer", "")
module = model
for attr in module_path.split("."):
module = getattr(module, attr, None)
if module is None:
break
# LoRA wrappers don't have output_sizes directly;
# check base_layer for the underlying parallel linear
if module is not None and not hasattr(module, "output_sizes"):
base = getattr(module, "base_layer", None)
if base is not None and hasattr(base, "output_sizes"):
module = base
if module is not None and hasattr(module, "output_sizes"):
tp_size = getattr(module, "tp_size", 1)
sizes = [s // tp_size for s in module.output_sizes]
offset = sum(sizes[:shard_idx])
shard_size = sizes[shard_idx]
param.data[offset : offset + shard_size].copy_(
weight.to(param.dtype)
)
return
# Fallback: try load_weights (may work for non-stacked params)
logger.warning("Falling back to load_weights for param: %s", name)
model.load_weights(weights=[(name, weight)])
def update_named_param(self, name, dtype, shape):
"""Override to use _direct_set_weight instead of load_weights."""
if self.communicator is None:
raise RuntimeError("Communicator not initialized.")
dtype = getattr(torch, dtype.split(".")[-1])
weight = torch.empty(shape, dtype=dtype, device=self.device)
if is_torch_xpu_available():
self.communicator.broadcast(weight, root=self.client_rank)
self.communicator.barrier()
else:
self.communicator.broadcast(weight, src=self.client_rank)
self.communicator.group.barrier()
self._direct_set_weight(name, weight)
def batch_update_named_params(self, params_list: list[tuple[str, str, tuple]]):
"""Receive and apply multiple weight tensors in sequence.
Args:
params_list: List of (name, dtype_str, shape) tuples.
"""
if self.communicator is None:
raise RuntimeError("Communicator not initialized.")
weights_to_load = []
for name, dtype_str, shape in params_list:
dtype = getattr(torch, dtype_str.split(".")[-1])
weight = torch.empty(shape, dtype=dtype, device=self.device)
if is_torch_xpu_available():
self.communicator.broadcast(weight, root=self.client_rank)
else:
self.communicator.broadcast(weight, src=self.client_rank)
weights_to_load.append((name, weight))
# Single barrier after all broadcasts
if is_torch_xpu_available():
self.communicator.barrier()
else:
self.communicator.group.barrier()
# Load weights using direct set (handles stacked params)
for name, weight in weights_to_load:
self._direct_set_weight(name, weight)

View File

@@ -189,3 +189,125 @@ class TRLConfig(BaseModel):
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
},
)
# Async GRPO fields
use_data_producer: bool = Field(
default=False,
json_schema_extra={
"description": "Use the GRPODataProducer protocol for online data generation."
},
)
async_prefetch: bool = Field(
default=False,
json_schema_extra={
"description": "Generate rollouts in a background thread while training on the previous rollout."
},
)
prefetch_depth: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of rollouts to prefetch ahead of training."
},
)
vllm_sync_interval: int | None = Field(
default=None,
json_schema_extra={
"description": "Sync model weights to vLLM every N optimizer steps (async mode only)."
},
)
streaming_partial_batch: bool | None = Field(
default=None,
json_schema_extra={
"description": "Score prompt groups incrementally instead of the full batch at once."
},
)
streaming_min_groups: int | None = Field(
default=None,
json_schema_extra={
"description": "Minimum prompt groups to score per streaming chunk."
},
)
vllm_importance_sampling_correction: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply IS correction for distribution mismatch between vLLM and training model."
},
)
vllm_importance_sampling_mode: (
Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"]
| None
) = Field(
default=None,
json_schema_extra={
"description": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask."
},
)
vllm_importance_sampling_cap: float | None = Field(
default=None,
json_schema_extra={"description": "Cap C for IS ratio clipping/masking."},
)
off_policy_mask_threshold: float | None = Field(
default=None,
json_schema_extra={
"description": "KL threshold for off-policy sequence masking (OPSM). None = disabled."
},
)
use_bias_correction_kl: bool | None = Field(
default=None,
json_schema_extra={"description": "Apply IS correction to KL divergence term."},
)
reward_num_workers: int = Field(
default=1,
json_schema_extra={
"description": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
},
)
replay_buffer_size: int = Field(
default=0,
json_schema_extra={
"description": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
},
)
replay_recompute_logps: bool = Field(
default=True,
json_schema_extra={
"description": "When True (default), recompute old_per_token_logps for replayed groups using the current "
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
"Only relevant when replay_buffer_size > 0."
},
)
reroll_start_fraction: float = Field(
default=1.0,
json_schema_extra={
"description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
},
)
reroll_max_groups: int = Field(
default=1,
json_schema_extra={
"description": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
},
)
skip_zero_advantage_batches: bool = Field(
default=True,
json_schema_extra={
"description": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
"logged with skipped_zero_adv_batches=1 for monitoring."
},
)
vllm_lora_sync: bool = Field(
default=False,
json_schema_extra={
"description": "Sync LoRA adapter to vLLM via filesystem instead of merging + NCCL broadcast. "
"Auto-selects vllm_serve_lora serve module. Syncs only LoRA adapter weights vs full merged model."
},
)

View File

@@ -675,20 +675,6 @@ class LoRAValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_rl(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("rl"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with RL at the moment."
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_trust_remote_code(cls, data):

View File

@@ -57,3 +57,10 @@ class VllmConfig(BaseModel):
default=None,
json_schema_extra={"description": "Reasoning parser for VLLM"},
)
serve_module: str | None = Field(
default=None,
json_schema_extra={
"description": "Python module for vLLM serve script. Set to 'axolotl.scripts.vllm_serve_lora' "
"for native LoRA support, or leave None for default TRL serve."
},
)

View File

@@ -0,0 +1,220 @@
"""Unit tests for async GRPO"""
import unittest
from unittest.mock import MagicMock
import torch
class TestReplayBuffer(unittest.TestCase):
"""Tests for ReplayBuffer edge cases."""
def test_add_noop_when_max_size_zero(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=0)
buf.add(1.0, {"data": "test"})
self.assertEqual(len(buf), 0)
def test_add_noop_when_max_size_negative(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=-1)
buf.add(1.0, {"data": "test"})
self.assertEqual(len(buf), 0)
def test_sample_returns_none_when_max_size_zero(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=0)
self.assertIsNone(buf.sample(1))
def test_sample_returns_none_when_empty(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=5)
self.assertIsNone(buf.sample(1))
def test_normal_add_and_sample(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=3)
buf.add(1.0, {"a": 1})
buf.add(2.0, {"a": 2})
buf.add(3.0, {"a": 3})
self.assertEqual(len(buf), 3)
result = buf.sample(1)
self.assertIsNotNone(result)
self.assertEqual(len(result), 1)
def test_replaces_lowest_when_full(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=2)
buf.add(1.0, {"a": 1})
buf.add(2.0, {"a": 2})
buf.add(3.0, {"a": 3}) # should replace score=1.0
self.assertEqual(len(buf), 2)
scores = sorted(item[0] for item in buf._heap)
self.assertEqual(scores, [2.0, 3.0])
class TestGRPOStrategyConflict(unittest.TestCase):
"""Tests for sequence_parallel + async_grpo conflict detection."""
def test_raises_on_both_enabled(self):
from axolotl.core.trainers.grpo import GRPOStrategy
with self.assertRaises(ValueError) as ctx:
GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=True)
self.assertIn("sequence_parallel", str(ctx.exception))
self.assertIn("async_grpo", str(ctx.exception))
def test_sequence_parallel_only(self):
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOSequenceParallelTrainer,
)
cls = GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=False)
self.assertIs(cls, AxolotlGRPOSequenceParallelTrainer)
def test_async_only(self):
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.grpo.trainer import AxolotlAsyncGRPOTrainer
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=True)
self.assertIs(cls, AxolotlAsyncGRPOTrainer)
def test_neither(self):
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=False)
self.assertIs(cls, AxolotlGRPOTrainer)
class TestDequantizeFP8TailBlocks(unittest.TestCase):
"""Tests for FP8 dequantization with non-divisible dimensions."""
def test_exact_divisible_shape(self):
from axolotl.kernels.quantize import dequantize_fp8
W = torch.randn(256, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
scale_inv = torch.ones(2, 1, dtype=torch.bfloat16)
result = dequantize_fp8(W, scale_inv)
self.assertEqual(result.shape, (256, 128))
self.assertEqual(result.dtype, torch.bfloat16)
def test_non_divisible_rows(self):
from axolotl.kernels.quantize import dequantize_fp8
# 130 rows, scale has 2 blocks (block_size ~65 for exact div, but with
# tail blocks: first block=65 rows, second=65 rows, 130%2=0 actually).
# Use 131 rows with 2 scale blocks to trigger tail handling.
W = torch.ones(131, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
scale_inv = torch.tensor([[2.0], [3.0]], dtype=torch.bfloat16)
result = dequantize_fp8(W, scale_inv)
self.assertEqual(result.shape, (131, 128))
self.assertEqual(result.dtype, torch.bfloat16)
def test_non_divisible_cols(self):
from axolotl.kernels.quantize import dequantize_fp8
W = torch.ones(128, 200, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
scale_inv = torch.ones(1, 2, dtype=torch.bfloat16)
result = dequantize_fp8(W, scale_inv)
self.assertEqual(result.shape, (128, 200))
def test_scalar_scale(self):
from axolotl.kernels.quantize import dequantize_fp8
W = torch.ones(64, 64, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
scale_inv = torch.tensor(2.0, dtype=torch.bfloat16)
result = dequantize_fp8(W, scale_inv)
self.assertEqual(result.shape, (64, 64))
class TestLoraFP8Guard(unittest.TestCase):
"""Tests that get_lora_parameters only uses weight_scale_inv for FP8 weights."""
def test_non_fp8_weight_skips_scale_inv(self):
"""Non-FP8 weight should NOT pick up weight_scale_inv as quant_state."""
from axolotl.kernels.lora import get_lora_parameters
proj = MagicMock()
proj.disable_adapters = True
base_layer = MagicMock(spec=[]) # empty spec to control attrs precisely
# Use a real tensor for weight (bf16, no quant_state attr)
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16)
base_layer.bias = None
base_layer.weight_scale_inv = torch.ones(1) # should NOT be used for bf16
proj.base_layer = base_layer
W, b, quant_state, A, B, s = get_lora_parameters(proj)
# quant_state should be None since weight is bf16, not FP8
self.assertIsNone(quant_state)
def test_fp8_weight_uses_scale_inv(self):
"""FP8 weight should pick up weight_scale_inv as quant_state."""
from axolotl.kernels.lora import get_lora_parameters
proj = MagicMock()
proj.disable_adapters = True
base_layer = MagicMock()
proj.base_layer = base_layer
# FP8 weight
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16).to(
torch.float8_e4m3fn
)
base_layer.bias = None
scale_inv = torch.ones(1)
base_layer.weight_scale_inv = scale_inv
W, b, quant_state, A, B, s = get_lora_parameters(proj)
self.assertIs(quant_state, scale_inv)
class TestValidateQuantPatchRestore(unittest.TestCase):
"""Test that validate_quantization_for_training is restored after trainer creation."""
def test_patch_restored_on_success(self):
"""Monkeypatch should be restored even after successful trainer creation."""
import transformers.trainer as _trainer_module
original = _trainer_module.validate_quantization_for_training
# After the build() method runs, original should be restored.
# We can't easily test the full build(), but we can test the pattern.
_orig = _trainer_module.validate_quantization_for_training
_trainer_module.validate_quantization_for_training = lambda model: None
try:
pass # simulate trainer_cls() succeeding
finally:
_trainer_module.validate_quantization_for_training = _orig
self.assertIs(_trainer_module.validate_quantization_for_training, original)
def test_patch_restored_on_error(self):
"""Monkeypatch should be restored even if trainer creation raises."""
import transformers.trainer as _trainer_module
original = _trainer_module.validate_quantization_for_training
_orig = _trainer_module.validate_quantization_for_training
_trainer_module.validate_quantization_for_training = lambda model: None
try:
raise ValueError("test error")
except ValueError:
pass
finally:
_trainer_module.validate_quantization_for_training = _orig
self.assertIs(_trainer_module.validate_quantization_for_training, original)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,286 @@
"""Unit tests for TRL vLLM monkeypatches.
Tests:
- split_tensor_dict: scalar type preservation (int/float/bool)
- shuffle_sequence_dict: scalar type preservation
- extract_logprobs: NaN → 0.0 replacement
- VLLMClient.batch_update_named_params: method exists after patch
- VLLMGeneration: weight_sync_chunk_size attribute after patch
- Patch idempotency: applying patch twice doesn't break anything
"""
import unittest
from dataclasses import dataclass
from unittest.mock import MagicMock
import torch
class TestSplitTensorDict(unittest.TestCase):
"""Tests for patched split_tensor_dict."""
def setUp(self):
from axolotl.monkeypatch.trainer.trl_vllm import _patched_split_tensor_dict
self.split = _patched_split_tensor_dict
def test_scalar_int_preserved(self):
d = {"a": torch.randn(4, 3), "count": 42}
chunks = self.split(d, 2)
self.assertEqual(len(chunks), 2)
self.assertEqual(chunks[0]["count"], 42)
self.assertEqual(chunks[1]["count"], 42)
def test_scalar_float_preserved(self):
d = {"a": torch.randn(6, 2), "lr": 1e-5}
chunks = self.split(d, 3)
for c in chunks:
self.assertEqual(c["lr"], 1e-5)
def test_scalar_bool_preserved(self):
d = {"a": torch.randn(4, 2), "flag": True}
chunks = self.split(d, 2)
for c in chunks:
self.assertTrue(c["flag"])
def test_none_preserved(self):
d = {"a": torch.randn(4, 2), "b": None}
chunks = self.split(d, 2)
for c in chunks:
self.assertIsNone(c["b"])
def test_tensor_split(self):
t = torch.arange(8).reshape(4, 2)
d = {"a": t, "n": 10}
chunks = self.split(d, 2)
self.assertEqual(chunks[0]["a"].shape, (2, 2))
self.assertEqual(chunks[1]["a"].shape, (2, 2))
torch.testing.assert_close(chunks[0]["a"], t[:2])
torch.testing.assert_close(chunks[1]["a"], t[2:])
def test_0d_tensor_preserved(self):
d = {"a": torch.randn(4, 2), "scalar_t": torch.tensor(3.14)}
chunks = self.split(d, 2)
for c in chunks:
self.assertAlmostEqual(c["scalar_t"].item(), 3.14, places=5)
def test_list_split(self):
d = {"a": torch.randn(4, 2), "names": ["a", "b", "c", "d"]}
chunks = self.split(d, 2)
self.assertEqual(chunks[0]["names"], ["a", "b"])
self.assertEqual(chunks[1]["names"], ["c", "d"])
class TestShuffleSequenceDict(unittest.TestCase):
"""Tests for patched shuffle_sequence_dict."""
def setUp(self):
from axolotl.monkeypatch.trainer.trl_vllm import _patched_shuffle_sequence_dict
self.shuffle = _patched_shuffle_sequence_dict
def test_scalar_int_preserved(self):
d = {"a": torch.randn(4, 3), "count": 42}
result = self.shuffle(d)
self.assertEqual(result["count"], 42)
def test_scalar_float_preserved(self):
d = {"a": torch.randn(4, 3), "lr": 1e-5}
result = self.shuffle(d)
self.assertEqual(result["lr"], 1e-5)
def test_scalar_bool_preserved(self):
d = {"a": torch.randn(4, 3), "flag": False}
result = self.shuffle(d)
self.assertFalse(result["flag"])
def test_none_preserved(self):
d = {"a": torch.randn(4, 3), "b": None}
result = self.shuffle(d)
self.assertIsNone(result["b"])
def test_tensor_permuted(self):
torch.manual_seed(42)
t = torch.arange(4).float()
d = {"a": t}
result = self.shuffle(d)
# Same elements, possibly different order
self.assertEqual(sorted(result["a"].tolist()), sorted(t.tolist()))
self.assertEqual(result["a"].shape, t.shape)
def test_list_permuted(self):
torch.manual_seed(42)
d = {"a": torch.randn(3, 2), "names": ["x", "y", "z"]}
result = self.shuffle(d)
self.assertEqual(sorted(result["names"]), ["x", "y", "z"])
self.assertEqual(len(result["names"]), 3)
def test_0d_tensor_preserved(self):
d = {"a": torch.randn(4, 2), "scalar_t": torch.tensor(3.14)}
result = self.shuffle(d)
self.assertAlmostEqual(result["scalar_t"].item(), 3.14, places=5)
class TestExtractLogprobs(unittest.TestCase):
"""Tests for patched extract_logprobs (NaN → 0.0)."""
def setUp(self):
from axolotl.monkeypatch.trainer.trl_vllm import _patched_extract_logprobs
self.extract = _patched_extract_logprobs
def _make_output(self, logprob_values):
"""Create a mock vLLM RequestOutput with given logprob values."""
@dataclass
class LogprobItem:
logprob: float
rank: int
@dataclass
class SeqOutput:
logprobs: list[dict[int, LogprobItem]] | None
@dataclass
class RequestOutput:
outputs: list[SeqOutput]
logprobs_list = []
for vals in logprob_values:
lp_dict = {i: LogprobItem(logprob=v, rank=i) for i, v in enumerate(vals)}
logprobs_list.append(lp_dict)
return RequestOutput(outputs=[SeqOutput(logprobs=logprobs_list)])
def test_nan_replaced_with_zero(self):
output = self._make_output([[float("nan"), 0.5], [-0.3, float("nan")]])
logprobs, token_ids = self.extract([output])
self.assertEqual(logprobs[0][0][0], 0.0) # NaN → 0.0
self.assertEqual(logprobs[0][0][1], 0.5)
self.assertEqual(logprobs[0][1][0], -0.3)
self.assertEqual(logprobs[0][1][1], 0.0) # NaN → 0.0
def test_normal_values_preserved(self):
output = self._make_output([[-0.5, -1.2], [-0.1, -2.0]])
logprobs, token_ids = self.extract([output])
self.assertAlmostEqual(logprobs[0][0][0], -0.5)
self.assertAlmostEqual(logprobs[0][0][1], -1.2)
def test_none_logprobs_returns_none(self):
@dataclass
class SeqOutput:
logprobs: None = None
@dataclass
class RequestOutput:
outputs: list
output = RequestOutput(outputs=[SeqOutput()])
logprobs, token_ids = self.extract([output])
self.assertIsNone(logprobs)
self.assertIsNone(token_ids)
def test_token_ids_extracted(self):
output = self._make_output([[-0.5]])
logprobs, token_ids = self.extract([output])
self.assertEqual(token_ids[0][0], [0]) # token_id=0 from enumerate
class TestPatchApplication(unittest.TestCase):
"""Tests for patch_trl_vllm() application."""
def test_batch_update_added_to_client(self):
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
patch_trl_vllm()
from trl.generation.vllm_client import VLLMClient
self.assertTrue(hasattr(VLLMClient, "batch_update_named_params"))
def test_extract_logprobs_patched(self):
from axolotl.monkeypatch.trainer.trl_vllm import (
_patched_extract_logprobs,
patch_trl_vllm,
)
patch_trl_vllm()
from trl.generation import vllm_generation
self.assertIs(vllm_generation.extract_logprobs, _patched_extract_logprobs)
def test_utils_patched(self):
from axolotl.monkeypatch.trainer.trl_vllm import (
_patched_shuffle_sequence_dict,
_patched_split_tensor_dict,
patch_trl_vllm,
)
patch_trl_vllm()
import trl.trainer.utils
self.assertIs(trl.trainer.utils.split_tensor_dict, _patched_split_tensor_dict)
self.assertIs(
trl.trainer.utils.shuffle_sequence_dict, _patched_shuffle_sequence_dict
)
def test_patch_idempotent(self):
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
patch_trl_vllm()
patch_trl_vllm() # second call should not error
from trl.generation.vllm_client import VLLMClient
self.assertTrue(hasattr(VLLMClient, "batch_update_named_params"))
class TestBatchUpdateChunking(unittest.TestCase):
"""Tests for batch_update_named_params chunking logic."""
def test_no_chunk_single_batch(self):
from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params
# Test that with chunk_size=None, all params go in one chunk
client = MagicMock()
client.base_url = "http://localhost:8000"
client.session.post.return_value = MagicMock(status_code=200)
client.communicator = MagicMock()
client.communicator.group = MagicMock()
client.rank = 0
params = [
("layer.0.weight", torch.randn(10, 10)),
("layer.1.weight", torch.randn(10, 10)),
]
_batch_update_named_params(client, params, chunk_size=None)
# Should make exactly 1 HTTP call
self.assertEqual(client.session.post.call_count, 1)
def test_chunk_splits_params(self):
from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params
client = MagicMock()
client.base_url = "http://localhost:8000"
client.session.post.return_value = MagicMock(status_code=200)
client.communicator = MagicMock()
client.communicator.group = MagicMock()
client.rank = 0
params = [
("a", torch.randn(100)), # 100 elements
("b", torch.randn(100)), # 100 elements
("c", torch.randn(100)), # 100 elements
]
_batch_update_named_params(client, params, chunk_size=150)
# Should make 2 HTTP calls: [a,b] then [c] (100+100 > 150 triggers split)
# Actually: a=100 < 150, a+b=200 > 150 → chunk [a], then b=100 < 150,
# b+c=200 > 150 → chunk [b], then [c]. So 3 calls.
# Wait: first a added (100 < 150), then b: 100+100=200 > 150, so chunk=[a],
# new chunk starts with b (100 < 150), then c: 100+100=200 > 150, so chunk=[b],
# final chunk=[c]. 3 HTTP calls.
self.assertEqual(client.session.post.call_count, 3)
if __name__ == "__main__":
unittest.main()