Support for Async GRPO (#3486)
* async grpo support * implement data producer * use fast async * handle call to create data producer * fix liger kernel setup * fix replay buffer * chore: lint * make gpus go brrr * chore: lint * inplace div_, unwrap model for logits in bf16 * fuse selective softmax and empty cuda cache on each scoring step * remove waiting for synch time and fix race * make fp8 work and allow lora kernels w rl * grpo with lora vllm sync and fixes for sharded distributed * update docs * more patches so it works against trl main * address PR feedback for corerabbit
This commit is contained in:
207
docs/rlhf.qmd
207
docs/rlhf.qmd
@@ -721,6 +721,213 @@ trl:
|
||||
|
||||
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
||||
|
||||
#### Async GRPO
|
||||
|
||||
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
use_data_producer: true # Enable data producer protocol
|
||||
use_vllm: true
|
||||
async_prefetch: true # Generate rollouts in background thread
|
||||
prefetch_depth: 1 # Number of rollouts to prefetch
|
||||
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
|
||||
:::
|
||||
|
||||
##### vLLM LoRA Sync
|
||||
|
||||
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
|
||||
|
||||
```yaml
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
trl:
|
||||
vllm_lora_sync: true # Enable native LoRA sync
|
||||
```
|
||||
|
||||
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
```
|
||||
|
||||
Then start training on a separate GPU:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
|
||||
:::
|
||||
|
||||
##### Streaming Partial Batch
|
||||
|
||||
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
streaming_partial_batch: true
|
||||
```
|
||||
|
||||
##### Importance Sampling Correction
|
||||
|
||||
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
vllm_importance_sampling_correction: true # Enable IS correction
|
||||
importance_sampling_level: token # 'token' or 'sequence'
|
||||
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
|
||||
```
|
||||
|
||||
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
|
||||
- `importance_sampling_level: sequence` applies per-sequence IS ratios
|
||||
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
|
||||
|
||||
##### Replay Buffer
|
||||
|
||||
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
replay_buffer_size: 100 # Max cached groups (0 = disabled)
|
||||
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
|
||||
:::
|
||||
|
||||
##### Deferred Re-rolling
|
||||
|
||||
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
|
||||
reroll_max_groups: 1 # Max groups to replace per batch
|
||||
```
|
||||
|
||||
##### Zero-Advantage Batch Skipping
|
||||
|
||||
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
skip_zero_advantage_batches: true # default
|
||||
```
|
||||
|
||||
##### Parallel Reward Workers
|
||||
|
||||
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
|
||||
```
|
||||
|
||||
##### Full Async GRPO Example
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
gpu_memory_utilization: 0.35
|
||||
dtype: auto
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
rl: grpo
|
||||
trl:
|
||||
use_data_producer: true
|
||||
use_vllm: true
|
||||
async_prefetch: true
|
||||
prefetch_depth: 1
|
||||
vllm_sync_interval: 2
|
||||
vllm_lora_sync: true
|
||||
streaming_partial_batch: true
|
||||
vllm_importance_sampling_correction: true
|
||||
off_policy_mask_threshold: 0.5
|
||||
importance_sampling_level: token
|
||||
num_generations: 8
|
||||
max_completion_length: 512
|
||||
reward_funcs:
|
||||
- rewards.accuracy_reward
|
||||
reroll_start_fraction: 0.5
|
||||
replay_buffer_size: 100
|
||||
reward_num_workers: 4
|
||||
skip_zero_advantage_batches: true
|
||||
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-TIR
|
||||
type: rewards.prompt_transform
|
||||
split: train
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
max_steps: 500
|
||||
learning_rate: 1e-5
|
||||
bf16: true
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train on GPU 1
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
##### Multi-GPU Async GRPO
|
||||
|
||||
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
|
||||
|
||||
**FSDP:**
|
||||
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
```
|
||||
|
||||
**DeepSpeed ZeRO-3:**
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required for ZeRO-3
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train on GPUs 0,1
|
||||
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
|
||||
:::
|
||||
|
||||
### GDPO
|
||||
|
||||
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
|
||||
|
||||
@@ -38,7 +38,18 @@ def do_vllm_serve(
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
|
||||
serve_module = cli_args.get("serve_module") or getattr(
|
||||
cfg.vllm, "serve_module", None
|
||||
)
|
||||
if (
|
||||
serve_module is None
|
||||
and getattr(cfg, "trl", None)
|
||||
and getattr(cfg.trl, "vllm_lora_sync", False)
|
||||
):
|
||||
serve_module = "axolotl.scripts.vllm_serve_lora"
|
||||
if serve_module is None:
|
||||
serve_module = "trl.scripts.vllm_serve"
|
||||
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
@@ -68,7 +79,7 @@ def do_vllm_serve(
|
||||
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
|
||||
)
|
||||
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
base_kwargs = dict(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
data_parallel_size=data_parallel_size,
|
||||
@@ -78,7 +89,21 @@ def do_vllm_serve(
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
reasoning_parser=reasoning_parser,
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
|
||||
# Use LoRAScriptArguments when serving with native LoRA support
|
||||
if serve_module == "axolotl.scripts.vllm_serve_lora":
|
||||
from axolotl.scripts.vllm_serve_lora import LoRAScriptArguments
|
||||
|
||||
lora_kwargs = {}
|
||||
if hasattr(cfg, "lora_r") and cfg.lora_r:
|
||||
lora_kwargs["max_lora_rank"] = cfg.lora_r
|
||||
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
|
||||
else:
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
**base_kwargs,
|
||||
reasoning_parser=reasoning_parser,
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
|
||||
vllm_serve_main(vllm_script_args)
|
||||
|
||||
@@ -54,8 +54,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
async_grpo = bool(
|
||||
self.cfg.trl
|
||||
and (
|
||||
getattr(self.cfg.trl, "async_prefetch", False)
|
||||
or getattr(self.cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1,
|
||||
async_grpo=async_grpo,
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
@@ -151,7 +159,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
async_grpo = bool(
|
||||
self.cfg.trl
|
||||
and (
|
||||
getattr(self.cfg.trl, "async_prefetch", False)
|
||||
or getattr(self.cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
training_args_cls = GRPOStrategy.get_training_args_class(
|
||||
async_grpo=async_grpo
|
||||
)
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
if self.cfg.rl is RLType.GDPO:
|
||||
@@ -217,13 +234,36 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
|
||||
trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
# Allow FP8-quantized models to be fine-tuned with LoRA adapters.
|
||||
# transformers' validate_quantization_for_training blocks FP8 because
|
||||
# hf_quantizer.is_trainable is False, but LoRA only trains the adapters
|
||||
# (base weights stay frozen in FP8).
|
||||
_orig_validate_quant = None
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and hasattr(self.model, "is_quantized")
|
||||
and self.model.is_quantized
|
||||
):
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
_orig_validate_quant = _trainer_module.validate_quantization_for_training
|
||||
_trainer_module.validate_quantization_for_training = lambda model: None
|
||||
|
||||
try:
|
||||
trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
finally:
|
||||
if _orig_validate_quant is not None:
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
_trainer_module.validate_quantization_for_training = (
|
||||
_orig_validate_quant
|
||||
)
|
||||
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||
|
||||
@@ -9,8 +9,9 @@ from huggingface_hub import snapshot_download
|
||||
from requests import HTTPError
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlAsyncGRPOTrainer,
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
@@ -27,14 +28,31 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(
|
||||
cls, sequence_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
|
||||
cls,
|
||||
sequence_parallel: bool,
|
||||
async_grpo: bool = False,
|
||||
) -> (
|
||||
type[AxolotlGRPOTrainer]
|
||||
| type[AxolotlGRPOSequenceParallelTrainer]
|
||||
| type[AxolotlAsyncGRPOTrainer]
|
||||
):
|
||||
if sequence_parallel and async_grpo:
|
||||
raise ValueError(
|
||||
"sequence_parallel and async_grpo cannot both be enabled. "
|
||||
"Disable one of context_parallel_size > 1 or async_prefetch/use_data_producer."
|
||||
)
|
||||
if sequence_parallel:
|
||||
return AxolotlGRPOSequenceParallelTrainer
|
||||
if async_grpo:
|
||||
return AxolotlAsyncGRPOTrainer
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
|
||||
def get_training_args_class(
|
||||
cls, async_grpo: bool = False
|
||||
) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
|
||||
if async_grpo:
|
||||
return AxolotlAsyncGRPOConfig
|
||||
return AxolotlGRPOConfig
|
||||
|
||||
@classmethod
|
||||
@@ -124,13 +142,63 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
|
||||
|
||||
if trl.use_liger_loss is not None:
|
||||
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
||||
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
|
||||
|
||||
if trl.multi_objective_aggregation is not None:
|
||||
grpo_args_kwargs["multi_objective_aggregation"] = (
|
||||
trl.multi_objective_aggregation
|
||||
)
|
||||
|
||||
# Async GRPO fields
|
||||
if getattr(trl, "use_data_producer", None) is not None:
|
||||
grpo_args_kwargs["use_data_producer"] = trl.use_data_producer
|
||||
if getattr(trl, "async_prefetch", None) is not None:
|
||||
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
|
||||
if getattr(trl, "prefetch_depth", None) is not None:
|
||||
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
|
||||
if getattr(trl, "vllm_sync_interval", None) is not None:
|
||||
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
|
||||
if getattr(trl, "streaming_partial_batch", None) is not None:
|
||||
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
|
||||
if getattr(trl, "streaming_min_groups", None) is not None:
|
||||
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
|
||||
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_correction"] = (
|
||||
trl.vllm_importance_sampling_correction
|
||||
)
|
||||
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_mode"] = (
|
||||
trl.vllm_importance_sampling_mode
|
||||
)
|
||||
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_cap"] = (
|
||||
trl.vllm_importance_sampling_cap
|
||||
)
|
||||
if getattr(trl, "off_policy_mask_threshold", None) is not None:
|
||||
grpo_args_kwargs["off_policy_mask_threshold"] = (
|
||||
trl.off_policy_mask_threshold
|
||||
)
|
||||
if getattr(trl, "use_bias_correction_kl", None) is not None:
|
||||
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
|
||||
|
||||
# Fast Async GRPO fields
|
||||
if getattr(trl, "reward_num_workers", None) is not None:
|
||||
grpo_args_kwargs["reward_num_workers"] = trl.reward_num_workers
|
||||
if getattr(trl, "replay_buffer_size", None) is not None:
|
||||
grpo_args_kwargs["replay_buffer_size"] = trl.replay_buffer_size
|
||||
if getattr(trl, "replay_recompute_logps", None) is not None:
|
||||
grpo_args_kwargs["replay_recompute_logps"] = trl.replay_recompute_logps
|
||||
if getattr(trl, "reroll_start_fraction", None) is not None:
|
||||
grpo_args_kwargs["reroll_start_fraction"] = trl.reroll_start_fraction
|
||||
if getattr(trl, "reroll_max_groups", None) is not None:
|
||||
grpo_args_kwargs["reroll_max_groups"] = trl.reroll_max_groups
|
||||
if getattr(trl, "skip_zero_advantage_batches", None) is not None:
|
||||
grpo_args_kwargs["skip_zero_advantage_batches"] = (
|
||||
trl.skip_zero_advantage_batches
|
||||
)
|
||||
if getattr(trl, "vllm_lora_sync", None) is not None:
|
||||
grpo_args_kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@@ -14,3 +15,10 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig):
|
||||
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
|
||||
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
2657
src/axolotl/core/trainers/grpo/async_trainer.py
Normal file
2657
src/axolotl/core/trainers/grpo/async_trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
768
src/axolotl/core/trainers/grpo/fast_async_trainer.py
Normal file
768
src/axolotl/core/trainers/grpo/fast_async_trainer.py
Normal file
@@ -0,0 +1,768 @@
|
||||
# Copyright 2020-2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Experimental GRPO extensions: parallel reward workers, replay buffer,
|
||||
deferred re-roll, and zero-advantage skipping.
|
||||
|
||||
These features are built as subclasses of GRPOTrainer and GRPODataProducer,
|
||||
using the hook system (_compute_rewards_for_batch, _post_advantage_hook,
|
||||
_pre_produce_hook) defined in the base classes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from trl import GRPOTrainer
|
||||
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOConfig,
|
||||
AsyncGRPOTrainer,
|
||||
GRPODataProducer,
|
||||
)
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class FastAsyncGRPOConfig(AsyncGRPOConfig):
|
||||
"""GRPOConfig with additional experimental parameters."""
|
||||
|
||||
reward_num_workers: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
|
||||
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
|
||||
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
|
||||
},
|
||||
)
|
||||
replay_buffer_size: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
|
||||
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
|
||||
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
replay_recompute_logps: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "When True (default), recompute old_per_token_logps for replayed groups using the current "
|
||||
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
|
||||
"Only relevant when replay_buffer_size > 0."
|
||||
},
|
||||
)
|
||||
reroll_start_fraction: float = field(
|
||||
default=0.5,
|
||||
metadata={
|
||||
"help": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
|
||||
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
|
||||
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
reroll_max_groups: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
|
||||
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
skip_zero_advantage_batches: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
|
||||
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
|
||||
"logged with skipped_zero_adv_batches=1 for monitoring."
|
||||
},
|
||||
)
|
||||
vllm_lora_sync: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "When True, sync LoRA adapter weights to vLLM via filesystem instead of merging into base model "
|
||||
"and NCCL-broadcasting all parameters. vLLM loads the adapter natively using Punica kernels. "
|
||||
"Requires vllm_serve_lora serve module (auto-selected when this is True). "
|
||||
"Syncs only LoRA adapter weights (much smaller) vs full merged model. Legacy merge behavior is used when False."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended data producer with re-roll injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RerollDataProducer(GRPODataProducer):
|
||||
"""GRPODataProducer that injects re-roll candidates into prompt batches.
|
||||
|
||||
Reads from the trainer's ``_reroll_buffer`` (populated by
|
||||
``GRPOExperimentalTrainer._post_advantage_hook``) and replaces the
|
||||
last N prompt groups with previously-failed prompts.
|
||||
"""
|
||||
|
||||
def _pre_produce_hook(self, inputs: list, global_step: int) -> list:
|
||||
trainer = self._trainer
|
||||
reroll_buf = getattr(trainer, "_reroll_buffer", None)
|
||||
reroll_lock = getattr(trainer, "_reroll_lock", None)
|
||||
if reroll_buf is None or reroll_lock is None:
|
||||
return inputs
|
||||
|
||||
max_steps = getattr(trainer.args, "max_steps", -1)
|
||||
start_frac = getattr(trainer.args, "reroll_start_fraction", 1.0)
|
||||
max_groups = getattr(trainer.args, "reroll_max_groups", 1)
|
||||
reroll_start_step = (
|
||||
max(1, int(max_steps * start_frac)) if max_steps > 0 else float("inf")
|
||||
)
|
||||
|
||||
if global_step < reroll_start_step:
|
||||
return inputs
|
||||
|
||||
with reroll_lock:
|
||||
n_to_take = min(max_groups, len(reroll_buf))
|
||||
reroll_prompts = [reroll_buf.pop(0) for _ in range(n_to_take)]
|
||||
|
||||
if reroll_prompts:
|
||||
num_gen = self._num_generations
|
||||
n_groups = len(inputs) // num_gen
|
||||
for i, reroll_prompt in enumerate(reroll_prompts):
|
||||
group_idx = n_groups - 1 - i
|
||||
if group_idx < 0:
|
||||
break
|
||||
start = group_idx * num_gen
|
||||
for j in range(num_gen):
|
||||
inputs[start + j] = reroll_prompt
|
||||
logger.info(
|
||||
f"[REROLL] Step {global_step}: replaced {len(reroll_prompts)}/{n_groups} prompt groups "
|
||||
f"with deferred re-roll candidates ({len(reroll_buf)} remaining)"
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistent reward subprocess pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _persistent_reward_worker(conn):
|
||||
"""Long-lived reward worker. Receives work items, returns results."""
|
||||
while True:
|
||||
try:
|
||||
msg = conn.recv()
|
||||
except EOFError:
|
||||
break
|
||||
if msg is None: # Shutdown signal
|
||||
break
|
||||
(
|
||||
reward_funcs,
|
||||
prompts,
|
||||
completions,
|
||||
completion_ids_list,
|
||||
inputs,
|
||||
reward_func_names,
|
||||
) = msg
|
||||
try:
|
||||
keys = [
|
||||
key
|
||||
for key in inputs[0]
|
||||
if key not in ["prompt", "completion", "completion_ids"]
|
||||
]
|
||||
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
||||
results = []
|
||||
for reward_func, _reward_func_name in zip(
|
||||
reward_funcs, reward_func_names, strict=True
|
||||
):
|
||||
output = reward_func(
|
||||
prompts=prompts,
|
||||
completions=completions,
|
||||
completion_ids=completion_ids_list,
|
||||
**reward_kwargs,
|
||||
)
|
||||
results.append(
|
||||
[float(r) if r is not None else float("nan") for r in output]
|
||||
)
|
||||
conn.send(results)
|
||||
except Exception:
|
||||
conn.send(None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended trainer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
"""GRPOTrainer with experimental extensions.
|
||||
|
||||
Adds:
|
||||
- Parallel reward subprocess workers (``reward_num_workers``)
|
||||
- Replay buffer for high-signal group reuse (``replay_buffer_size``)
|
||||
- Deferred re-roll of failed prompts (``reroll_start_fraction``)
|
||||
- Zero-advantage micro-batch skipping
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# These must be initialized before super().__init__() because
|
||||
# _create_data_producer (called during super().__init__) needs them.
|
||||
self._reroll_buffer: list = []
|
||||
self._reroll_lock = threading.Lock()
|
||||
|
||||
# Temporarily suppress the base class's Liger + OPSM validation check,
|
||||
# since this subclass supports it via a custom compute_liger_loss override.
|
||||
grpo_args = kwargs.get("args")
|
||||
if grpo_args is None:
|
||||
for a in args:
|
||||
if hasattr(a, "off_policy_mask_threshold"):
|
||||
grpo_args = a
|
||||
break
|
||||
saved_threshold = None
|
||||
if grpo_args is not None and getattr(grpo_args, "use_liger_kernel", False):
|
||||
saved_threshold = grpo_args.off_policy_mask_threshold
|
||||
grpo_args.off_policy_mask_threshold = None
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if saved_threshold is not None:
|
||||
grpo_args.off_policy_mask_threshold = saved_threshold
|
||||
self.off_policy_mask_threshold = saved_threshold
|
||||
|
||||
# Replay buffer
|
||||
if getattr(self.args, "replay_buffer_size", 0) > 0:
|
||||
self._replay_buffer = ReplayBuffer(max_size=self.args.replay_buffer_size)
|
||||
else:
|
||||
self._replay_buffer = None
|
||||
self._replay_recompute_logps = getattr(
|
||||
self.args, "replay_recompute_logps", True
|
||||
)
|
||||
|
||||
# Reward worker pool (lazy-initialized)
|
||||
self._reward_workers = None
|
||||
|
||||
# -- Factory override: use RerollDataProducer ----------------------------
|
||||
|
||||
def _create_data_producer(self, args, train_dataset):
|
||||
"""Override to use RerollDataProducer for re-roll prompt injection."""
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncDataProducer,
|
||||
ProducerConfig,
|
||||
)
|
||||
|
||||
producer_config = ProducerConfig(
|
||||
mini_epochs=args.num_iterations,
|
||||
max_rollouts=None,
|
||||
eval_during_produce=False,
|
||||
empty_cache_before_produce=True,
|
||||
empty_cache_after_produce=True,
|
||||
async_prefetch=args.async_prefetch,
|
||||
prefetch_depth=args.prefetch_depth,
|
||||
)
|
||||
data_producer = RerollDataProducer(
|
||||
config=producer_config,
|
||||
prompt_dataset=train_dataset,
|
||||
num_generations=self.num_generations,
|
||||
generation_batch_size=args.generation_batch_size,
|
||||
train_batch_size=args.per_device_train_batch_size,
|
||||
steps_per_generation=args.steps_per_generation,
|
||||
shuffle_dataset=self.shuffle_dataset,
|
||||
seed=args.seed,
|
||||
)
|
||||
data_producer.set_trainer(self)
|
||||
if args.async_prefetch:
|
||||
data_producer = AsyncDataProducer(
|
||||
data_producer,
|
||||
background_produce_kwargs={"skip_policy_logps": True},
|
||||
)
|
||||
return data_producer
|
||||
|
||||
# -- Reward worker pool --------------------------------------------------
|
||||
|
||||
def _get_reward_workers(self):
|
||||
"""Return a list of persistent reward worker subprocesses (lazy-initialized)."""
|
||||
import multiprocessing as _mp
|
||||
|
||||
num_workers = getattr(self.args, "reward_num_workers", 1)
|
||||
if num_workers < 1:
|
||||
num_workers = 1
|
||||
|
||||
if self._reward_workers is not None:
|
||||
alive = all(proc.is_alive() for conn, proc in self._reward_workers)
|
||||
if alive and len(self._reward_workers) == num_workers:
|
||||
return self._reward_workers
|
||||
self._shutdown_reward_workers()
|
||||
|
||||
workers = []
|
||||
for _ in range(num_workers):
|
||||
parent_conn, child_conn = _mp.Pipe()
|
||||
proc = _mp.Process(
|
||||
target=_persistent_reward_worker, args=(child_conn,), daemon=True
|
||||
)
|
||||
proc.start()
|
||||
child_conn.close()
|
||||
workers.append((parent_conn, proc))
|
||||
|
||||
self._reward_workers = workers
|
||||
return workers
|
||||
|
||||
def _shutdown_reward_workers(self):
|
||||
"""Shut down all persistent reward workers."""
|
||||
if self._reward_workers is None:
|
||||
return
|
||||
for conn, proc in self._reward_workers:
|
||||
try:
|
||||
conn.send(None)
|
||||
proc.join(timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._reward_workers = None
|
||||
|
||||
# -- Hook overrides ------------------------------------------------------
|
||||
|
||||
def _compute_rewards_for_batch(
|
||||
self, inputs, prompts, completions, completion_ids_list
|
||||
):
|
||||
"""Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
|
||||
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||
return self._collect_reward_workers(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||
"""Send reward work to subprocess workers (non-blocking).
|
||||
|
||||
Results are collected later by _collect_reward_workers, allowing GPU
|
||||
logprob computation to overlap with CPU reward computation.
|
||||
"""
|
||||
reward_can_bg = all(
|
||||
callable(rf)
|
||||
and not isinstance(rf, nn.Module)
|
||||
and not asyncio.iscoroutinefunction(rf)
|
||||
for rf in self.reward_funcs
|
||||
)
|
||||
num_workers = getattr(self.args, "reward_num_workers", 1)
|
||||
|
||||
if not reward_can_bg or num_workers <= 1:
|
||||
# Can't parallelize — store args for sync fallback in collect
|
||||
self._reward_workers_used = None
|
||||
self._pending_reward_args = (
|
||||
inputs,
|
||||
prompts,
|
||||
completions,
|
||||
completion_ids_list,
|
||||
)
|
||||
return
|
||||
|
||||
workers = self._get_reward_workers()
|
||||
num_generations = self.num_generations
|
||||
num_prompts = len(prompts)
|
||||
num_groups = num_prompts // num_generations
|
||||
|
||||
# Shard by prompt groups across workers
|
||||
groups_per_worker = max(1, (num_groups + len(workers) - 1) // len(workers))
|
||||
workers_used = []
|
||||
for w_idx, (conn, _proc) in enumerate(workers):
|
||||
g_start = w_idx * groups_per_worker
|
||||
g_end = min((w_idx + 1) * groups_per_worker, num_groups)
|
||||
if g_start >= num_groups:
|
||||
break
|
||||
s_start = g_start * num_generations
|
||||
s_end = g_end * num_generations
|
||||
conn.send(
|
||||
(
|
||||
self.reward_funcs,
|
||||
prompts[s_start:s_end],
|
||||
completions[s_start:s_end],
|
||||
completion_ids_list[s_start:s_end],
|
||||
inputs[s_start:s_end],
|
||||
self.reward_func_names,
|
||||
)
|
||||
)
|
||||
workers_used.append(conn)
|
||||
|
||||
self._reward_workers_used = workers_used
|
||||
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
def _collect_reward_workers(
|
||||
self, inputs, prompts, completions, completion_ids_list
|
||||
):
|
||||
"""Collect reward results from subprocess workers (blocks until done)."""
|
||||
from accelerate.utils import gather
|
||||
|
||||
workers_used = getattr(self, "_reward_workers_used", None)
|
||||
args = getattr(self, "_pending_reward_args", None)
|
||||
self._reward_workers_used = None
|
||||
self._pending_reward_args = None
|
||||
|
||||
if workers_used is None:
|
||||
# Sync fallback — compute on main thread
|
||||
if args is not None:
|
||||
return self._calculate_rewards(*args)
|
||||
return self._calculate_rewards(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
device = self.accelerator.device
|
||||
num_prompts = len(args[1]) if args else len(prompts)
|
||||
|
||||
# Collect results from workers
|
||||
all_worker_results = []
|
||||
any_failed = False
|
||||
for conn in workers_used:
|
||||
result = conn.recv()
|
||||
if result is None:
|
||||
any_failed = True
|
||||
# Drain remaining workers to prevent stale results in pipes
|
||||
for remaining_conn in workers_used:
|
||||
if remaining_conn is not conn:
|
||||
try:
|
||||
remaining_conn.recv()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
all_worker_results.append(result)
|
||||
|
||||
if not any_failed:
|
||||
rewards_per_func = torch.zeros(
|
||||
num_prompts, len(self.reward_funcs), device=device
|
||||
)
|
||||
offset = 0
|
||||
for worker_result in all_worker_results:
|
||||
chunk_size = len(worker_result[0])
|
||||
for i, result in enumerate(worker_result):
|
||||
rewards_per_func[offset : offset + chunk_size, i] = torch.tensor(
|
||||
result, dtype=torch.float32, device=device
|
||||
)
|
||||
offset += chunk_size
|
||||
return gather(rewards_per_func)
|
||||
|
||||
# Fallback to main thread on failure
|
||||
if args is not None:
|
||||
return self._calculate_rewards(*args)
|
||||
return self._calculate_rewards(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
def _post_advantage_hook(
|
||||
self,
|
||||
data: dict,
|
||||
rewards_per_func,
|
||||
advantages,
|
||||
inputs: list,
|
||||
num_generations: int,
|
||||
mode: str,
|
||||
s_start: int | None = None,
|
||||
s_end: int | None = None,
|
||||
is_last_chunk: bool = True,
|
||||
) -> None:
|
||||
"""Replay buffer store/replace + re-roll buffering."""
|
||||
from trl.models.utils import disable_gradient_checkpointing
|
||||
|
||||
# -- Replay buffer: store high-signal groups --
|
||||
if self._replay_buffer is not None:
|
||||
local_grouped = rewards_per_func.view(
|
||||
-1, num_generations, len(self.reward_funcs)
|
||||
)
|
||||
per_group_std = local_grouped.std(dim=1)
|
||||
has_signal = (per_group_std > 0).any(dim=1)
|
||||
offset = s_start or 0
|
||||
|
||||
if has_signal.any():
|
||||
grouped_adv = advantages.view(-1, num_generations)
|
||||
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
|
||||
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
|
||||
gi = group_idx.item()
|
||||
start = offset + gi * num_generations
|
||||
end = start + num_generations
|
||||
group_data = {}
|
||||
for key in data:
|
||||
val = data[key]
|
||||
if (
|
||||
isinstance(val, torch.Tensor)
|
||||
and val.dim() > 0
|
||||
and val.size(0) >= end
|
||||
):
|
||||
group_data[key] = val[start:end].clone()
|
||||
self._replay_buffer.add(replay_scores[gi].item(), group_data)
|
||||
|
||||
# Replace zero-signal groups with high-signal replay buffer entries
|
||||
# Only in non-streaming path (s_start is None) — streaming scores
|
||||
# groups incrementally, so replacement + logprob recompute would be
|
||||
# too expensive per chunk.
|
||||
n_replaced = 0
|
||||
if s_start is None:
|
||||
no_signal = ~has_signal
|
||||
replaced_ranges = []
|
||||
if no_signal.any() and len(self._replay_buffer) > 0:
|
||||
for group_idx in no_signal.nonzero(as_tuple=True)[0]:
|
||||
sampled = self._replay_buffer.sample(1)
|
||||
if sampled is None:
|
||||
break
|
||||
sampled_group = sampled[0]
|
||||
gi = group_idx.item()
|
||||
start = offset + gi * num_generations
|
||||
end = start + num_generations
|
||||
for key, val in sampled_group.items():
|
||||
if key in data and isinstance(data[key], torch.Tensor):
|
||||
src = val.to(data[key].device)
|
||||
tgt_seq_len = (
|
||||
data[key].size(1) if data[key].dim() > 1 else None
|
||||
)
|
||||
if start >= data[key].size(0) or end > data[key].size(
|
||||
0
|
||||
):
|
||||
continue
|
||||
if tgt_seq_len is not None:
|
||||
if src.size(1) <= tgt_seq_len:
|
||||
data[key][start:end] = 0
|
||||
data[key][start:end, : src.size(1)] = src
|
||||
else:
|
||||
data[key][start:end] = src[:, :tgt_seq_len]
|
||||
else:
|
||||
data[key][start:end] = src
|
||||
replaced_ranges.append((start, end))
|
||||
n_replaced += 1
|
||||
|
||||
# Recompute old_per_token_logps for replayed groups
|
||||
if (
|
||||
n_replaced > 0
|
||||
and self._replay_recompute_logps
|
||||
and "old_per_token_logps" in data
|
||||
):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
disable_gradient_checkpointing(
|
||||
self.model, self.args.gradient_checkpointing_kwargs
|
||||
),
|
||||
):
|
||||
for r_start, r_end in replaced_ranges:
|
||||
r_ids = torch.cat(
|
||||
[
|
||||
data["prompt_ids"][r_start:r_end],
|
||||
data["completion_ids"][r_start:r_end],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
r_mask = torch.cat(
|
||||
[
|
||||
data["prompt_mask"][r_start:r_end],
|
||||
data["completion_mask"][r_start:r_end],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
r_logits_to_keep = data["completion_ids"].size(1)
|
||||
r_fwd_kwargs = {}
|
||||
for fk in (
|
||||
"pixel_values",
|
||||
"image_grid_thw",
|
||||
"pixel_attention_mask",
|
||||
"image_sizes",
|
||||
"token_type_ids",
|
||||
"mm_token_type_ids",
|
||||
):
|
||||
if fk in data:
|
||||
r_fwd_kwargs[fk] = data[fk]
|
||||
r_logps, _ = self._get_per_token_logps_and_entropies(
|
||||
self.model,
|
||||
r_ids,
|
||||
r_mask,
|
||||
r_logits_to_keep,
|
||||
r_end - r_start,
|
||||
**r_fwd_kwargs,
|
||||
)
|
||||
data["old_per_token_logps"][r_start:r_end] = r_logps
|
||||
|
||||
if n_replaced > 0:
|
||||
self._metrics[mode]["replay_buffer_replacements"].append(
|
||||
float(n_replaced)
|
||||
)
|
||||
|
||||
if is_last_chunk:
|
||||
self._metrics[mode]["replay_buffer_size"].append(
|
||||
float(len(self._replay_buffer))
|
||||
)
|
||||
|
||||
# -- Re-roll buffer: store failed prompts --
|
||||
if getattr(self.args, "reroll_start_fraction", 1.0) < 1.0:
|
||||
grouped_rewards = rewards_per_func.view(
|
||||
-1, num_generations, len(self.reward_funcs)
|
||||
)
|
||||
per_group_std = grouped_rewards.std(dim=1)
|
||||
per_group_mean = grouped_rewards.mean(dim=1)
|
||||
zero_signal = (per_group_std == 0).all(dim=1)
|
||||
all_failed = (per_group_mean.abs() < 1e-6).all(dim=1)
|
||||
should_reroll = zero_signal & all_failed
|
||||
_n_buffered = 0
|
||||
with self._reroll_lock:
|
||||
for group_idx in should_reroll.nonzero(as_tuple=True)[0]:
|
||||
idx = group_idx.item() * num_generations
|
||||
if idx >= len(inputs):
|
||||
continue
|
||||
prompt_input = inputs[idx]
|
||||
self._reroll_buffer.append(prompt_input)
|
||||
_n_buffered += 1
|
||||
if _n_buffered > 0:
|
||||
self._metrics[mode]["reroll_buffered"].append(float(_n_buffered))
|
||||
if is_last_chunk:
|
||||
self._metrics[mode]["reroll_buffer_size"].append(
|
||||
float(len(self._reroll_buffer))
|
||||
)
|
||||
|
||||
# -- Zero-advantage skipping + Liger OPSM ---------------------------------
|
||||
|
||||
def compute_liger_loss(self, unwrapped_model, inputs):
|
||||
"""Liger loss with zero-adv skipping and off-policy sequence masking (OPSM).
|
||||
|
||||
The base class Liger path doesn't support OPSM because the fused kernel
|
||||
doesn't expose per-token logprobs needed for the KL computation. This
|
||||
override computes them via chunked lm_head matmul (no grad, low memory)
|
||||
and applies the OPSM to the loss mask before calling the kernel.
|
||||
"""
|
||||
if self.args.skip_zero_advantage_batches and torch.all(
|
||||
inputs["advantages"] == 0
|
||||
):
|
||||
mode = "train" if self.model.training else "eval"
|
||||
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
|
||||
return torch.tensor(
|
||||
0.0, device=inputs["advantages"].device, requires_grad=True
|
||||
)
|
||||
|
||||
if self.off_policy_mask_threshold is None:
|
||||
return super().compute_liger_loss(unwrapped_model, inputs)
|
||||
|
||||
# OPSM path: need per_token_logps for KL, which Liger kernel doesn't provide
|
||||
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
||||
completion_ids, completion_mask = (
|
||||
inputs["completion_ids"],
|
||||
inputs["completion_mask"],
|
||||
)
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
||||
logits_to_keep = completion_ids.size(1)
|
||||
|
||||
last_hidden_state = self._get_last_hidden_state(
|
||||
unwrapped_model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
inputs.get("pixel_values"),
|
||||
inputs.get("image_grid_thw"),
|
||||
inputs.get("pixel_attention_mask"),
|
||||
inputs.get("image_sizes"),
|
||||
)
|
||||
|
||||
loss_mask = (
|
||||
completion_mask
|
||||
if "tool_mask" not in inputs
|
||||
else completion_mask * inputs["tool_mask"]
|
||||
)
|
||||
|
||||
# Compute per_token_logps via chunked lm_head matmul (no grad, low memory)
|
||||
lm_weight = unwrapped_model.lm_head.weight
|
||||
lm_bias = unwrapped_model.lm_head.bias
|
||||
with torch.no_grad():
|
||||
per_token_logps_chunks = []
|
||||
for i in range(last_hidden_state.size(0)):
|
||||
chunk_logits = torch.matmul(last_hidden_state[i : i + 1], lm_weight.t())
|
||||
if lm_bias is not None:
|
||||
chunk_logits = chunk_logits + lm_bias
|
||||
chunk_lps = (
|
||||
chunk_logits.float()
|
||||
.log_softmax(-1)
|
||||
.gather(-1, completion_ids[i : i + 1].unsqueeze(-1))
|
||||
.squeeze(-1)
|
||||
)
|
||||
per_token_logps_chunks.append(chunk_lps)
|
||||
del chunk_logits
|
||||
per_token_logps = torch.cat(per_token_logps_chunks, dim=0)
|
||||
|
||||
advantages = inputs["advantages"]
|
||||
if advantages.dim() == 1:
|
||||
advantages_2d = advantages.unsqueeze(1)
|
||||
else:
|
||||
advantages_2d = advantages
|
||||
|
||||
sampling_per_token_logps = inputs.get("sampling_per_token_logps")
|
||||
if sampling_per_token_logps is None:
|
||||
sampling_per_token_logps = inputs.get("old_per_token_logps")
|
||||
if sampling_per_token_logps is None:
|
||||
sampling_per_token_logps = per_token_logps
|
||||
|
||||
off_policy_mask = GRPOTrainer.get_off_policy_mask(
|
||||
advantages=advantages_2d,
|
||||
per_token_logps=per_token_logps,
|
||||
sampling_per_token_logps=sampling_per_token_logps,
|
||||
mask=loss_mask,
|
||||
off_policy_threshold=self.off_policy_mask_threshold,
|
||||
)
|
||||
loss_mask = loss_mask * off_policy_mask
|
||||
|
||||
# Call the Liger fused kernel with OPSM-modified mask
|
||||
loss, metrics = self.liger_grpo_loss(
|
||||
_input=last_hidden_state,
|
||||
lin_weight=unwrapped_model.lm_head.weight,
|
||||
selected_token_ids=completion_ids,
|
||||
attention_mask=loss_mask,
|
||||
advantages=inputs["advantages"],
|
||||
bias=unwrapped_model.lm_head.bias,
|
||||
old_per_token_logps=inputs.get("old_per_token_logps"),
|
||||
ref_per_token_logps=inputs.get("ref_per_token_logps"),
|
||||
vllm_is_ratio=inputs.get("importance_sampling_ratio"),
|
||||
)
|
||||
|
||||
mean_kl = metrics[0] if self.beta != 0.0 else None
|
||||
clip_ratio = metrics[-1]
|
||||
|
||||
mode = "train" if self.model.training else "eval"
|
||||
if self.beta != 0.0:
|
||||
self._metrics[mode]["kl"].append(
|
||||
self.accelerator.gather(mean_kl).mean().item()
|
||||
)
|
||||
self._metrics[mode]["clip_ratio"].append(
|
||||
self.accelerator.gather(clip_ratio).mean().item()
|
||||
)
|
||||
normalizer = (
|
||||
self.current_gradient_accumulation_steps if mode == "train" else 1.0
|
||||
)
|
||||
return loss / normalizer
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
if self.args.skip_zero_advantage_batches and torch.all(
|
||||
inputs["advantages"] == 0
|
||||
):
|
||||
mode = "train" if self.model.training else "eval"
|
||||
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
|
||||
# Create zero loss with grad_fn. DeepSpeed requires grad_fn != None.
|
||||
# With ZeRO-3, parameters are partitioned (shape=[0], requires_grad=False)
|
||||
# so we can't just do `(p * 0).sum()`. Instead, do a tiny forward pass
|
||||
# with a single token to create a proper computation graph.
|
||||
prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1)
|
||||
attn = torch.ones_like(prompt_ids)
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
out = model(input_ids=prompt_ids, attention_mask=attn)
|
||||
return out.logits.sum() * 0
|
||||
return super()._compute_loss(model, inputs)
|
||||
44
src/axolotl/core/trainers/grpo/replay_buffer.py
Normal file
44
src/axolotl/core/trainers/grpo/replay_buffer.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Simple replay buffer for storing and sampling high-signal rollout groups."""
|
||||
|
||||
import heapq
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
"""Min-heap replay buffer that keeps the highest-scoring rollout groups.
|
||||
Groups are scored by signal quality (advantage magnitude * reward variance).
|
||||
When sampling, groups are drawn proportional to their scores.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int):
|
||||
self.max_size = max_size
|
||||
self._heap: list[tuple[float, int, dict]] = [] # min-heap of (score, id, data)
|
||||
self._counter = 0 # unique tiebreaker for heap
|
||||
|
||||
def __len__(self):
|
||||
return len(self._heap)
|
||||
|
||||
def add(self, score: float, data: dict):
|
||||
"""Add a group to the buffer. If full, replaces lowest-scoring entry."""
|
||||
if self.max_size <= 0:
|
||||
return
|
||||
self._counter += 1
|
||||
if len(self._heap) < self.max_size:
|
||||
heapq.heappush(self._heap, (score, self._counter, data))
|
||||
elif score > self._heap[0][0]:
|
||||
heapq.heapreplace(self._heap, (score, self._counter, data))
|
||||
|
||||
def sample(self, num_samples: int) -> list[dict] | None:
|
||||
"""Sample groups weighted by their scores. Returns None if buffer is empty."""
|
||||
if self.max_size <= 0 or not self._heap:
|
||||
return None
|
||||
|
||||
scores = torch.tensor([item[0] for item in self._heap], dtype=torch.float32)
|
||||
scores = scores.clamp(min=1e-8) # avoid zero probabilities
|
||||
probs = scores / scores.sum()
|
||||
replacement = num_samples > len(self._heap)
|
||||
indices = torch.multinomial(
|
||||
probs, num_samples, replacement=replacement
|
||||
).tolist()
|
||||
return [self._heap[i][2] for i in indices]
|
||||
@@ -40,6 +40,7 @@ from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOTrainer
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
@@ -66,6 +67,19 @@ class AxolotlGRPOTrainer(
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
|
||||
class AxolotlAsyncGRPOTrainer(
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
FastAsyncGRPOTrainer,
|
||||
):
|
||||
"""Extend AsyncGRPOTrainer with axolotl helpers"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "async", "axolotl"]
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_lora_parameters(
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
QuantState | None,
|
||||
QuantState | torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
float | None,
|
||||
@@ -48,9 +48,13 @@ def get_lora_parameters(
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
if quant_state is None and W.dtype == torch.float8_e4m3fn:
|
||||
quant_state = getattr(base_layer, "weight_scale_inv", None)
|
||||
return W, b, quant_state, None, None, None
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
if quant_state is None and W.dtype == torch.float8_e4m3fn:
|
||||
quant_state = getattr(base_layer, "weight_scale_inv", None)
|
||||
|
||||
active_adapter = (
|
||||
proj.active_adapters[0]
|
||||
@@ -81,7 +85,7 @@ def matmul_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
b: torch.Tensor | None,
|
||||
W_quant: QuantState | None,
|
||||
W_quant: QuantState | torch.Tensor | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
s: float | None,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Dequantization utilities for `bitsandbytes` integration."""
|
||||
"""Dequantization utilities for `bitsandbytes` and FP8 integration."""
|
||||
|
||||
import ctypes
|
||||
|
||||
@@ -15,9 +15,50 @@ CUDA_STREAM: torch.cuda.Stream | None = None
|
||||
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
|
||||
|
||||
|
||||
def dequantize_fp8(
|
||||
W: torch.Tensor,
|
||||
scale_inv: torch.Tensor,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
"""Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.
|
||||
|
||||
Args:
|
||||
W: FP8 weight tensor [out_features, in_features] in float8_e4m3fn.
|
||||
scale_inv: Per-block inverse scale [ceil(out/block), ceil(in/block)]
|
||||
or per-tensor scalar.
|
||||
dtype: Output dtype (default bf16).
|
||||
|
||||
Returns:
|
||||
Dequantized tensor in the specified dtype.
|
||||
"""
|
||||
W_float = W.to(dtype)
|
||||
if scale_inv.numel() == 1:
|
||||
return W_float * scale_inv.to(dtype)
|
||||
if scale_inv.dim() == 2 and W.dim() == 2:
|
||||
sr, sc = scale_inv.shape
|
||||
br = W.shape[0] // sr
|
||||
bc = W.shape[1] // sc
|
||||
# If dimensions are exactly divisible, use fast reshape path
|
||||
if sr * br == W.shape[0] and sc * bc == W.shape[1]:
|
||||
return (
|
||||
W_float.reshape(sr, br, sc, bc) * scale_inv[:, None, :, None].to(dtype)
|
||||
).reshape(W.shape)
|
||||
# Tail-block handling: compute actual block size (ceil division),
|
||||
# tile scale_inv to cover full shape, then crop to W's dimensions
|
||||
br_ceil = -(-W.shape[0] // sr) # ceil(rows / scale_rows) = block_size
|
||||
bc_ceil = -(-W.shape[1] // sc)
|
||||
scale_expanded = (
|
||||
scale_inv.to(dtype)
|
||||
.repeat_interleave(br_ceil, dim=0)
|
||||
.repeat_interleave(bc_ceil, dim=1)
|
||||
)[: W.shape[0], : W.shape[1]]
|
||||
return W_float * scale_expanded
|
||||
return W_float * scale_inv.to(dtype)
|
||||
|
||||
|
||||
def dequantize(
|
||||
W: torch.Tensor,
|
||||
quant_state: QuantState | list | None = None,
|
||||
quant_state: QuantState | list | torch.Tensor | None = None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -49,6 +90,15 @@ def dequantize(
|
||||
if quant_state is None:
|
||||
return W
|
||||
|
||||
# FP8 path: quant_state is actually scale_inv tensor
|
||||
if W.dtype == torch.float8_e4m3fn:
|
||||
scale_inv = quant_state
|
||||
# Caller may pass W.t() (non-contiguous) — dequantize in original
|
||||
# layout then transpose back so the result shape matches the input.
|
||||
if not W.is_contiguous() and W.dim() == 2:
|
||||
return dequantize_fp8(W.t(), scale_inv).t()
|
||||
return dequantize_fp8(W, scale_inv)
|
||||
|
||||
# Get the target device from input tensor W
|
||||
target_device = W.device
|
||||
|
||||
|
||||
@@ -160,6 +160,18 @@ def load_lora(
|
||||
else:
|
||||
model = get_peft_model(model, lora_config, **model_kwargs)
|
||||
|
||||
# FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training
|
||||
# requires a compute dtype (bf16/fp16). Cast trainable LoRA params.
|
||||
if cfg.torch_dtype:
|
||||
_fp8_cast_dtype = cfg.torch_dtype
|
||||
elif torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
||||
_fp8_cast_dtype = torch.bfloat16
|
||||
else:
|
||||
_fp8_cast_dtype = torch.float16
|
||||
for _name, param in model.named_parameters():
|
||||
if param.requires_grad and param.dtype == torch.float8_e4m3fn:
|
||||
param.data = param.data.to(_fp8_cast_dtype)
|
||||
|
||||
if rank == 0:
|
||||
try:
|
||||
model.print_trainable_parameters()
|
||||
|
||||
@@ -215,6 +215,8 @@ class ModelLoader:
|
||||
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
||||
if self.cfg.use_kernels:
|
||||
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
||||
if "allow_all_kernels" not in self.model_kwargs:
|
||||
self.model_kwargs["allow_all_kernels"] = self.cfg.use_kernels
|
||||
self._set_quantization_config()
|
||||
self._set_attention_config()
|
||||
self._check_model_requirements()
|
||||
|
||||
@@ -116,6 +116,7 @@ class PatchManager:
|
||||
self._apply_patch_deepspeed_zero3()
|
||||
self._apply_voxtral_patches()
|
||||
self._apply_apertus_patches()
|
||||
self._apply_trl_vllm_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
@@ -667,6 +668,17 @@ class PatchManager:
|
||||
|
||||
patch_apertus_xielu_activation()
|
||||
|
||||
def _apply_trl_vllm_patches(self):
|
||||
"""Apply TRL vLLM patches for batched weight sync, NaN logprobs fix, and scalar handling."""
|
||||
if (
|
||||
self.cfg.rl
|
||||
and getattr(self.cfg, "trl", None)
|
||||
and getattr(self.cfg.trl, "use_vllm", False)
|
||||
):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
|
||||
|
||||
patch_trl_vllm()
|
||||
|
||||
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
|
||||
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
|
||||
if self.cfg.scaling_softmax:
|
||||
|
||||
245
src/axolotl/monkeypatch/trainer/trl_vllm.py
Normal file
245
src/axolotl/monkeypatch/trainer/trl_vllm.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Monkeypatches for TRL's vLLM integration and trainer utils.
|
||||
|
||||
Adds:
|
||||
- VLLMClient.batch_update_named_params: batched weight sync (fewer HTTP round-trips)
|
||||
- extract_logprobs: NaN→0.0 fix (prevents downstream NaN propagation)
|
||||
- VLLMGeneration: weight_sync_chunk_size + batched sync path for non-FSDP/non-ZeRO
|
||||
- split_tensor_dict / shuffle_sequence_dict: scalar type handling (int/float/bool passthrough)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _batch_update_named_params(
|
||||
self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None
|
||||
):
|
||||
"""Batched weight sync — sends param metadata via HTTP, tensors via NCCL."""
|
||||
from transformers import is_torch_xpu_available
|
||||
|
||||
if chunk_size is None:
|
||||
chunks = [params]
|
||||
else:
|
||||
chunks = []
|
||||
current_chunk: list[tuple[str, torch.Tensor]] = []
|
||||
current_elements = 0
|
||||
for name, weights in params:
|
||||
n_elem = weights.numel()
|
||||
if current_chunk and current_elements + n_elem > chunk_size:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = []
|
||||
current_elements = 0
|
||||
current_chunk.append((name, weights))
|
||||
current_elements += n_elem
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
for chunk in chunks:
|
||||
param_metadata = [
|
||||
{"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
|
||||
for name, weights in chunk
|
||||
]
|
||||
url = f"{self.base_url}/batch_update_named_params/"
|
||||
response = self.session.post(url, json={"params": param_metadata})
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
for _name, weights in chunk:
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weights, root=self.rank)
|
||||
else:
|
||||
self.communicator.broadcast(weights, src=self.rank)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.barrier()
|
||||
else:
|
||||
self.communicator.group.barrier()
|
||||
|
||||
|
||||
def _update_model_params(self, model: nn.Module, chunk_size: int | None = None):
|
||||
"""Updates all model params using batch_update_named_params."""
|
||||
params = [(name, param.data) for name, param in model.named_parameters()]
|
||||
self.batch_update_named_params(params, chunk_size=chunk_size)
|
||||
|
||||
|
||||
def _patched_extract_logprobs(all_outputs):
|
||||
"""extract_logprobs with NaN→0.0 fix (stock TRL uses None which causes downstream errors)."""
|
||||
all_logprobs = []
|
||||
all_token_ids = []
|
||||
|
||||
for outputs in all_outputs:
|
||||
for output in outputs.outputs:
|
||||
if output.logprobs is None:
|
||||
return None, None
|
||||
seq_logprobs = []
|
||||
seq_token_ids = []
|
||||
for lp in output.logprobs:
|
||||
sorted_items = sorted(lp.items(), key=lambda x: x[1].rank)
|
||||
seq_token_ids.append([token_id for token_id, _ in sorted_items])
|
||||
seq_logprobs.append(
|
||||
[
|
||||
0.0 if math.isnan(item.logprob) else item.logprob
|
||||
for _, item in sorted_items
|
||||
]
|
||||
)
|
||||
all_logprobs.append(seq_logprobs)
|
||||
all_token_ids.append(seq_token_ids)
|
||||
|
||||
return all_logprobs, all_token_ids
|
||||
|
||||
|
||||
def _patched_split_tensor_dict(tensor_dict, num_chunks):
|
||||
"""split_tensor_dict that handles scalar types (int/float/bool) for num_items_in_batch."""
|
||||
first_tensor = next(
|
||||
tensor
|
||||
for tensor in tensor_dict.values()
|
||||
if tensor is not None and isinstance(tensor, torch.Tensor) and tensor.ndim > 0
|
||||
)
|
||||
chunk_size = first_tensor.shape[0] // num_chunks
|
||||
chunks = []
|
||||
for i in range(num_chunks):
|
||||
chunk_dict = {}
|
||||
for key, tensor in tensor_dict.items():
|
||||
if isinstance(tensor, (int, float, bool)):
|
||||
chunk_dict[key] = tensor
|
||||
elif tensor is not None and (isinstance(tensor, list) or tensor.ndim > 0):
|
||||
chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size]
|
||||
elif tensor is not None and tensor.ndim == 0:
|
||||
chunk_dict[key] = tensor
|
||||
else:
|
||||
chunk_dict[key] = None
|
||||
chunks.append(chunk_dict)
|
||||
return chunks
|
||||
|
||||
|
||||
def _patched_shuffle_sequence_dict(seq_dict):
|
||||
"""shuffle_sequence_dict that handles scalar types (int/float/bool)."""
|
||||
first_seq = next(
|
||||
v
|
||||
for v in seq_dict.values()
|
||||
if v is not None and isinstance(v, (torch.Tensor, list)) and len(v) > 0
|
||||
)
|
||||
perm = torch.randperm(len(first_seq))
|
||||
|
||||
def permute(v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (int, float, bool)):
|
||||
return v
|
||||
if isinstance(v, torch.Tensor) and v.ndim == 0:
|
||||
return v
|
||||
if isinstance(v, torch.Tensor) and v.ndim >= 1:
|
||||
return v[perm]
|
||||
if isinstance(v, list):
|
||||
return [v[i] for i in perm.tolist()]
|
||||
return v
|
||||
|
||||
return {k: permute(v) for k, v in seq_dict.items()}
|
||||
|
||||
|
||||
def _patch_sync_weights_batched(original_init):
|
||||
"""Wrap VLLMGeneration.__init__ to accept weight_sync_chunk_size."""
|
||||
|
||||
@wraps(original_init)
|
||||
def patched_init(self, *args, weight_sync_chunk_size=None, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.weight_sync_chunk_size = weight_sync_chunk_size
|
||||
|
||||
return patched_init
|
||||
|
||||
|
||||
def _make_batched_sync_weights(original_sync_weights):
|
||||
"""Wrap sync_weights to use batched sync for non-FSDP/non-ZeRO paths."""
|
||||
|
||||
@wraps(original_sync_weights)
|
||||
def patched_sync_weights(self):
|
||||
from accelerate.utils import is_peft_model
|
||||
|
||||
# Check if we're in a non-PEFT, non-FSDP, non-ZeRO scenario where batching helps
|
||||
accelerator = self.accelerator
|
||||
model = self.model
|
||||
is_fsdp_enabled = self.is_fsdp_enabled
|
||||
|
||||
deepspeed_plugin = accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
|
||||
is_peft = is_peft_model(model)
|
||||
|
||||
# If PEFT, FSDP, or ZeRO-3, fall back to original (which handles those cases)
|
||||
if is_peft or is_fsdp_enabled or zero_stage_3:
|
||||
return original_sync_weights(self)
|
||||
|
||||
# Non-PEFT, non-FSDP, non-ZeRO: use batched sync
|
||||
if self.mode == "colocate" and getattr(self, "enable_sleep_mode", False):
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import (
|
||||
empty_cache,
|
||||
)
|
||||
|
||||
empty_cache()
|
||||
self.llm.wake_up(tags=["weights"])
|
||||
|
||||
if self.mode == "server" and accelerator.is_main_process:
|
||||
params = [
|
||||
(self._fix_param_name_to_vllm(name), param.data)
|
||||
for name, param in model.named_parameters()
|
||||
]
|
||||
self.vllm_client.batch_update_named_params(
|
||||
params, chunk_size=getattr(self, "weight_sync_chunk_size", None)
|
||||
)
|
||||
elif self.mode == "colocate":
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
weights = [
|
||||
(self._fix_param_name_to_vllm(name), param.data)
|
||||
for name, param in model.named_parameters()
|
||||
]
|
||||
llm_model.load_weights(weights=weights)
|
||||
|
||||
# Reset cache
|
||||
if self.mode == "server" and accelerator.is_main_process:
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
elif self.mode == "colocate":
|
||||
self.llm.reset_prefix_cache()
|
||||
|
||||
return patched_sync_weights
|
||||
|
||||
|
||||
def patch_trl_vllm():
|
||||
"""Apply all TRL vLLM monkeypatches."""
|
||||
import trl.generation.vllm_client
|
||||
import trl.generation.vllm_generation
|
||||
import trl.trainer.utils
|
||||
|
||||
VLLMClient = trl.generation.vllm_client.VLLMClient
|
||||
VLLMGeneration = trl.generation.vllm_generation.VLLMGeneration
|
||||
|
||||
# 1. Add batch_update_named_params to VLLMClient
|
||||
if not hasattr(VLLMClient, "batch_update_named_params"):
|
||||
VLLMClient.batch_update_named_params = _batch_update_named_params
|
||||
VLLMClient.update_model_params = _update_model_params
|
||||
LOG.info("Patched VLLMClient with batch_update_named_params")
|
||||
|
||||
# 2. Patch extract_logprobs (NaN→0.0)
|
||||
trl.generation.vllm_generation.extract_logprobs = _patched_extract_logprobs
|
||||
LOG.info("Patched extract_logprobs with NaN→0.0 fix")
|
||||
|
||||
# 3. Patch VLLMGeneration.__init__ to accept weight_sync_chunk_size
|
||||
VLLMGeneration.__init__ = _patch_sync_weights_batched(VLLMGeneration.__init__)
|
||||
|
||||
# 4. Patch sync_weights for batched non-FSDP/non-ZeRO path
|
||||
VLLMGeneration.sync_weights = _make_batched_sync_weights(
|
||||
VLLMGeneration.sync_weights
|
||||
)
|
||||
LOG.info("Patched VLLMGeneration with batched sync_weights")
|
||||
|
||||
# 5. Patch split_tensor_dict and shuffle_sequence_dict
|
||||
trl.trainer.utils.split_tensor_dict = _patched_split_tensor_dict
|
||||
trl.trainer.utils.shuffle_sequence_dict = _patched_shuffle_sequence_dict
|
||||
LOG.info("Patched split_tensor_dict and shuffle_sequence_dict for scalar types")
|
||||
0
src/axolotl/scripts/__init__.py
Normal file
0
src/axolotl/scripts/__init__.py
Normal file
503
src/axolotl/scripts/vllm_serve_lora.py
Normal file
503
src/axolotl/scripts/vllm_serve_lora.py
Normal file
@@ -0,0 +1,503 @@
|
||||
"""vLLM serve script with native LoRA adapter support.
|
||||
|
||||
Extends TRL's vllm_serve to enable direct LoRA adapter loading in vLLM,
|
||||
instead of merging adapter weights into the base model before syncing.
|
||||
|
||||
Usage:
|
||||
Set ``vllm.serve_module: axolotl.scripts.vllm_serve_lora`` in your config,
|
||||
or ``trl.vllm_lora_sync: true`` to auto-select.
|
||||
|
||||
Benefits over merge-sync:
|
||||
- Syncs only LoRA adapter weights via filesystem instead of full merged model via NCCL
|
||||
- vLLM handles LoRA application natively (Punica kernels)
|
||||
- No NCCL communicator needed for weight sync
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from multiprocessing import Pipe, Process
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any
|
||||
|
||||
from trl.scripts.vllm_serve import (
|
||||
ScriptArguments,
|
||||
chunk_list,
|
||||
extract_logprobs,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAScriptArguments(ScriptArguments):
|
||||
"""Extended script arguments with LoRA support."""
|
||||
|
||||
enable_lora: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Enable LoRA adapter support in vLLM."},
|
||||
)
|
||||
max_lora_rank: int = field(
|
||||
default=64,
|
||||
metadata={"help": "Maximum LoRA rank supported."},
|
||||
)
|
||||
max_loras: int = field(
|
||||
default=2,
|
||||
metadata={"help": "Maximum number of LoRA adapters loaded simultaneously."},
|
||||
)
|
||||
lora_dtype: str = field(
|
||||
default="bfloat16",
|
||||
metadata={"help": "Data type for LoRA weights."},
|
||||
)
|
||||
|
||||
|
||||
def llm_worker(
|
||||
script_args: LoRAScriptArguments,
|
||||
data_parallel_rank: int,
|
||||
master_port: int,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
"""Worker process that creates a vLLM LLM with LoRA enabled."""
|
||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
|
||||
llm = LLM(
|
||||
model=script_args.model,
|
||||
revision=script_args.revision,
|
||||
tensor_parallel_size=script_args.tensor_parallel_size,
|
||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
||||
enforce_eager=script_args.enforce_eager,
|
||||
dtype=script_args.dtype,
|
||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||
max_model_len=script_args.max_model_len,
|
||||
# Use batch-capable worker extension (adds batch_update_named_params + auto-close)
|
||||
worker_extension_cls="axolotl.scripts.vllm_worker_ext.BatchWeightSyncWorkerExtension",
|
||||
trust_remote_code=script_args.trust_remote_code,
|
||||
model_impl=script_args.vllm_model_impl,
|
||||
logprobs_mode="processed_logprobs",
|
||||
# LoRA
|
||||
enable_lora=script_args.enable_lora,
|
||||
max_lora_rank=script_args.max_lora_rank,
|
||||
max_loras=script_args.max_loras,
|
||||
lora_dtype=script_args.lora_dtype,
|
||||
)
|
||||
|
||||
connection.send({"status": "ready"})
|
||||
|
||||
while True:
|
||||
try:
|
||||
command = connection.recv()
|
||||
except KeyboardInterrupt:
|
||||
llm.collective_rpc(method="close_communicator")
|
||||
break
|
||||
|
||||
if command["type"] in ["call", "fire_and_forget"]:
|
||||
method_name = command["method"]
|
||||
args = command.get("args", ())
|
||||
kwargs = command.get("kwargs", {})
|
||||
|
||||
# Reconstruct LoRARequest from serialized dict (can't pickle across pipe)
|
||||
if "lora_request" in kwargs and kwargs["lora_request"] is not None:
|
||||
lr = kwargs["lora_request"]
|
||||
kwargs["lora_request"] = LoRARequest(
|
||||
lora_name=lr["lora_name"],
|
||||
lora_int_id=lr["lora_int_id"],
|
||||
lora_path=lr["lora_path"],
|
||||
load_inplace=lr.get("load_inplace", False),
|
||||
)
|
||||
|
||||
method = getattr(llm, method_name)
|
||||
result = method(*args, **kwargs)
|
||||
if command["type"] == "call":
|
||||
connection.send(result)
|
||||
elif command["type"] == "shutdown":
|
||||
break
|
||||
|
||||
|
||||
def main(script_args: ScriptArguments):
|
||||
"""Start vLLM workers with LoRA support and the HTTP server."""
|
||||
import asyncio
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
|
||||
# Request/Response models (defined locally like TRL's vllm_serve.main)
|
||||
class GenerateRequest(BaseModel):
|
||||
prompts: list[str]
|
||||
images: list[str] | None = None
|
||||
n: int = 1
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
max_tokens: int = 16
|
||||
logprobs: int | None = 0
|
||||
truncate_prompt_tokens: int | None = None
|
||||
structured_outputs_regex: str | None = None
|
||||
generation_kwargs: dict = PydanticField(default_factory=dict)
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
prompt_ids: list[list[int]]
|
||||
completion_ids: list[list[int]]
|
||||
logprobs: list[list[list[float]]]
|
||||
logprob_token_ids: list[list[list[int]]]
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: list[list[dict]]
|
||||
n: int = 1
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
max_tokens: int = 16
|
||||
logprobs: int | None = 0
|
||||
truncate_prompt_tokens: int | None = None
|
||||
structured_outputs_regex: str | None = None
|
||||
generation_kwargs: dict = PydanticField(default_factory=dict)
|
||||
chat_template_kwargs: dict = PydanticField(default_factory=dict)
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
prompt_ids: list[list[int]]
|
||||
completion_ids: list[list[int]]
|
||||
logprobs: list[list[list[float]]]
|
||||
logprob_token_ids: list[list[list[int]]]
|
||||
|
||||
class InitCommunicatorRequest(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
world_size: int
|
||||
client_device_uuid: str
|
||||
|
||||
# Wrap plain ScriptArguments with LoRA defaults
|
||||
if not isinstance(script_args, LoRAScriptArguments):
|
||||
lora_args = LoRAScriptArguments.__new__(LoRAScriptArguments)
|
||||
for f in ScriptArguments.__dataclass_fields__:
|
||||
setattr(lora_args, f, getattr(script_args, f))
|
||||
# Apply LoRA defaults
|
||||
for f in LoRAScriptArguments.__dataclass_fields__:
|
||||
if f not in ScriptArguments.__dataclass_fields__:
|
||||
setattr(
|
||||
lora_args, f, LoRAScriptArguments.__dataclass_fields__[f].default
|
||||
)
|
||||
script_args = lora_args
|
||||
|
||||
# Spawn workers
|
||||
master_port = get_open_port()
|
||||
connections: list[Connection] = []
|
||||
processes: list[Process] = []
|
||||
for dp_rank in range(script_args.data_parallel_size):
|
||||
parent_conn, child_conn = Pipe()
|
||||
process = Process(
|
||||
target=llm_worker,
|
||||
args=(script_args, dp_rank, master_port, child_conn),
|
||||
)
|
||||
process.start()
|
||||
connections.append(parent_conn)
|
||||
processes.append(process)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
import time
|
||||
|
||||
startup_timeout = 300 # 5 minutes
|
||||
start_time = time.monotonic()
|
||||
ready: set[int] = set()
|
||||
while len(ready) < script_args.data_parallel_size:
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed > startup_timeout:
|
||||
raise RuntimeError(
|
||||
f"vLLM workers failed to start within {startup_timeout}s "
|
||||
f"({len(ready)}/{script_args.data_parallel_size} ready)"
|
||||
)
|
||||
for i, (conn, proc) in enumerate(zip(connections, processes, strict=True)):
|
||||
if id(conn) in ready:
|
||||
continue
|
||||
if not proc.is_alive():
|
||||
raise RuntimeError(
|
||||
f"vLLM worker {i} exited unexpectedly during startup"
|
||||
)
|
||||
if conn.poll():
|
||||
msg = conn.recv()
|
||||
if isinstance(msg, dict) and msg.get("status") == "ready":
|
||||
ready.add(id(conn))
|
||||
await asyncio.sleep(0.1)
|
||||
yield
|
||||
for p in processes:
|
||||
p.join(timeout=10)
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# --- Active LoRA state (shared across endpoints via closure) ---
|
||||
active_lora: dict = {"request": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LoRA-specific endpoints
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class SetLoRARequest(BaseModel):
|
||||
lora_name: str
|
||||
lora_int_id: int
|
||||
lora_path: str
|
||||
load_inplace: bool = False
|
||||
|
||||
@app.post("/set_lora_adapter/")
|
||||
async def set_lora_adapter(request: SetLoRARequest):
|
||||
"""Register a LoRA adapter for all subsequent generate/chat calls."""
|
||||
active_lora["request"] = {
|
||||
"lora_name": request.lora_name,
|
||||
"lora_int_id": request.lora_int_id,
|
||||
"lora_path": request.lora_path,
|
||||
"load_inplace": request.load_inplace,
|
||||
}
|
||||
logger.info(
|
||||
"Set active LoRA: %s (id=%d, path=%s)",
|
||||
request.lora_name,
|
||||
request.lora_int_id,
|
||||
request.lora_path,
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.post("/clear_lora_adapter/")
|
||||
async def clear_lora_adapter():
|
||||
"""Clear active LoRA adapter (revert to base model)."""
|
||||
active_lora["request"] = None
|
||||
return {"status": "ok"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Standard endpoints (mirrors TRL's vllm_serve)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@app.get("/health/")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/get_world_size/")
|
||||
async def get_world_size():
|
||||
return {
|
||||
"world_size": script_args.tensor_parallel_size
|
||||
* script_args.data_parallel_size
|
||||
}
|
||||
|
||||
@app.post("/generate/", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
"""Generate completions with optional LoRA adapter."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import vllm
|
||||
from packaging.version import Version
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
|
||||
prompts: list[dict[str, Any]] = []
|
||||
for prompt, image in zip(request.prompts, images, strict=True):
|
||||
row: dict[str, Any] = {"prompt": prompt}
|
||||
if image is not None:
|
||||
from PIL import Image
|
||||
|
||||
row["multi_modal_data"] = {
|
||||
"image": Image.open(BytesIO(base64.b64decode(image)))
|
||||
}
|
||||
prompts.append(row)
|
||||
|
||||
generation_kwargs = {
|
||||
"n": request.n,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"max_tokens": request.max_tokens,
|
||||
"logprobs": request.logprobs,
|
||||
}
|
||||
generation_kwargs.update(request.generation_kwargs)
|
||||
|
||||
if Version(vllm.__version__) <= Version("0.10.2"):
|
||||
key = "guided_decoding"
|
||||
if request.structured_outputs_regex is not None:
|
||||
generation_kwargs[key] = GuidedDecodingParams(
|
||||
regex=request.structured_outputs_regex
|
||||
)
|
||||
else:
|
||||
generation_kwargs.setdefault(key, None)
|
||||
else:
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
|
||||
key = "structured_outputs"
|
||||
if request.structured_outputs_regex is not None:
|
||||
generation_kwargs[key] = StructuredOutputsParams(
|
||||
regex=request.structured_outputs_regex
|
||||
)
|
||||
elif isinstance(generation_kwargs.get(key), dict):
|
||||
generation_kwargs[key] = StructuredOutputsParams(
|
||||
**generation_kwargs[key]
|
||||
)
|
||||
else:
|
||||
generation_kwargs.setdefault(key, None)
|
||||
|
||||
sampling_params = SamplingParams(**generation_kwargs)
|
||||
chunked_prompts = chunk_list(prompts, script_args.data_parallel_size)
|
||||
|
||||
for conn, chunk in zip(connections, chunked_prompts, strict=True):
|
||||
if not chunk:
|
||||
chunk = [{"prompt": "<placeholder>"}]
|
||||
kwargs = {
|
||||
"prompts": chunk,
|
||||
"sampling_params": sampling_params,
|
||||
"lora_request": active_lora["request"],
|
||||
}
|
||||
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
||||
|
||||
all_outputs = [conn.recv() for conn in connections]
|
||||
all_outputs = [
|
||||
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
|
||||
]
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
return {
|
||||
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
|
||||
"completion_ids": [
|
||||
list(out.token_ids) for o in all_outputs for out in o.outputs
|
||||
],
|
||||
"logprobs": extract_logprobs(all_outputs)[0],
|
||||
"logprob_token_ids": extract_logprobs(all_outputs)[1],
|
||||
}
|
||||
|
||||
@app.post("/chat/", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
"""Chat endpoint with optional LoRA adapter."""
|
||||
generation_kwargs = {
|
||||
"n": request.n,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"max_tokens": request.max_tokens,
|
||||
"logprobs": request.logprobs,
|
||||
}
|
||||
generation_kwargs.update(request.generation_kwargs)
|
||||
sampling_params = SamplingParams(**generation_kwargs)
|
||||
chunked = chunk_list(request.messages, script_args.data_parallel_size)
|
||||
for conn, chunk in zip(connections, chunked, strict=True):
|
||||
if not chunk:
|
||||
chunk = [[{"role": "user", "content": "<placeholder>"}]]
|
||||
kwargs = {
|
||||
"messages": chunk,
|
||||
"sampling_params": sampling_params,
|
||||
"use_tqdm": False,
|
||||
"lora_request": active_lora["request"],
|
||||
}
|
||||
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
|
||||
|
||||
all_outputs = [conn.recv() for conn in connections]
|
||||
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
return {
|
||||
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
|
||||
"completion_ids": [
|
||||
list(out.token_ids) for o in all_outputs for out in o.outputs
|
||||
],
|
||||
"logprobs": extract_logprobs(all_outputs)[0],
|
||||
"logprob_token_ids": extract_logprobs(all_outputs)[1],
|
||||
}
|
||||
|
||||
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
||||
|
||||
@app.post("/init_communicator/")
|
||||
async def init_communicator(request: InitCommunicatorRequest):
|
||||
world_size = (
|
||||
script_args.tensor_parallel_size * script_args.data_parallel_size + 1
|
||||
)
|
||||
kwargs = {
|
||||
"method": "init_communicator",
|
||||
"args": (
|
||||
request.host,
|
||||
request.port,
|
||||
world_size,
|
||||
request.client_device_uuid,
|
||||
),
|
||||
}
|
||||
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
||||
loop = asyncio.get_running_loop()
|
||||
await asyncio.gather(
|
||||
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
||||
)
|
||||
return {"message": "Initializing communicator"}
|
||||
|
||||
class UpdateWeightsRequest(BaseModel):
|
||||
name: str
|
||||
dtype: str
|
||||
shape: list[int]
|
||||
|
||||
@app.post("/update_named_param/")
|
||||
async def update_named_param(request: UpdateWeightsRequest):
|
||||
kwargs = {
|
||||
"method": "update_named_param",
|
||||
"args": (request.name, request.dtype, tuple(request.shape)),
|
||||
}
|
||||
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
||||
loop = asyncio.get_running_loop()
|
||||
await asyncio.gather(
|
||||
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
||||
)
|
||||
return {"message": "Updating parameter"}
|
||||
|
||||
class BatchUpdateWeightsRequest(BaseModel):
|
||||
params: list[dict]
|
||||
|
||||
@app.post("/batch_update_named_params/")
|
||||
async def batch_update_named_params(request: BatchUpdateWeightsRequest):
|
||||
params_list = [
|
||||
(p["name"], p["dtype"], tuple(p["shape"])) for p in request.params
|
||||
]
|
||||
kwargs = {"method": "batch_update_named_params", "args": (params_list,)}
|
||||
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
||||
loop = asyncio.get_running_loop()
|
||||
await asyncio.gather(
|
||||
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
||||
)
|
||||
return {"message": f"Batch update for {len(params_list)} params"}
|
||||
|
||||
@app.post("/reset_prefix_cache/")
|
||||
async def reset_prefix_cache():
|
||||
for conn in connections:
|
||||
conn.send({"type": "call", "method": "reset_prefix_cache"})
|
||||
results = [conn.recv() for conn in connections]
|
||||
return {"message": f"Reset prefix cache: {all(results)}"}
|
||||
|
||||
@app.post("/close_communicator/")
|
||||
async def close_communicator():
|
||||
kwargs = {"method": "close_communicator"}
|
||||
for conn in connections:
|
||||
conn.send(
|
||||
{
|
||||
"type": "fire_and_forget",
|
||||
"method": "collective_rpc",
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
return {"message": "Closing communicator"}
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=script_args.host,
|
||||
port=script_args.port,
|
||||
log_level=script_args.log_level,
|
||||
access_log=True,
|
||||
)
|
||||
158
src/axolotl/scripts/vllm_worker_ext.py
Normal file
158
src/axolotl/scripts/vllm_worker_ext.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Extended vLLM worker extension with batch weight sync support.
|
||||
|
||||
Subclasses TRL's WeightSyncWorkerExtension to add:
|
||||
- batch_update_named_params: receives multiple params in one call
|
||||
- Auto-close stale communicator on re-init
|
||||
- _direct_set_weight: proper handling for stacked (qkv_proj, gate_up_proj) params,
|
||||
including LoRA-wrapped models where vLLM inserts base_layer into the hierarchy
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from transformers import is_torch_xpu_available
|
||||
except ImportError:
|
||||
is_torch_xpu_available = lambda: False # noqa: E731
|
||||
|
||||
from trl.scripts.vllm_serve import WeightSyncWorkerExtension
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Stacked param name mapping: shard_name -> (packed_name, shard_order)
|
||||
_STACKED_PARAMS = {
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
|
||||
class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
|
||||
"""Worker extension that adds batch weight update and direct weight setting."""
|
||||
|
||||
def init_communicator(self, host, port, world_size, client_device_uuid):
|
||||
"""Auto-close stale communicator before re-initializing."""
|
||||
if self.communicator is not None:
|
||||
self.close_communicator()
|
||||
super().init_communicator(host, port, world_size, client_device_uuid)
|
||||
|
||||
def _direct_set_weight(self, name: str, weight: torch.Tensor) -> None:
|
||||
"""Directly copy weight data into the model, handling stacked params.
|
||||
|
||||
Bypasses model.load_weights() which may fail on vLLM 0.17's new
|
||||
module-tree weight loader for stacked params (qkv_proj, gate_up_proj).
|
||||
|
||||
Handles LoRA-wrapped params where vLLM inserts ``base_layer`` into the
|
||||
parameter hierarchy (e.g. ``qkv_proj.base_layer.weight``).
|
||||
"""
|
||||
model = self.model_runner.model
|
||||
params_dict = dict(model.named_parameters())
|
||||
|
||||
# Check if this is a simple direct param (exists as-is)
|
||||
if name in params_dict:
|
||||
params_dict[name].data.copy_(weight.to(params_dict[name].dtype))
|
||||
return
|
||||
|
||||
# Also check with base_layer inserted: x.y.weight -> x.y.base_layer.weight
|
||||
parts_bl = name.rsplit(".", 1)
|
||||
if len(parts_bl) == 2:
|
||||
base_layer_name = f"{parts_bl[0]}.base_layer.{parts_bl[1]}"
|
||||
if base_layer_name in params_dict:
|
||||
params_dict[base_layer_name].data.copy_(
|
||||
weight.to(params_dict[base_layer_name].dtype)
|
||||
)
|
||||
return
|
||||
|
||||
# Handle stacked params: e.g. "model.layers.0.self_attn.q_proj.weight"
|
||||
# -> "model.layers.0.self_attn.qkv_proj.weight" with shard offset
|
||||
parts = name.rsplit(".", 2) # [prefix, layer_name, suffix]
|
||||
if len(parts) == 3:
|
||||
prefix, layer_name, suffix = parts
|
||||
if layer_name in _STACKED_PARAMS:
|
||||
packed_name, shard_idx = _STACKED_PARAMS[layer_name]
|
||||
for packed_full in [
|
||||
f"{prefix}.{packed_name}.{suffix}",
|
||||
f"{prefix}.{packed_name}.base_layer.{suffix}",
|
||||
]:
|
||||
if packed_full not in params_dict:
|
||||
continue
|
||||
param = params_dict[packed_full]
|
||||
# Navigate to the packed module to find shard sizes
|
||||
module_path = packed_full.rsplit(".", 1)[0] # strip .weight/.bias
|
||||
if ".base_layer" in module_path:
|
||||
module_path = module_path.replace(".base_layer", "")
|
||||
module = model
|
||||
for attr in module_path.split("."):
|
||||
module = getattr(module, attr, None)
|
||||
if module is None:
|
||||
break
|
||||
# LoRA wrappers don't have output_sizes directly;
|
||||
# check base_layer for the underlying parallel linear
|
||||
if module is not None and not hasattr(module, "output_sizes"):
|
||||
base = getattr(module, "base_layer", None)
|
||||
if base is not None and hasattr(base, "output_sizes"):
|
||||
module = base
|
||||
if module is not None and hasattr(module, "output_sizes"):
|
||||
tp_size = getattr(module, "tp_size", 1)
|
||||
sizes = [s // tp_size for s in module.output_sizes]
|
||||
offset = sum(sizes[:shard_idx])
|
||||
shard_size = sizes[shard_idx]
|
||||
param.data[offset : offset + shard_size].copy_(
|
||||
weight.to(param.dtype)
|
||||
)
|
||||
return
|
||||
|
||||
# Fallback: try load_weights (may work for non-stacked params)
|
||||
logger.warning("Falling back to load_weights for param: %s", name)
|
||||
model.load_weights(weights=[(name, weight)])
|
||||
|
||||
def update_named_param(self, name, dtype, shape):
|
||||
"""Override to use _direct_set_weight instead of load_weights."""
|
||||
if self.communicator is None:
|
||||
raise RuntimeError("Communicator not initialized.")
|
||||
|
||||
dtype = getattr(torch, dtype.split(".")[-1])
|
||||
weight = torch.empty(shape, dtype=dtype, device=self.device)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weight, root=self.client_rank)
|
||||
self.communicator.barrier()
|
||||
else:
|
||||
self.communicator.broadcast(weight, src=self.client_rank)
|
||||
self.communicator.group.barrier()
|
||||
|
||||
self._direct_set_weight(name, weight)
|
||||
|
||||
def batch_update_named_params(self, params_list: list[tuple[str, str, tuple]]):
|
||||
"""Receive and apply multiple weight tensors in sequence.
|
||||
|
||||
Args:
|
||||
params_list: List of (name, dtype_str, shape) tuples.
|
||||
"""
|
||||
if self.communicator is None:
|
||||
raise RuntimeError("Communicator not initialized.")
|
||||
|
||||
weights_to_load = []
|
||||
for name, dtype_str, shape in params_list:
|
||||
dtype = getattr(torch, dtype_str.split(".")[-1])
|
||||
weight = torch.empty(shape, dtype=dtype, device=self.device)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weight, root=self.client_rank)
|
||||
else:
|
||||
self.communicator.broadcast(weight, src=self.client_rank)
|
||||
|
||||
weights_to_load.append((name, weight))
|
||||
|
||||
# Single barrier after all broadcasts
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.barrier()
|
||||
else:
|
||||
self.communicator.group.barrier()
|
||||
|
||||
# Load weights using direct set (handles stacked params)
|
||||
for name, weight in weights_to_load:
|
||||
self._direct_set_weight(name, weight)
|
||||
@@ -189,3 +189,125 @@ class TRLConfig(BaseModel):
|
||||
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
|
||||
},
|
||||
)
|
||||
|
||||
# Async GRPO fields
|
||||
use_data_producer: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Use the GRPODataProducer protocol for online data generation."
|
||||
},
|
||||
)
|
||||
async_prefetch: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Generate rollouts in a background thread while training on the previous rollout."
|
||||
},
|
||||
)
|
||||
prefetch_depth: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of rollouts to prefetch ahead of training."
|
||||
},
|
||||
)
|
||||
vllm_sync_interval: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Sync model weights to vLLM every N optimizer steps (async mode only)."
|
||||
},
|
||||
)
|
||||
streaming_partial_batch: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Score prompt groups incrementally instead of the full batch at once."
|
||||
},
|
||||
)
|
||||
streaming_min_groups: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Minimum prompt groups to score per streaming chunk."
|
||||
},
|
||||
)
|
||||
vllm_importance_sampling_correction: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Apply IS correction for distribution mismatch between vLLM and training model."
|
||||
},
|
||||
)
|
||||
vllm_importance_sampling_mode: (
|
||||
Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"]
|
||||
| None
|
||||
) = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask."
|
||||
},
|
||||
)
|
||||
vllm_importance_sampling_cap: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Cap C for IS ratio clipping/masking."},
|
||||
)
|
||||
off_policy_mask_threshold: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "KL threshold for off-policy sequence masking (OPSM). None = disabled."
|
||||
},
|
||||
)
|
||||
use_bias_correction_kl: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Apply IS correction to KL divergence term."},
|
||||
)
|
||||
|
||||
reward_num_workers: int = Field(
|
||||
default=1,
|
||||
json_schema_extra={
|
||||
"description": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
|
||||
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
|
||||
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
|
||||
},
|
||||
)
|
||||
replay_buffer_size: int = Field(
|
||||
default=0,
|
||||
json_schema_extra={
|
||||
"description": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
|
||||
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
|
||||
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
replay_recompute_logps: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "When True (default), recompute old_per_token_logps for replayed groups using the current "
|
||||
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
|
||||
"Only relevant when replay_buffer_size > 0."
|
||||
},
|
||||
)
|
||||
reroll_start_fraction: float = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={
|
||||
"description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
|
||||
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
|
||||
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
reroll_max_groups: int = Field(
|
||||
default=1,
|
||||
json_schema_extra={
|
||||
"description": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
|
||||
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
skip_zero_advantage_batches: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
|
||||
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
|
||||
"logged with skipped_zero_adv_batches=1 for monitoring."
|
||||
},
|
||||
)
|
||||
vllm_lora_sync: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Sync LoRA adapter to vLLM via filesystem instead of merging + NCCL broadcast. "
|
||||
"Auto-selects vllm_serve_lora serve module. Syncs only LoRA adapter weights vs full merged model."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -675,20 +675,6 @@ class LoRAValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernels_rl(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
) and data.get("rl"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||
"compatible with RL at the moment."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernels_trust_remote_code(cls, data):
|
||||
|
||||
@@ -57,3 +57,10 @@ class VllmConfig(BaseModel):
|
||||
default=None,
|
||||
json_schema_extra={"description": "Reasoning parser for VLLM"},
|
||||
)
|
||||
serve_module: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Python module for vLLM serve script. Set to 'axolotl.scripts.vllm_serve_lora' "
|
||||
"for native LoRA support, or leave None for default TRL serve."
|
||||
},
|
||||
)
|
||||
|
||||
220
tests/core/test_async_grpo.py
Normal file
220
tests/core/test_async_grpo.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Unit tests for async GRPO"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestReplayBuffer(unittest.TestCase):
|
||||
"""Tests for ReplayBuffer edge cases."""
|
||||
|
||||
def test_add_noop_when_max_size_zero(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=0)
|
||||
buf.add(1.0, {"data": "test"})
|
||||
self.assertEqual(len(buf), 0)
|
||||
|
||||
def test_add_noop_when_max_size_negative(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=-1)
|
||||
buf.add(1.0, {"data": "test"})
|
||||
self.assertEqual(len(buf), 0)
|
||||
|
||||
def test_sample_returns_none_when_max_size_zero(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=0)
|
||||
self.assertIsNone(buf.sample(1))
|
||||
|
||||
def test_sample_returns_none_when_empty(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=5)
|
||||
self.assertIsNone(buf.sample(1))
|
||||
|
||||
def test_normal_add_and_sample(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=3)
|
||||
buf.add(1.0, {"a": 1})
|
||||
buf.add(2.0, {"a": 2})
|
||||
buf.add(3.0, {"a": 3})
|
||||
self.assertEqual(len(buf), 3)
|
||||
result = buf.sample(1)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_replaces_lowest_when_full(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=2)
|
||||
buf.add(1.0, {"a": 1})
|
||||
buf.add(2.0, {"a": 2})
|
||||
buf.add(3.0, {"a": 3}) # should replace score=1.0
|
||||
self.assertEqual(len(buf), 2)
|
||||
scores = sorted(item[0] for item in buf._heap)
|
||||
self.assertEqual(scores, [2.0, 3.0])
|
||||
|
||||
|
||||
class TestGRPOStrategyConflict(unittest.TestCase):
|
||||
"""Tests for sequence_parallel + async_grpo conflict detection."""
|
||||
|
||||
def test_raises_on_both_enabled(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=True)
|
||||
self.assertIn("sequence_parallel", str(ctx.exception))
|
||||
self.assertIn("async_grpo", str(ctx.exception))
|
||||
|
||||
def test_sequence_parallel_only(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
)
|
||||
|
||||
cls = GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=False)
|
||||
self.assertIs(cls, AxolotlGRPOSequenceParallelTrainer)
|
||||
|
||||
def test_async_only(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlAsyncGRPOTrainer
|
||||
|
||||
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=True)
|
||||
self.assertIs(cls, AxolotlAsyncGRPOTrainer)
|
||||
|
||||
def test_neither(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
|
||||
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=False)
|
||||
self.assertIs(cls, AxolotlGRPOTrainer)
|
||||
|
||||
|
||||
class TestDequantizeFP8TailBlocks(unittest.TestCase):
|
||||
"""Tests for FP8 dequantization with non-divisible dimensions."""
|
||||
|
||||
def test_exact_divisible_shape(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
W = torch.randn(256, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.ones(2, 1, dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (256, 128))
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
def test_non_divisible_rows(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
# 130 rows, scale has 2 blocks (block_size ~65 for exact div, but with
|
||||
# tail blocks: first block=65 rows, second=65 rows, 130%2=0 actually).
|
||||
# Use 131 rows with 2 scale blocks to trigger tail handling.
|
||||
W = torch.ones(131, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.tensor([[2.0], [3.0]], dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (131, 128))
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
def test_non_divisible_cols(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
W = torch.ones(128, 200, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.ones(1, 2, dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (128, 200))
|
||||
|
||||
def test_scalar_scale(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
W = torch.ones(64, 64, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.tensor(2.0, dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (64, 64))
|
||||
|
||||
|
||||
class TestLoraFP8Guard(unittest.TestCase):
|
||||
"""Tests that get_lora_parameters only uses weight_scale_inv for FP8 weights."""
|
||||
|
||||
def test_non_fp8_weight_skips_scale_inv(self):
|
||||
"""Non-FP8 weight should NOT pick up weight_scale_inv as quant_state."""
|
||||
from axolotl.kernels.lora import get_lora_parameters
|
||||
|
||||
proj = MagicMock()
|
||||
proj.disable_adapters = True
|
||||
base_layer = MagicMock(spec=[]) # empty spec to control attrs precisely
|
||||
|
||||
# Use a real tensor for weight (bf16, no quant_state attr)
|
||||
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16)
|
||||
base_layer.bias = None
|
||||
base_layer.weight_scale_inv = torch.ones(1) # should NOT be used for bf16
|
||||
|
||||
proj.base_layer = base_layer
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
# quant_state should be None since weight is bf16, not FP8
|
||||
self.assertIsNone(quant_state)
|
||||
|
||||
def test_fp8_weight_uses_scale_inv(self):
|
||||
"""FP8 weight should pick up weight_scale_inv as quant_state."""
|
||||
from axolotl.kernels.lora import get_lora_parameters
|
||||
|
||||
proj = MagicMock()
|
||||
proj.disable_adapters = True
|
||||
base_layer = MagicMock()
|
||||
proj.base_layer = base_layer
|
||||
|
||||
# FP8 weight
|
||||
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
base_layer.bias = None
|
||||
scale_inv = torch.ones(1)
|
||||
base_layer.weight_scale_inv = scale_inv
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
self.assertIs(quant_state, scale_inv)
|
||||
|
||||
|
||||
class TestValidateQuantPatchRestore(unittest.TestCase):
|
||||
"""Test that validate_quantization_for_training is restored after trainer creation."""
|
||||
|
||||
def test_patch_restored_on_success(self):
|
||||
"""Monkeypatch should be restored even after successful trainer creation."""
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
original = _trainer_module.validate_quantization_for_training
|
||||
|
||||
# After the build() method runs, original should be restored.
|
||||
# We can't easily test the full build(), but we can test the pattern.
|
||||
_orig = _trainer_module.validate_quantization_for_training
|
||||
_trainer_module.validate_quantization_for_training = lambda model: None
|
||||
try:
|
||||
pass # simulate trainer_cls() succeeding
|
||||
finally:
|
||||
_trainer_module.validate_quantization_for_training = _orig
|
||||
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
def test_patch_restored_on_error(self):
|
||||
"""Monkeypatch should be restored even if trainer creation raises."""
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
original = _trainer_module.validate_quantization_for_training
|
||||
|
||||
_orig = _trainer_module.validate_quantization_for_training
|
||||
_trainer_module.validate_quantization_for_training = lambda model: None
|
||||
try:
|
||||
raise ValueError("test error")
|
||||
except ValueError:
|
||||
pass
|
||||
finally:
|
||||
_trainer_module.validate_quantization_for_training = _orig
|
||||
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
286
tests/monkeypatch/test_trl_vllm.py
Normal file
286
tests/monkeypatch/test_trl_vllm.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Unit tests for TRL vLLM monkeypatches.
|
||||
|
||||
Tests:
|
||||
- split_tensor_dict: scalar type preservation (int/float/bool)
|
||||
- shuffle_sequence_dict: scalar type preservation
|
||||
- extract_logprobs: NaN → 0.0 replacement
|
||||
- VLLMClient.batch_update_named_params: method exists after patch
|
||||
- VLLMGeneration: weight_sync_chunk_size attribute after patch
|
||||
- Patch idempotency: applying patch twice doesn't break anything
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestSplitTensorDict(unittest.TestCase):
|
||||
"""Tests for patched split_tensor_dict."""
|
||||
|
||||
def setUp(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _patched_split_tensor_dict
|
||||
|
||||
self.split = _patched_split_tensor_dict
|
||||
|
||||
def test_scalar_int_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "count": 42}
|
||||
chunks = self.split(d, 2)
|
||||
self.assertEqual(len(chunks), 2)
|
||||
self.assertEqual(chunks[0]["count"], 42)
|
||||
self.assertEqual(chunks[1]["count"], 42)
|
||||
|
||||
def test_scalar_float_preserved(self):
|
||||
d = {"a": torch.randn(6, 2), "lr": 1e-5}
|
||||
chunks = self.split(d, 3)
|
||||
for c in chunks:
|
||||
self.assertEqual(c["lr"], 1e-5)
|
||||
|
||||
def test_scalar_bool_preserved(self):
|
||||
d = {"a": torch.randn(4, 2), "flag": True}
|
||||
chunks = self.split(d, 2)
|
||||
for c in chunks:
|
||||
self.assertTrue(c["flag"])
|
||||
|
||||
def test_none_preserved(self):
|
||||
d = {"a": torch.randn(4, 2), "b": None}
|
||||
chunks = self.split(d, 2)
|
||||
for c in chunks:
|
||||
self.assertIsNone(c["b"])
|
||||
|
||||
def test_tensor_split(self):
|
||||
t = torch.arange(8).reshape(4, 2)
|
||||
d = {"a": t, "n": 10}
|
||||
chunks = self.split(d, 2)
|
||||
self.assertEqual(chunks[0]["a"].shape, (2, 2))
|
||||
self.assertEqual(chunks[1]["a"].shape, (2, 2))
|
||||
torch.testing.assert_close(chunks[0]["a"], t[:2])
|
||||
torch.testing.assert_close(chunks[1]["a"], t[2:])
|
||||
|
||||
def test_0d_tensor_preserved(self):
|
||||
d = {"a": torch.randn(4, 2), "scalar_t": torch.tensor(3.14)}
|
||||
chunks = self.split(d, 2)
|
||||
for c in chunks:
|
||||
self.assertAlmostEqual(c["scalar_t"].item(), 3.14, places=5)
|
||||
|
||||
def test_list_split(self):
|
||||
d = {"a": torch.randn(4, 2), "names": ["a", "b", "c", "d"]}
|
||||
chunks = self.split(d, 2)
|
||||
self.assertEqual(chunks[0]["names"], ["a", "b"])
|
||||
self.assertEqual(chunks[1]["names"], ["c", "d"])
|
||||
|
||||
|
||||
class TestShuffleSequenceDict(unittest.TestCase):
|
||||
"""Tests for patched shuffle_sequence_dict."""
|
||||
|
||||
def setUp(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _patched_shuffle_sequence_dict
|
||||
|
||||
self.shuffle = _patched_shuffle_sequence_dict
|
||||
|
||||
def test_scalar_int_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "count": 42}
|
||||
result = self.shuffle(d)
|
||||
self.assertEqual(result["count"], 42)
|
||||
|
||||
def test_scalar_float_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "lr": 1e-5}
|
||||
result = self.shuffle(d)
|
||||
self.assertEqual(result["lr"], 1e-5)
|
||||
|
||||
def test_scalar_bool_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "flag": False}
|
||||
result = self.shuffle(d)
|
||||
self.assertFalse(result["flag"])
|
||||
|
||||
def test_none_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "b": None}
|
||||
result = self.shuffle(d)
|
||||
self.assertIsNone(result["b"])
|
||||
|
||||
def test_tensor_permuted(self):
|
||||
torch.manual_seed(42)
|
||||
t = torch.arange(4).float()
|
||||
d = {"a": t}
|
||||
result = self.shuffle(d)
|
||||
# Same elements, possibly different order
|
||||
self.assertEqual(sorted(result["a"].tolist()), sorted(t.tolist()))
|
||||
self.assertEqual(result["a"].shape, t.shape)
|
||||
|
||||
def test_list_permuted(self):
|
||||
torch.manual_seed(42)
|
||||
d = {"a": torch.randn(3, 2), "names": ["x", "y", "z"]}
|
||||
result = self.shuffle(d)
|
||||
self.assertEqual(sorted(result["names"]), ["x", "y", "z"])
|
||||
self.assertEqual(len(result["names"]), 3)
|
||||
|
||||
def test_0d_tensor_preserved(self):
|
||||
d = {"a": torch.randn(4, 2), "scalar_t": torch.tensor(3.14)}
|
||||
result = self.shuffle(d)
|
||||
self.assertAlmostEqual(result["scalar_t"].item(), 3.14, places=5)
|
||||
|
||||
|
||||
class TestExtractLogprobs(unittest.TestCase):
|
||||
"""Tests for patched extract_logprobs (NaN → 0.0)."""
|
||||
|
||||
def setUp(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _patched_extract_logprobs
|
||||
|
||||
self.extract = _patched_extract_logprobs
|
||||
|
||||
def _make_output(self, logprob_values):
|
||||
"""Create a mock vLLM RequestOutput with given logprob values."""
|
||||
|
||||
@dataclass
|
||||
class LogprobItem:
|
||||
logprob: float
|
||||
rank: int
|
||||
|
||||
@dataclass
|
||||
class SeqOutput:
|
||||
logprobs: list[dict[int, LogprobItem]] | None
|
||||
|
||||
@dataclass
|
||||
class RequestOutput:
|
||||
outputs: list[SeqOutput]
|
||||
|
||||
logprobs_list = []
|
||||
for vals in logprob_values:
|
||||
lp_dict = {i: LogprobItem(logprob=v, rank=i) for i, v in enumerate(vals)}
|
||||
logprobs_list.append(lp_dict)
|
||||
|
||||
return RequestOutput(outputs=[SeqOutput(logprobs=logprobs_list)])
|
||||
|
||||
def test_nan_replaced_with_zero(self):
|
||||
output = self._make_output([[float("nan"), 0.5], [-0.3, float("nan")]])
|
||||
logprobs, token_ids = self.extract([output])
|
||||
self.assertEqual(logprobs[0][0][0], 0.0) # NaN → 0.0
|
||||
self.assertEqual(logprobs[0][0][1], 0.5)
|
||||
self.assertEqual(logprobs[0][1][0], -0.3)
|
||||
self.assertEqual(logprobs[0][1][1], 0.0) # NaN → 0.0
|
||||
|
||||
def test_normal_values_preserved(self):
|
||||
output = self._make_output([[-0.5, -1.2], [-0.1, -2.0]])
|
||||
logprobs, token_ids = self.extract([output])
|
||||
self.assertAlmostEqual(logprobs[0][0][0], -0.5)
|
||||
self.assertAlmostEqual(logprobs[0][0][1], -1.2)
|
||||
|
||||
def test_none_logprobs_returns_none(self):
|
||||
@dataclass
|
||||
class SeqOutput:
|
||||
logprobs: None = None
|
||||
|
||||
@dataclass
|
||||
class RequestOutput:
|
||||
outputs: list
|
||||
|
||||
output = RequestOutput(outputs=[SeqOutput()])
|
||||
logprobs, token_ids = self.extract([output])
|
||||
self.assertIsNone(logprobs)
|
||||
self.assertIsNone(token_ids)
|
||||
|
||||
def test_token_ids_extracted(self):
|
||||
output = self._make_output([[-0.5]])
|
||||
logprobs, token_ids = self.extract([output])
|
||||
self.assertEqual(token_ids[0][0], [0]) # token_id=0 from enumerate
|
||||
|
||||
|
||||
class TestPatchApplication(unittest.TestCase):
|
||||
"""Tests for patch_trl_vllm() application."""
|
||||
|
||||
def test_batch_update_added_to_client(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
|
||||
|
||||
patch_trl_vllm()
|
||||
from trl.generation.vllm_client import VLLMClient
|
||||
|
||||
self.assertTrue(hasattr(VLLMClient, "batch_update_named_params"))
|
||||
|
||||
def test_extract_logprobs_patched(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import (
|
||||
_patched_extract_logprobs,
|
||||
patch_trl_vllm,
|
||||
)
|
||||
|
||||
patch_trl_vllm()
|
||||
from trl.generation import vllm_generation
|
||||
|
||||
self.assertIs(vllm_generation.extract_logprobs, _patched_extract_logprobs)
|
||||
|
||||
def test_utils_patched(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import (
|
||||
_patched_shuffle_sequence_dict,
|
||||
_patched_split_tensor_dict,
|
||||
patch_trl_vllm,
|
||||
)
|
||||
|
||||
patch_trl_vllm()
|
||||
import trl.trainer.utils
|
||||
|
||||
self.assertIs(trl.trainer.utils.split_tensor_dict, _patched_split_tensor_dict)
|
||||
self.assertIs(
|
||||
trl.trainer.utils.shuffle_sequence_dict, _patched_shuffle_sequence_dict
|
||||
)
|
||||
|
||||
def test_patch_idempotent(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
|
||||
|
||||
patch_trl_vllm()
|
||||
patch_trl_vllm() # second call should not error
|
||||
from trl.generation.vllm_client import VLLMClient
|
||||
|
||||
self.assertTrue(hasattr(VLLMClient, "batch_update_named_params"))
|
||||
|
||||
|
||||
class TestBatchUpdateChunking(unittest.TestCase):
|
||||
"""Tests for batch_update_named_params chunking logic."""
|
||||
|
||||
def test_no_chunk_single_batch(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params
|
||||
|
||||
# Test that with chunk_size=None, all params go in one chunk
|
||||
client = MagicMock()
|
||||
client.base_url = "http://localhost:8000"
|
||||
client.session.post.return_value = MagicMock(status_code=200)
|
||||
client.communicator = MagicMock()
|
||||
client.communicator.group = MagicMock()
|
||||
client.rank = 0
|
||||
|
||||
params = [
|
||||
("layer.0.weight", torch.randn(10, 10)),
|
||||
("layer.1.weight", torch.randn(10, 10)),
|
||||
]
|
||||
_batch_update_named_params(client, params, chunk_size=None)
|
||||
|
||||
# Should make exactly 1 HTTP call
|
||||
self.assertEqual(client.session.post.call_count, 1)
|
||||
|
||||
def test_chunk_splits_params(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params
|
||||
|
||||
client = MagicMock()
|
||||
client.base_url = "http://localhost:8000"
|
||||
client.session.post.return_value = MagicMock(status_code=200)
|
||||
client.communicator = MagicMock()
|
||||
client.communicator.group = MagicMock()
|
||||
client.rank = 0
|
||||
|
||||
params = [
|
||||
("a", torch.randn(100)), # 100 elements
|
||||
("b", torch.randn(100)), # 100 elements
|
||||
("c", torch.randn(100)), # 100 elements
|
||||
]
|
||||
_batch_update_named_params(client, params, chunk_size=150)
|
||||
|
||||
# Should make 2 HTTP calls: [a,b] then [c] (100+100 > 150 triggers split)
|
||||
# Actually: a=100 < 150, a+b=200 > 150 → chunk [a], then b=100 < 150,
|
||||
# b+c=200 > 150 → chunk [b], then [c]. So 3 calls.
|
||||
# Wait: first a added (100 < 150), then b: 100+100=200 > 150, so chunk=[a],
|
||||
# new chunk starts with b (100 < 150), then c: 100+100=200 > 150, so chunk=[b],
|
||||
# final chunk=[c]. 3 HTTP calls.
|
||||
self.assertEqual(client.session.post.call_count, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user