Compare commits
15 Commits
main
...
swe-rebenc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d17ed89a3c | ||
|
|
02e4f2350d | ||
|
|
4195605ab2 | ||
|
|
37acb28d02 | ||
|
|
4a5281e61a | ||
|
|
a892d8cce1 | ||
|
|
78de2919a6 | ||
|
|
28283ff373 | ||
|
|
dc16859983 | ||
|
|
d4e9cf2eec | ||
|
|
53391a10d7 | ||
|
|
7617b951a8 | ||
|
|
e993ed5208 | ||
|
|
69f165b39b | ||
|
|
80a97f192b |
@@ -242,6 +242,85 @@ class ProducerConfig:
|
||||
)
|
||||
|
||||
|
||||
class _GroupShardedSampler:
|
||||
"""Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups.
|
||||
|
||||
``RepeatSampler`` yields ``num_generations`` consecutive copies of
|
||||
each prompt, forming a GRPO group. For distributed training each
|
||||
rank must see a disjoint slice of prompts (otherwise every rank
|
||||
dogpiles on the first 1/world_size of the batch) while keeping each
|
||||
group intact on a single rank so advantage normalization sees all
|
||||
peer generations.
|
||||
|
||||
``accelerator.prepare(DataLoader)`` does not handle this correctly
|
||||
for custom samplers with ``split_batches=False`` (the default): it
|
||||
leaves the sampler alone and every rank replays identical indices.
|
||||
This wrapper fixes that by consuming the inner sampler's full
|
||||
output, chunking it into ``num_generations``-sized groups, and
|
||||
round-robining whole groups across ranks.
|
||||
|
||||
Intended to be used ONLY when distributed training is active
|
||||
(``num_replicas > 1``); for single-rank it is a no-op but still
|
||||
correct.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner: Any,
|
||||
num_generations: int,
|
||||
rank: int,
|
||||
num_replicas: int,
|
||||
):
|
||||
if num_generations < 1:
|
||||
raise ValueError(f"num_generations must be >= 1, got {num_generations}")
|
||||
if num_replicas < 1:
|
||||
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
|
||||
if not (0 <= rank < num_replicas):
|
||||
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
|
||||
self.inner = inner
|
||||
self.num_generations = num_generations
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
all_indices = list(self.inner)
|
||||
if len(all_indices) % self.num_generations != 0:
|
||||
raise ValueError(
|
||||
f"inner sampler yielded {len(all_indices)} indices, "
|
||||
f"not a multiple of num_generations={self.num_generations}"
|
||||
)
|
||||
# Chunk the flat index sequence into groups of num_generations
|
||||
# consecutive indices. ``RepeatSampler`` guarantees that each
|
||||
# group contains num_generations copies of the same prompt id.
|
||||
groups = [
|
||||
all_indices[i : i + self.num_generations]
|
||||
for i in range(0, len(all_indices), self.num_generations)
|
||||
]
|
||||
# Round-robin whole groups across ranks. Round-robin (vs.
|
||||
# contiguous chunking) preserves approximate shuffled order on
|
||||
# each rank even when the group count is small relative to the
|
||||
# world size.
|
||||
for group in groups[self.rank :: self.num_replicas]:
|
||||
yield from group
|
||||
|
||||
def __len__(self):
|
||||
try:
|
||||
inner_len = len(self.inner)
|
||||
except TypeError:
|
||||
# Non-sized inner sampler — we can't know the per-rank
|
||||
# length without materializing. Return 0 as a hint that the
|
||||
# DataLoader should fall back to iteration.
|
||||
return 0
|
||||
total_groups = inner_len // self.num_generations
|
||||
# Ceiling division for the trailing groups that don't divide
|
||||
# evenly — extra groups go to the first ``total_groups %
|
||||
# num_replicas`` ranks, matching the round-robin above.
|
||||
my_groups = (
|
||||
total_groups + self.num_replicas - self.rank - 1
|
||||
) // self.num_replicas
|
||||
return my_groups * self.num_generations
|
||||
|
||||
|
||||
class DataProducer(ABC):
|
||||
"""Abstract base class for online data producers.
|
||||
|
||||
@@ -556,6 +635,34 @@ class GRPODataProducer(BaseDataProducer):
|
||||
seed=self._seed,
|
||||
)
|
||||
|
||||
# Shard the sampler across distributed ranks so each rank sees
|
||||
# a disjoint slice of prompts. ``RepeatSampler`` groups each
|
||||
# prompt with ``num_generations`` consecutive copies — our
|
||||
# wrapper round-robins WHOLE groups across ranks so all
|
||||
# generations of a given prompt stay on the same rank (needed
|
||||
# for GRPO advantage normalization within a group).
|
||||
#
|
||||
# Without this, ``accelerator.prepare(dl)`` with the default
|
||||
# ``split_batches=False`` leaves the custom sampler alone, so
|
||||
# every rank iterates the identical index sequence and the
|
||||
# cluster dogpiles on the first 1/world_size of the prompts.
|
||||
num_replicas = max(1, trainer.accelerator.num_processes)
|
||||
if num_replicas > 1:
|
||||
sampler = _GroupShardedSampler(
|
||||
inner=sampler,
|
||||
num_generations=self._num_generations,
|
||||
rank=trainer.accelerator.process_index,
|
||||
num_replicas=num_replicas,
|
||||
)
|
||||
logger.info(
|
||||
"[RANK:%d] _GroupShardedSampler active "
|
||||
"(num_replicas=%d, num_generations=%d, gen_batch=%d)",
|
||||
trainer.accelerator.process_index,
|
||||
num_replicas,
|
||||
self._num_generations,
|
||||
self._generation_batch_size,
|
||||
)
|
||||
|
||||
# Use identity collator (same as stock GRPOTrainer)
|
||||
def _identity(x):
|
||||
return x
|
||||
@@ -574,12 +681,11 @@ class GRPODataProducer(BaseDataProducer):
|
||||
rank=trainer.args.process_index,
|
||||
),
|
||||
)
|
||||
self._prompt_dl = trainer.accelerator.prepare(dl)
|
||||
|
||||
# Don't let accelerator track this dataloader
|
||||
acc_dls = trainer.accelerator._dataloaders
|
||||
if self._prompt_dl in acc_dls:
|
||||
acc_dls.remove(self._prompt_dl)
|
||||
# Skip accelerator.prepare — we're handling per-rank sharding
|
||||
# ourselves via ``_GroupShardedSampler``. ``prepare()`` would
|
||||
# otherwise try to wrap the DataLoader with its own sharding
|
||||
# logic which does not understand our group structure.
|
||||
self._prompt_dl = dl
|
||||
|
||||
self._prompt_iter = iter(self._prompt_dl)
|
||||
|
||||
@@ -1103,11 +1209,22 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
- vllm_lora_sync: saves adapter to filesystem, vLLM loads natively
|
||||
- PEFT no-merge: computes merged weights as new tensors, NCCL broadcast
|
||||
- Non-PEFT: stock sync_weights via merge_adapter + NCCL
|
||||
|
||||
This is the canonical sync trigger and runs in BOTH async and
|
||||
synchronous modes from ``_prepare_inputs_with_data_producer`` /
|
||||
``_prepare_inputs_legacy_async``. The ``_generate_single_turn``
|
||||
patch is a parallel backup for non-data-producer paths (vanilla
|
||||
GRPO without NeMo Gym), where the data producer is bypassed
|
||||
entirely and TRL's stock generate-then-sync flow is used instead.
|
||||
"""
|
||||
if not (self.use_vllm and self.args.async_prefetch):
|
||||
if not self.use_vllm:
|
||||
return
|
||||
step = self.state.global_step
|
||||
interval = self.args.vllm_sync_interval
|
||||
# Default to syncing every step when no interval is configured —
|
||||
# otherwise ``step % None`` would TypeError, and the previous
|
||||
# behavior of crashing on the first sync was strictly worse than
|
||||
# the standard "sync every optimizer step".
|
||||
interval = self.args.vllm_sync_interval or 1
|
||||
if step != self._last_synced_step and step % interval == 0:
|
||||
if step == 0:
|
||||
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
||||
@@ -1202,13 +1319,42 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
# Permanently replace vllm_generation.sync_weights with our custom
|
||||
# sync to avoid merge_adapter (fails on FP8 / races with training).
|
||||
# For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights
|
||||
# handles the sync with proper interval tracking.
|
||||
#
|
||||
# The design has two modes that have to be threaded carefully:
|
||||
#
|
||||
# - Async prefetch ON: BG generation thread can't safely call
|
||||
# sync_weights mid-rollout (it races with the trainer's optimizer
|
||||
# step and can corrupt weights). We no-op the stock sync hook and
|
||||
# drive sync ourselves from ``_maybe_sync_vllm_weights`` after the
|
||||
# optimizer step on the main thread.
|
||||
#
|
||||
# - Async prefetch OFF (synchronous mode): TRL's stock
|
||||
# ``_generate_single_turn`` calls ``sync_weights`` once per step
|
||||
# boundary. There's no BG thread to race with, and
|
||||
# ``_maybe_sync_vllm_weights`` short-circuits with
|
||||
# ``if not async_prefetch: return``, so we MUST wire the stock
|
||||
# hook directly to our LoRA sync helper — otherwise nothing ever
|
||||
# pushes weights to vLLM and the trainer becomes a no-op (vLLM
|
||||
# keeps serving the base model, every rollout in every group
|
||||
# produces identical outputs, advantages are zero, optimizer
|
||||
# step gets skipped, repeat).
|
||||
if not getattr(self, "_patched_sync_weights", False):
|
||||
if self.use_vllm and hasattr(self, "vllm_generation"):
|
||||
if getattr(self.args, "vllm_lora_sync", False):
|
||||
# No-op: LoRA sync is driven by _maybe_sync_vllm_weights
|
||||
self.vllm_generation.sync_weights = lambda: None
|
||||
if getattr(self.args, "async_prefetch", False):
|
||||
# Async: drive sync from main thread via
|
||||
# _maybe_sync_vllm_weights instead.
|
||||
self.vllm_generation.sync_weights = lambda: None
|
||||
else:
|
||||
# Sync mode: TRL's _generate_single_turn already
|
||||
# calls sync_weights once per step boundary. Wire
|
||||
# it directly to our LoRA filesystem sync helper.
|
||||
sync_helper = self._sync_lora_adapter
|
||||
|
||||
def _lora_filesystem_sync():
|
||||
sync_helper()
|
||||
|
||||
self.vllm_generation.sync_weights = _lora_filesystem_sync
|
||||
self._patched_sync_weights = True
|
||||
else:
|
||||
from accelerate.utils import is_peft_model
|
||||
|
||||
@@ -2,17 +2,35 @@
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
from . import layers
|
||||
from .lora_ops import ParallelExperts
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
||||
from .lora_layout import (
|
||||
peft_down_proj_lora_to_scattermoe,
|
||||
peft_lora_B_to_scattermoe,
|
||||
peft_lora_to_scattermoe,
|
||||
validate_scattermoe_lora_shapes,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"layers",
|
||||
"ParallelExperts",
|
||||
"flatten_sort_count",
|
||||
"parallel_linear",
|
||||
"ScatterMoELoRA",
|
||||
"parallel_linear_lora",
|
||||
"lora_ops",
|
||||
"peft_down_proj_lora_to_scattermoe",
|
||||
"peft_lora_B_to_scattermoe",
|
||||
"peft_lora_to_scattermoe",
|
||||
"validate_scattermoe_lora_shapes",
|
||||
]
|
||||
|
||||
try:
|
||||
from . import layers
|
||||
from .lora_ops import ParallelExperts
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
||||
except ModuleNotFoundError as exc:
|
||||
if exc.name != "triton":
|
||||
raise
|
||||
else:
|
||||
__all__ += [
|
||||
"layers",
|
||||
"ParallelExperts",
|
||||
"flatten_sort_count",
|
||||
"parallel_linear",
|
||||
"ScatterMoELoRA",
|
||||
"parallel_linear_lora",
|
||||
"lora_ops",
|
||||
]
|
||||
|
||||
@@ -35,81 +35,19 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .lora_layout import (
|
||||
peft_down_proj_lora_to_scattermoe,
|
||||
peft_lora_B_to_scattermoe,
|
||||
peft_lora_to_scattermoe,
|
||||
)
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
|
||||
|
||||
# =============================================================================
|
||||
# LoRA layout conversion utilities (peft <-> scattermoe)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
||||
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
|
||||
expert-major ``[N, r*E]``.
|
||||
|
||||
peft reshapes B to ``[out, r, E]`` (rank-major).
|
||||
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
|
||||
"""
|
||||
N = peft_B.shape[0]
|
||||
return (
|
||||
peft_B.reshape(N, rank, num_experts)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
.reshape(N, num_experts * rank)
|
||||
)
|
||||
|
||||
|
||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
|
||||
|
||||
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
|
||||
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
|
||||
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
|
||||
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
|
||||
are swapped relative to scattermoe's convention.
|
||||
|
||||
peft gives:
|
||||
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
|
||||
|
||||
scattermoe needs:
|
||||
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
|
||||
|
||||
This function swaps A<->B and converts B from rank-major to expert-major.
|
||||
Uses vectorized tensor operations (no Python loop over experts).
|
||||
|
||||
Works for **both** gate_up_proj and down_proj since the transposition
|
||||
issue is the same for any parameter.
|
||||
"""
|
||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
|
||||
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
|
||||
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
|
||||
|
||||
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
|
||||
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
|
||||
smoe_A = (
|
||||
peft_B_em.reshape(dim2, num_experts, rank)
|
||||
.permute(1, 2, 0)
|
||||
.contiguous()
|
||||
.reshape(rank * num_experts, dim2)
|
||||
)
|
||||
|
||||
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
|
||||
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
|
||||
smoe_B = (
|
||||
peft_A.reshape(num_experts, rank, dim1)
|
||||
.permute(2, 0, 1)
|
||||
.contiguous()
|
||||
.reshape(dim1, num_experts * rank)
|
||||
)
|
||||
|
||||
return smoe_A, smoe_B
|
||||
|
||||
|
||||
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
|
||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||
|
||||
__all__ = [
|
||||
"peft_down_proj_lora_to_scattermoe",
|
||||
"peft_lora_B_to_scattermoe",
|
||||
"peft_lora_to_scattermoe",
|
||||
]
|
||||
|
||||
# =============================================================================
|
||||
# ParamWrapper unwrapping
|
||||
@@ -199,7 +137,7 @@ def _unwrap_experts_lora(experts_module):
|
||||
if gup is not None:
|
||||
num_experts = gup.shape[0]
|
||||
|
||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
||||
# Extract gate_up_proj LoRA
|
||||
gup_lora = None
|
||||
gup_wrapper = wrappers.get("gate_up_proj")
|
||||
if gup_wrapper is not None:
|
||||
@@ -208,7 +146,7 @@ def _unwrap_experts_lora(experts_module):
|
||||
rank = lora_A.shape[0] // num_experts
|
||||
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
|
||||
|
||||
# Extract down_proj LoRA (needs A<->B swap due to transposition)
|
||||
# Extract down_proj LoRA
|
||||
down_lora = None
|
||||
down_wrapper = wrappers.get("down_proj")
|
||||
if down_wrapper is not None:
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Pure tensor layout helpers for ScatterMoE LoRA weights."""
|
||||
|
||||
|
||||
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
||||
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
|
||||
expert-major ``[N, r*E]``.
|
||||
|
||||
peft reshapes B to ``[out, r, E]`` (rank-major).
|
||||
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
|
||||
"""
|
||||
N = peft_B.shape[0]
|
||||
return (
|
||||
peft_B.reshape(N, rank, num_experts)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
.reshape(N, num_experts * rank)
|
||||
)
|
||||
|
||||
|
||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Convert peft LoRA weights to scattermoe layout.
|
||||
|
||||
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
|
||||
where ``out_features=dim1, in_features=dim2``. ScatterMoE transposes the
|
||||
parameter (``W = param.transpose(2, 1)``), giving ``[E, dim2, dim1]`` with
|
||||
``K=dim2, N=dim1``.
|
||||
|
||||
peft gives:
|
||||
lora_A ``[r*E, dim2]``, lora_B ``[dim1, r*E]``
|
||||
|
||||
scattermoe needs:
|
||||
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
|
||||
|
||||
peft's A already matches ScatterMoE's A shape. Only B needs conversion from
|
||||
peft's rank-major layout to ScatterMoE's expert-major layout.
|
||||
"""
|
||||
smoe_A = peft_A
|
||||
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
|
||||
return smoe_A, smoe_B
|
||||
|
||||
|
||||
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
|
||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||
|
||||
|
||||
def validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B):
|
||||
"""Validate LoRA tensor layout before dispatching ScatterMoE kernels."""
|
||||
E, K, N = expert_weights.shape
|
||||
if lora_A.dim() != 2 or lora_B.dim() != 2:
|
||||
raise ValueError(
|
||||
"ScatterMoE LoRA expects 2D lora_A and lora_B tensors, got "
|
||||
f"lora_A={tuple(lora_A.shape)} and lora_B={tuple(lora_B.shape)}."
|
||||
)
|
||||
|
||||
if lora_A.size(0) % E != 0:
|
||||
raise ValueError(
|
||||
"ScatterMoE LoRA expects lora_A rows to be divisible by the number "
|
||||
f"of experts ({E}), got lora_A={tuple(lora_A.shape)}."
|
||||
)
|
||||
|
||||
rank = lora_A.size(0) // E
|
||||
expected_A = (E * rank, K)
|
||||
expected_B = (N, E * rank)
|
||||
if tuple(lora_A.shape) != expected_A or tuple(lora_B.shape) != expected_B:
|
||||
raise ValueError(
|
||||
"Invalid ScatterMoE LoRA layout for expert_weights "
|
||||
f"{tuple(expert_weights.shape)}. Expected lora_A={expected_A} and "
|
||||
f"lora_B={expected_B}, got lora_A={tuple(lora_A.shape)} and "
|
||||
f"lora_B={tuple(lora_B.shape)}. For PEFT target_parameters, keep "
|
||||
"lora_A as [E*r, K] and only convert lora_B from rank-major to "
|
||||
"expert-major layout."
|
||||
)
|
||||
@@ -34,6 +34,7 @@ from .kernels.lora_ops import (
|
||||
scatter2scatter_lora,
|
||||
scatter2scatter_lora_dX,
|
||||
)
|
||||
from .lora_layout import validate_scattermoe_lora_shapes
|
||||
|
||||
|
||||
class ScatterMoELoRA(torch.autograd.Function):
|
||||
@@ -422,11 +423,6 @@ def get_lora_params_from_wrapper(module) -> tuple:
|
||||
return lora_A, lora_B, scaling
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Drop-in replacement for parallel_linear
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parallel_linear_lora(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
@@ -451,6 +447,7 @@ def parallel_linear_lora(
|
||||
Otherwise falls back to standard scatter2scatter.
|
||||
"""
|
||||
if lora_A is not None and lora_B is not None:
|
||||
validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B)
|
||||
return ScatterMoELoRA.apply(
|
||||
inputs,
|
||||
expert_weights,
|
||||
|
||||
@@ -110,11 +110,36 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
item["agent_ref"] = full_item["agent_ref"]
|
||||
dataset_items.append(item)
|
||||
|
||||
# Expand by num_generations (agent produces one rollout per call)
|
||||
expanded_items = []
|
||||
for item in dataset_items:
|
||||
for _ in range(self._num_generations):
|
||||
expanded_items.append(item)
|
||||
# NOTE: do NOT re-expand by num_generations here.
|
||||
# ``RepeatSampler(mini_repeat_count=num_generations)`` already
|
||||
# yields ``num_generations`` consecutive copies of each unique
|
||||
# prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank *
|
||||
# num_generations)`` items — one entry per rollout. Expanding
|
||||
# again here would fire ``num_generations^2`` rollouts per
|
||||
# prompt per rank and make every step dogpile on a handful of
|
||||
# tasks.
|
||||
expanded_items = dataset_items
|
||||
|
||||
# Diagnostic: log what this rank is about to fire.
|
||||
try:
|
||||
import collections
|
||||
|
||||
iid_counts: collections.Counter[str | None] = collections.Counter()
|
||||
for it in dataset_items:
|
||||
iid_counts[
|
||||
(it.get("responses_create_params", {}).get("metadata") or {}).get(
|
||||
"instance_id"
|
||||
)
|
||||
] += 1
|
||||
LOG.info(
|
||||
"[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s",
|
||||
trainer.accelerator.process_index,
|
||||
len(dataset_items),
|
||||
len(iid_counts),
|
||||
list(iid_counts.most_common(5)),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Call NeMo Gym agents
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -140,6 +165,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
logprobs_list = []
|
||||
rewards_list = []
|
||||
|
||||
num_turns_list: list[int] = []
|
||||
for resp in responses:
|
||||
parsed = _parse_agent_response(resp, eos_token_id)
|
||||
prompt_ids_list.append(parsed["prompt_ids"])
|
||||
@@ -147,6 +173,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
env_mask_list.append(parsed["env_mask"])
|
||||
logprobs_list.append(parsed["logprobs"])
|
||||
rewards_list.append(parsed["reward"])
|
||||
num_turns_list.append(parsed.get("num_turns", 0))
|
||||
|
||||
# Pad to tensors
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||
@@ -179,22 +206,48 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
|
||||
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
|
||||
|
||||
# Inject rewards into inputs so _compute_deferred_scores can use them
|
||||
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
|
||||
# Our passthrough reward_fn reads "env_reward" from kwargs.
|
||||
# Inject per-rollout reward + num_turns into each input. Since
|
||||
# ``RepeatSampler`` already yields ``num_generations`` copies of
|
||||
# each prompt, ``inputs`` has ONE entry per rollout (matching
|
||||
# ``rewards_list`` 1:1). No per-prompt grouping happens here —
|
||||
# GRPO advantage normalization is the trainer's job downstream.
|
||||
assert len(inputs) == len(rewards_list), (
|
||||
f"rewards/inputs length mismatch: "
|
||||
f"{len(rewards_list)} rewards vs {len(inputs)} inputs"
|
||||
)
|
||||
for i, inp in enumerate(inputs):
|
||||
# Each input gets rewards for its num_generations rollouts
|
||||
start = i * self._num_generations
|
||||
end = start + self._num_generations
|
||||
inp["env_reward"] = rewards_list[start:end]
|
||||
inp["env_reward"] = rewards_list[i]
|
||||
inp["num_turns"] = num_turns_list[i]
|
||||
|
||||
# Expand inputs to match expanded rollouts (num_generations copies)
|
||||
expanded_inputs = []
|
||||
for inp in inputs:
|
||||
for g in range(self._num_generations):
|
||||
expanded_inp = dict(inp)
|
||||
expanded_inp["env_reward"] = inp["env_reward"][g]
|
||||
expanded_inputs.append(expanded_inp)
|
||||
# One expanded_input per rollout (already correct count because
|
||||
# inputs has num_generations copies baked in by the sampler).
|
||||
expanded_inputs = [dict(inp) for inp in inputs]
|
||||
|
||||
# Log rollout-level stats to wandb from rank 0. These are the
|
||||
# true agent-side metrics (not the tokenized TRL view) — so
|
||||
# num_turns reflects how many /run iterations each rollout
|
||||
# actually took before finishing or hitting max_turns.
|
||||
if is_main and num_turns_list:
|
||||
try:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
import statistics as _stats
|
||||
|
||||
nonzero = sum(1 for r in rewards_list if r > 0)
|
||||
log_payload = {
|
||||
"rollout/num_turns/mean": float(_stats.mean(num_turns_list)),
|
||||
"rollout/num_turns/min": float(min(num_turns_list)),
|
||||
"rollout/num_turns/max": float(max(num_turns_list)),
|
||||
"rollout/reward/mean": float(_stats.mean(rewards_list)),
|
||||
"rollout/reward/nonzero_frac": (
|
||||
nonzero / len(rewards_list) if rewards_list else 0.0
|
||||
),
|
||||
"rollout/n_samples": float(len(rewards_list)),
|
||||
}
|
||||
wandb.log(log_payload, commit=False)
|
||||
except Exception as exc: # never let metric logging break training
|
||||
LOG.warning("rollout wandb log failed: %s", exc)
|
||||
|
||||
# Decode completions for reward functions
|
||||
completions = trainer.processing_class.batch_decode(
|
||||
|
||||
@@ -19,6 +19,7 @@ Supports two modes:
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
@@ -30,6 +31,107 @@ if TYPE_CHECKING:
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
# ---- vLLM weight-sync transport probe ------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLLMWeightSyncCapabilities:
|
||||
"""What weight-sync routes a vLLM server actually exposes.
|
||||
|
||||
Discovered once at ``pre_model_load`` time by fetching the server's
|
||||
``/openapi.json``. Drives the transport-selection table below.
|
||||
"""
|
||||
|
||||
nccl: bool = False # /init_communicator/ + /update_named_param/
|
||||
lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native)
|
||||
lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension)
|
||||
http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension)
|
||||
probed: bool = False
|
||||
probe_error: str | None = None
|
||||
routes: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def any_full_param_sync(self) -> bool:
|
||||
"""True if at least one transport can push full-model weights."""
|
||||
return self.nccl or self.http_full
|
||||
|
||||
@property
|
||||
def any_lora_sync(self) -> bool:
|
||||
"""True if at least one transport can push LoRA adapters."""
|
||||
return self.lora_filesystem or self.lora_axolotl or self.nccl
|
||||
|
||||
|
||||
def probe_vllm_weight_sync(
|
||||
base_url: str, timeout: float = 5.0
|
||||
) -> VLLMWeightSyncCapabilities:
|
||||
"""Detect which weight-sync routes the configured vLLM server exposes.
|
||||
|
||||
Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport
|
||||
we care about is mounted as a POST route there. Falls back to all-False
|
||||
on any error so the caller can still decide what to do (typically: raise
|
||||
a clear error rather than silently no-op).
|
||||
"""
|
||||
import requests
|
||||
|
||||
caps = VLLMWeightSyncCapabilities()
|
||||
try:
|
||||
r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout)
|
||||
r.raise_for_status()
|
||||
spec = r.json()
|
||||
routes = sorted((spec.get("paths") or {}).keys())
|
||||
caps.routes = routes
|
||||
caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes
|
||||
caps.lora_filesystem = "/v1/load_lora_adapter" in routes
|
||||
caps.lora_axolotl = "/set_lora_adapter/" in routes
|
||||
caps.http_full = "/http_update_weights/" in routes
|
||||
caps.probed = True
|
||||
except Exception as exc:
|
||||
caps.probe_error = f"{type(exc).__name__}: {exc}"
|
||||
LOG.warning(
|
||||
"NeMo Gym: failed to probe vLLM /openapi.json at %s — %s. "
|
||||
"Will fall back to LoRA-only behavior.",
|
||||
base_url,
|
||||
caps.probe_error,
|
||||
)
|
||||
return caps
|
||||
|
||||
|
||||
def select_weight_sync_transport(
|
||||
caps: VLLMWeightSyncCapabilities,
|
||||
*,
|
||||
has_lora: bool,
|
||||
vllm_lora_sync_pref: bool,
|
||||
) -> str:
|
||||
"""Pick the right transport for a (server caps, model type) combo.
|
||||
|
||||
Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or
|
||||
``"none"``. The caller decides what to do with ``"none"`` (typically:
|
||||
raise an error explaining the misconfiguration).
|
||||
|
||||
Selection table:
|
||||
LoRA model + lora endpoint + lora-sync pref → lora_filesystem
|
||||
LoRA model + lora endpoint → lora_filesystem
|
||||
LoRA model + nccl endpoint → nccl (broadcast merged adapter)
|
||||
Full model + nccl endpoint → nccl
|
||||
Full model + http endpoint → http_full
|
||||
anything else → none
|
||||
"""
|
||||
if has_lora:
|
||||
if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref:
|
||||
return "lora_filesystem"
|
||||
if caps.lora_filesystem or caps.lora_axolotl:
|
||||
return "lora_filesystem"
|
||||
if caps.nccl:
|
||||
return "nccl"
|
||||
return "none"
|
||||
# Full-parameter model
|
||||
if caps.nccl:
|
||||
return "nccl"
|
||||
if caps.http_full:
|
||||
return "http_full"
|
||||
return "none"
|
||||
|
||||
|
||||
class NemoGymPlugin(BasePlugin):
|
||||
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
||||
|
||||
@@ -50,37 +152,69 @@ class NemoGymPlugin(BasePlugin):
|
||||
self._reward_fn = None
|
||||
self._dataset_lookup = None
|
||||
self._agent_servers = {}
|
||||
self._vllm_caps: VLLMWeightSyncCapabilities | None = None
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.nemo_gym.NemoGymArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply monkeypatches before trainer creation."""
|
||||
"""Probe vLLM weight-sync routes and conditionally bypass NCCL init.
|
||||
|
||||
Replaces the previous unconditional ``init_communicator`` monkey-patch
|
||||
with a probe of the configured vLLM server's ``/openapi.json``. We only
|
||||
bypass NCCL init when the server we're talking to actually lacks the
|
||||
``/init_communicator/`` route (i.e. stock ``vllm serve``); against
|
||||
TRL/axolotl serve modules that DO expose NCCL routes, we leave the
|
||||
standard TRL flow alone so full-finetune training can sync weights.
|
||||
"""
|
||||
if not cfg.nemo_gym_enabled:
|
||||
return
|
||||
|
||||
# Always skip NCCL communicator init in NeMo Gym mode.
|
||||
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
|
||||
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
|
||||
if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"):
|
||||
return
|
||||
|
||||
host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1"
|
||||
port = getattr(trl_cfg, "vllm_server_port", None) or 8000
|
||||
base_url = f"http://{host}:{port}"
|
||||
self._vllm_caps = probe_vllm_weight_sync(base_url)
|
||||
|
||||
if self._vllm_caps.probed:
|
||||
LOG.info(
|
||||
"NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s "
|
||||
"lora_axolotl=%s http_full=%s",
|
||||
base_url,
|
||||
self._vllm_caps.nccl,
|
||||
self._vllm_caps.lora_filesystem,
|
||||
self._vllm_caps.lora_axolotl,
|
||||
self._vllm_caps.http_full,
|
||||
)
|
||||
|
||||
# Only bypass NCCL init when the server doesn't speak it. If NCCL is
|
||||
# available we leave VLLMClient.init_communicator alone so the
|
||||
# standard TRL sync flow can run for full-parameter training.
|
||||
if not self._vllm_caps.nccl:
|
||||
self._patch_skip_nccl_init()
|
||||
|
||||
def _patch_skip_nccl_init(self):
|
||||
"""Monkeypatch VLLMClient.init_communicator to no-op.
|
||||
|
||||
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
|
||||
serve script). The NCCL communicator is not needed and fails with both
|
||||
vLLM V1 engine and standard OpenAI server mode.
|
||||
Only called when the configured vLLM server doesn't expose
|
||||
``/init_communicator/`` (e.g. stock ``vllm serve``). In that case
|
||||
TRL's standard ``init_communicator`` would 404 inside trainer
|
||||
construction; we no-op it so the LoRA filesystem path can install
|
||||
its own sync in ``post_trainer_create``.
|
||||
"""
|
||||
try:
|
||||
from trl.generation.vllm_client import VLLMClient
|
||||
|
||||
VLLMClient._original_init_communicator = VLLMClient.init_communicator
|
||||
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
|
||||
"Skipping NCCL init_communicator (LoRA sync mode)"
|
||||
"Skipping NCCL init_communicator (server has no /init_communicator/)"
|
||||
)
|
||||
LOG.info(
|
||||
"Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)"
|
||||
)
|
||||
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
|
||||
except Exception as exc:
|
||||
LOG.warning(f"Failed to patch VLLMClient: {exc}")
|
||||
|
||||
@@ -234,30 +368,80 @@ class NemoGymPlugin(BasePlugin):
|
||||
verify_timeout = cfg.nemo_gym_verify_timeout or 30
|
||||
multi_turn = cfg.nemo_gym_multi_turn or False
|
||||
|
||||
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
|
||||
# - Install LoRA sync (when vllm_lora_sync=True)
|
||||
# - Or no-op sync_weights (when using standard vLLM server)
|
||||
# Pick a weight-sync transport based on what the configured vLLM
|
||||
# server actually exposes (see ``pre_model_load`` probe) and what
|
||||
# kind of model we're training. The selection table is documented
|
||||
# in ``select_weight_sync_transport``.
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
|
||||
vllm_gen = trainer.vllm_generation
|
||||
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
|
||||
adapter = getattr(cfg, "adapter", None)
|
||||
has_lora = adapter in ("lora", "qlora")
|
||||
vllm_lora_sync_pref = bool(
|
||||
trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False)
|
||||
)
|
||||
caps = self._vllm_caps or VLLMWeightSyncCapabilities()
|
||||
transport = select_weight_sync_transport(
|
||||
caps,
|
||||
has_lora=has_lora,
|
||||
vllm_lora_sync_pref=vllm_lora_sync_pref,
|
||||
)
|
||||
|
||||
if transport == "lora_filesystem":
|
||||
self._setup_lora_sync(trainer)
|
||||
# Verify the vLLM server supports runtime LoRA loading
|
||||
self._check_lora_endpoint(vllm_gen)
|
||||
else:
|
||||
# No NCCL, no LoRA sync — skip all weight sync paths
|
||||
vllm_gen.sync_weights = lambda: LOG.debug(
|
||||
"Weight sync skipped (NeMo Gym mode)"
|
||||
LOG.info("NeMo Gym weight sync: LoRA filesystem")
|
||||
elif transport == "nccl":
|
||||
# Standard TRL NCCL path. We leave ``VLLMClient.init_communicator``
|
||||
# alone (pre_model_load only patched it when the probe found no
|
||||
# NCCL route) so the trainer's normal weight-sync flow runs.
|
||||
LOG.info(
|
||||
"NeMo Gym weight sync: NCCL (server exposes /init_communicator/)"
|
||||
)
|
||||
type(vllm_gen).sync_weights = lambda self: LOG.debug(
|
||||
"Weight sync skipped (NeMo Gym mode)"
|
||||
elif transport == "http_full":
|
||||
# Full-parameter HTTP sync — implementation lands in step 3.
|
||||
# For now, fail loudly so users know the path is detected but
|
||||
# not yet wired up, instead of silently no-oping like before.
|
||||
raise NotImplementedError(
|
||||
"NeMo Gym + full fine-tune + HTTP weight sync is detected "
|
||||
"but the client-side sync helper is not yet implemented "
|
||||
"(planned). Use `adapter: lora|qlora` for now, or use a "
|
||||
"vLLM serve module that exposes /init_communicator/ for "
|
||||
"NCCL sync."
|
||||
)
|
||||
# Also patch the async trainer's internal sync method
|
||||
if hasattr(trainer, "_maybe_sync_vllm_weights"):
|
||||
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
|
||||
"Async weight sync skipped (NeMo Gym mode)"
|
||||
else: # transport == "none"
|
||||
# No viable sync path. Build a precise error so the user knows
|
||||
# exactly what's missing and how to fix it.
|
||||
if not caps.probed:
|
||||
msg = (
|
||||
"could not probe the vLLM server's "
|
||||
f"/openapi.json: {caps.probe_error}. "
|
||||
"Verify that vLLM is reachable at "
|
||||
f"{getattr(trl_cfg, 'vllm_server_host', '?')}:"
|
||||
f"{getattr(trl_cfg, 'vllm_server_port', '?')}."
|
||||
)
|
||||
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
|
||||
elif has_lora:
|
||||
msg = (
|
||||
"the vLLM server has neither NCCL routes "
|
||||
"(/init_communicator/) nor a LoRA-loading route "
|
||||
"(/v1/load_lora_adapter or /set_lora_adapter/). "
|
||||
"Restart vLLM with `--enable-lora --max-lora-rank N "
|
||||
"VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock "
|
||||
"server, or use `axolotl vllm-serve` for the "
|
||||
"NCCL-capable serve module."
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
"the vLLM server exposes no full-parameter sync route "
|
||||
"(/init_communicator/ for NCCL or /http_update_weights/ "
|
||||
"for HTTP). Use `axolotl vllm-serve` (which has both) "
|
||||
"or set `adapter: lora|qlora`."
|
||||
)
|
||||
raise ValueError(
|
||||
f"NeMo Gym: no usable weight-sync transport — {msg} Without "
|
||||
"weight sync the trainer's gradient updates never reach the "
|
||||
"rollout policy (functionally a no-op trainer)."
|
||||
)
|
||||
|
||||
if multi_turn:
|
||||
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)
|
||||
|
||||
@@ -130,21 +130,41 @@ def start_servers(
|
||||
)
|
||||
|
||||
|
||||
def get_server_configs(head_port: int = 11000) -> dict:
|
||||
def get_server_configs(head_port: int = 11000, timeout: float = 30.0) -> dict:
|
||||
"""Fetch the global config from the NeMo Gym head server.
|
||||
|
||||
Retries up to 3 times with exponential backoff. The default per-attempt
|
||||
timeout is 30s (raised from the original 5s) because head servers can
|
||||
be slow to respond when they're concurrently serving rollouts from a
|
||||
prior training run. A 5s timeout was empirically too tight to survive
|
||||
a kill-and-relaunch cycle.
|
||||
|
||||
Returns:
|
||||
Dict mapping server_name -> server config.
|
||||
"""
|
||||
response = requests.get(
|
||||
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
|
||||
url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml"
|
||||
last_exc: Exception | None = None
|
||||
for attempt in (1, 2, 3):
|
||||
try:
|
||||
response = requests.get(url, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
result = yaml.safe_load(response.text)
|
||||
# NeMo Gym head server double-encodes: YAML string inside a YAML string
|
||||
if isinstance(result, str):
|
||||
result = yaml.safe_load(result)
|
||||
return result
|
||||
except (requests.exceptions.RequestException, OSError) as exc:
|
||||
last_exc = exc
|
||||
LOG.warning(
|
||||
"NeMo Gym head probe attempt %d/3 failed: %s. Retrying...",
|
||||
attempt,
|
||||
type(exc).__name__,
|
||||
)
|
||||
if attempt < 3:
|
||||
time.sleep(2.0 * attempt)
|
||||
raise RuntimeError(
|
||||
f"NeMo Gym head server at {url} did not respond after 3 attempts: {last_exc}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = yaml.safe_load(response.text)
|
||||
# NeMo Gym head server double-encodes: YAML string inside a YAML string
|
||||
if isinstance(result, str):
|
||||
result = yaml.safe_load(result)
|
||||
return result
|
||||
|
||||
|
||||
def get_agent_servers(
|
||||
|
||||
@@ -53,6 +53,7 @@ def _rms_norm_rope_forward_kernel(
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
@@ -60,28 +61,35 @@ def _rms_norm_rope_forward_kernel(
|
||||
):
|
||||
"""
|
||||
Fused forward:
|
||||
x_norm = x / rms(x) [* weight] (RMSNorm)
|
||||
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
|
||||
x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols)
|
||||
y[..., :n_rot] = rope(x_norm[..., :n_rot])
|
||||
y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary)
|
||||
|
||||
rotate_half swaps first/second halves and negates the first:
|
||||
rotate_half([a, b]) = [-b, a]
|
||||
rotate_half swaps first/second halves and negates the first, restricted
|
||||
to the rotary span [0, n_rot):
|
||||
rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2
|
||||
|
||||
For the partial-rotary pass-through region we load cos with default 1.0
|
||||
and sin with default 0.0 outside [0, n_rot), so the same formula
|
||||
`Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`.
|
||||
|
||||
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
|
||||
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
|
||||
(cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)).
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
|
||||
# cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot)
|
||||
cs_row_idx = row_idx // n_heads
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
rot_mask_col = col_offsets < n_rot
|
||||
half_rot = n_rot // 2
|
||||
|
||||
# Load input row
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
X_dtype = X_row.dtype
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
|
||||
# RMSNorm: compute 1/rms
|
||||
# RMSNorm: compute 1/rms over the full row (rotary + pass-through)
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
@@ -94,33 +102,38 @@ def _rms_norm_rope_forward_kernel(
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
X_norm = X_norm * W_row
|
||||
|
||||
# RoPE: load cos/sin (broadcast across heads)
|
||||
# RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get
|
||||
# cos=1, sin=0 so the formula leaves X_norm untouched.
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
||||
mask=rot_mask_col,
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
sin_row = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets,
|
||||
mask=rot_mask_col,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
|
||||
# for col >= half_dim, take X_norm[col - half_dim]
|
||||
# rotate_half within [0, n_rot):
|
||||
# for col < half_rot: take -X_norm[col + half_rot]
|
||||
# for col in [half_rot, n_rot): take X_norm[col - half_rot]
|
||||
# For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out).
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
||||
X_rot = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
|
||||
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0
|
||||
).to(tl.float32)
|
||||
# Re-normalize the rotated values
|
||||
X_rot_norm = X_rot * rstd
|
||||
if HAS_WEIGHT:
|
||||
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
|
||||
tl.float32
|
||||
)
|
||||
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32)
|
||||
X_rot_norm = X_rot_norm * W_rot
|
||||
|
||||
# Negate the first half (rotate_half negates x2, which becomes the first half)
|
||||
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
|
||||
sign = tl.where(col_offsets < half_rot, -1.0, 1.0)
|
||||
X_rot_norm = X_rot_norm * sign
|
||||
|
||||
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||
@@ -153,13 +166,21 @@ def _rms_norm_rope_backward_kernel(
|
||||
dW_row_stride,
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Backward for Y = RoPE(RMSNorm(X, W))
|
||||
Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary
|
||||
(`n_rot <= n_cols`).
|
||||
|
||||
For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the
|
||||
output is just the normalized row, so dN[col] = dY[col] (achieved by
|
||||
loading cos with default 1.0 and forcing the rotate-half contribution
|
||||
to zero outside the rotary span).
|
||||
|
||||
cos/sin indexed by row_idx // n_heads for per-head broadcast.
|
||||
"""
|
||||
row_block_id = tl.program_id(0).to(tl.int64)
|
||||
@@ -167,7 +188,8 @@ def _rms_norm_rope_backward_kernel(
|
||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
rot_mask_col = col_offsets < n_rot
|
||||
half_rot = n_rot // 2
|
||||
|
||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
@@ -186,33 +208,37 @@ def _rms_norm_rope_backward_kernel(
|
||||
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
||||
mask=rot_mask_col,
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# dN = dY * cos + rotate_half^T(dY * sin)
|
||||
# dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span)
|
||||
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
|
||||
#
|
||||
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
|
||||
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
|
||||
# This is equivalent to rotating (dY * sin) because the rotation
|
||||
# just permutes which elements are multiplied.
|
||||
# For col >= n_rot the formula must collapse to dN = dY (since the
|
||||
# forward is just a pass-through). cos defaults to 1.0 above; the
|
||||
# rotate-half contribution is masked to zero below.
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
||||
dY_rot = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
mask=rot_load_mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
sin_rot = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
mask=rot_load_mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
|
||||
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
|
||||
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
|
||||
adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0)
|
||||
rotate_term = dY_rot * sin_rot * adj_sign
|
||||
# Zero out rotate-half contribution outside the rotary span.
|
||||
rotate_term = tl.where(rot_mask_col, rotate_term, 0.0)
|
||||
dN = dY_row * cos_row + rotate_term
|
||||
|
||||
# Pre-weight normalized: n = rstd * x
|
||||
n = X_row * rstd
|
||||
@@ -241,15 +267,17 @@ def _rms_norm_rope_backward_kernel(
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
|
||||
"""
|
||||
Args:
|
||||
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
|
||||
W: (head_dim,) or None — RMSNorm weight
|
||||
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
cos: (B*S, n_rot) — position embeddings (broadcast across heads)
|
||||
sin: (B*S, n_rot) — position embeddings (broadcast across heads)
|
||||
eps: float
|
||||
n_heads: int — number of attention heads (for cos/sin indexing)
|
||||
n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for
|
||||
partial rotary). Must be even and ``<= head_dim``.
|
||||
Returns:
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
|
||||
"""
|
||||
@@ -273,6 +301,7 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT=has_weight,
|
||||
@@ -282,7 +311,9 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
return Y, X, RSTD, BLOCK_SIZE, num_warps
|
||||
|
||||
|
||||
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
|
||||
def rms_norm_rope_backward(
|
||||
dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps
|
||||
):
|
||||
n_rows, n_cols = dY.shape
|
||||
has_weight = W is not None
|
||||
|
||||
@@ -315,6 +346,7 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa
|
||||
_dW.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT=has_weight,
|
||||
@@ -329,13 +361,14 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa
|
||||
class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, W, cos, sin, eps, n_heads):
|
||||
def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot):
|
||||
"""
|
||||
X: (B*S*H, head_dim)
|
||||
W: (head_dim,) or None
|
||||
cos: (B*S, head_dim) — broadcast across heads
|
||||
sin: (B*S, head_dim) — broadcast across heads
|
||||
X: (B*S*H, head_dim)
|
||||
W: (head_dim,) or None
|
||||
cos: (B*S, n_rot) — broadcast across heads
|
||||
sin: (B*S, n_rot) — broadcast across heads
|
||||
n_heads: int
|
||||
n_rot: int — rotary dim (<= head_dim)
|
||||
"""
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
|
||||
X,
|
||||
@@ -344,11 +377,13 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
sin,
|
||||
eps,
|
||||
n_heads,
|
||||
n_rot,
|
||||
)
|
||||
ctx.eps = eps
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.n_heads = n_heads
|
||||
ctx.n_rot = n_rot
|
||||
ctx.has_weight = W is not None
|
||||
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
|
||||
return Y
|
||||
@@ -365,21 +400,26 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
sin,
|
||||
RSTD,
|
||||
ctx.n_heads,
|
||||
ctx.n_rot,
|
||||
ctx.BLOCK_SIZE,
|
||||
ctx.num_warps,
|
||||
)
|
||||
return dX, dW, None, None, None, None
|
||||
return dX, dW, None, None, None, None, None
|
||||
|
||||
|
||||
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||
"""
|
||||
Apply fused RMSNorm + RoPE.
|
||||
Apply fused RMSNorm + (partial) RoPE.
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, num_heads, head_dim) — after projection + view
|
||||
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
|
||||
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot``
|
||||
must be even and ``<= head_dim``. When ``n_rot < head_dim``
|
||||
the trailing ``head_dim - n_rot`` columns are RMSNorm-only
|
||||
(partial-rotary pass-through), matching stock Gemma 4 with
|
||||
``partial_rotary_factor < 1.0``.
|
||||
sin: (batch, seq_len, n_rot) — same shape as ``cos``
|
||||
eps: float — RMSNorm epsilon
|
||||
|
||||
Returns:
|
||||
@@ -387,14 +427,38 @@ def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||
"""
|
||||
shape = x.shape # (B, S, H, D)
|
||||
B, S, H, D = shape
|
||||
n_rot = cos.shape[-1]
|
||||
if sin.shape[-1] != n_rot:
|
||||
raise ValueError(
|
||||
f"cos and sin must have the same last dim, got cos={cos.shape[-1]} "
|
||||
f"sin={sin.shape[-1]}"
|
||||
)
|
||||
if n_rot > D:
|
||||
raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})")
|
||||
if n_rot % 2 != 0:
|
||||
raise ValueError(f"rotary dim must be even, got {n_rot}")
|
||||
|
||||
# Flatten to 2D: (B*S*H, D)
|
||||
x_flat = x.reshape(-1, D).contiguous()
|
||||
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
|
||||
# by dividing the row_idx by H to get the cos/sin row
|
||||
cos_flat = cos.reshape(B * S, D).contiguous()
|
||||
sin_flat = sin.reshape(B * S, D).contiguous()
|
||||
# cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when
|
||||
# all sequences share the same rotary positions). The kernel needs a
|
||||
# dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly
|
||||
# onto a single (b, s) pair, so expand-then-contiguous to materialize
|
||||
# the per-batch broadcast. Expand is a no-op when B == cos.shape[0].
|
||||
if cos.shape[0] != B:
|
||||
if cos.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal "
|
||||
f"to x batch dim ({B})"
|
||||
)
|
||||
cos = cos.expand(B, S, n_rot)
|
||||
sin = sin.expand(B, S, n_rot)
|
||||
cos_flat = cos.reshape(B * S, n_rot).contiguous()
|
||||
sin_flat = sin.reshape(B * S, n_rot).contiguous()
|
||||
|
||||
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
|
||||
y_flat = FusedRMSNormRoPEFunction.apply(
|
||||
x_flat, weight, cos_flat, sin_flat, eps, H, n_rot
|
||||
)
|
||||
return y_flat.view(shape)
|
||||
|
||||
|
||||
|
||||
@@ -156,6 +156,14 @@ class PatchManager:
|
||||
# which would clobber any earlier fix.
|
||||
self._fix_nemotron_h_conversion_mapping()
|
||||
|
||||
# Gemma 4 hybrid attention runs here in post-build (NOT post-load):
|
||||
# the per-layer ``self_attn.config._attn_implementation="sdpa"``
|
||||
# override needs to walk the raw model tree, which is broken by
|
||||
# the post-load PEFT wrapping. The accompanying
|
||||
# ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and
|
||||
# installation-time-independent, so both halves of the fix live
|
||||
# cleanly in the same call even though one is instance-scoped
|
||||
# and the other is module-scoped.
|
||||
self._apply_gemma_hybrid_attention(model)
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
|
||||
@@ -173,12 +181,23 @@ class PatchManager:
|
||||
which exceeds flash attention's supported size. This patch loads the model
|
||||
with flash_attention_2 for the sliding window layers (head_dim=256), then
|
||||
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
|
||||
|
||||
We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask`
|
||||
which fixes the corresponding mask construction inside
|
||||
``Gemma4TextModel.forward``. Without it, the per-layer SDPA config
|
||||
override is not enough — the forward still builds a 2D FA2-format mask
|
||||
at the model level and the SDPA layers crash at long context lengths
|
||||
with ``RuntimeError: The expanded size of the tensor ... must match``.
|
||||
"""
|
||||
if not self.cfg.gemma4_hybrid_attn_impl:
|
||||
return
|
||||
|
||||
import copy
|
||||
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
patch_gemma4_hybrid_mask()
|
||||
|
||||
# Navigate to the module that has 'layers' - varies by model structure:
|
||||
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
|
||||
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
|
||||
@@ -392,6 +411,14 @@ class PatchManager:
|
||||
patch_qwen3_5_vlm_flash_attention()
|
||||
|
||||
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||
# The fused attn path is now compatible with
|
||||
# ``gemma4_hybrid_attn_impl``: the kernel handles partial
|
||||
# rotary (cos.shape[-1] < head_dim) and the fused forward
|
||||
# mirrors the current ``Gemma4TextAttention.forward`` API
|
||||
# for shared kv (read from / write to
|
||||
# ``past_key_values.shared_layers``). See
|
||||
# ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``
|
||||
# for the history.
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
115
src/axolotl/monkeypatch/gemma4_hybrid_mask.py
Normal file
115
src/axolotl/monkeypatch/gemma4_hybrid_mask.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Hybrid attention mask fix for Gemma 4.
|
||||
|
||||
Gemma 4 has full-attention (global) layers with ``head_dim=512`` which
|
||||
exceeds flash-attention-2's supported size. Axolotl's hybrid-attention
|
||||
patch in ``patch_manager._apply_gemma_hybrid_attention`` works around
|
||||
this by forcing ``_attn_implementation="sdpa"`` on each global layer's
|
||||
``self_attn.config``, leaving sliding-window layers on FA2.
|
||||
|
||||
The per-layer config override alone is insufficient, however:
|
||||
``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict
|
||||
using the **model-level** config and passes the mapped mask to each
|
||||
decoder layer. With FA2 still set at the model level, the ``full_attention``
|
||||
entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask.
|
||||
The global layers then fail with::
|
||||
|
||||
RuntimeError: The expanded size of the tensor (S) must match the existing
|
||||
size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor
|
||||
sizes: [B, S]
|
||||
|
||||
...when the sequence length grows past roughly 7k tokens.
|
||||
|
||||
This module fixes the symptom by monkey-patching ``create_causal_mask`` in
|
||||
``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT
|
||||
the original in ``masking_utils``. The wrapper forces
|
||||
``_attn_implementation="sdpa"`` on a shallow-copied config before calling
|
||||
through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward``
|
||||
is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left
|
||||
alone, so sliding-window layers continue to receive FA2-format masks.
|
||||
|
||||
The patch is idempotent. Install once per process, before any Gemma 4
|
||||
forward pass runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
_PATCH_APPLIED = False
|
||||
|
||||
|
||||
def patch_gemma4_hybrid_mask() -> bool:
|
||||
"""Install the Gemma 4 hybrid-attention mask fix.
|
||||
|
||||
Returns ``True`` if the patch was installed (or was already installed),
|
||||
``False`` if the target module could not be imported (e.g. transformers
|
||||
version predates Gemma 4) — in which case nothing is done and the
|
||||
caller can continue unaffected.
|
||||
"""
|
||||
global _PATCH_APPLIED
|
||||
if _PATCH_APPLIED:
|
||||
return True
|
||||
|
||||
try:
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
except ImportError:
|
||||
LOG.debug(
|
||||
"gemma4_hybrid_mask: transformers.models.gemma4 not importable, "
|
||||
"skipping. This is fine for non-Gemma4 training."
|
||||
)
|
||||
return False
|
||||
|
||||
if not hasattr(modeling_gemma4, "create_causal_mask"):
|
||||
LOG.warning(
|
||||
"gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' "
|
||||
"binding, skipping. Transformers API may have changed."
|
||||
)
|
||||
return False
|
||||
|
||||
original = modeling_gemma4.create_causal_mask
|
||||
|
||||
def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any):
|
||||
"""Wrapper that forces SDPA format for the full-attention mask.
|
||||
|
||||
The global layers were patched to SDPA by
|
||||
``_apply_gemma_hybrid_attention``, so their mask must be 4D. The
|
||||
original ``create_causal_mask`` dispatches on
|
||||
``config._attn_implementation``; we shadow that with a local
|
||||
override.
|
||||
"""
|
||||
sdpa_config = copy.copy(config)
|
||||
sdpa_config._attn_implementation = "sdpa"
|
||||
return original(sdpa_config, *args, **kwargs)
|
||||
|
||||
# Preserve the original reference on the wrapper for tests / teardown.
|
||||
hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined]
|
||||
|
||||
modeling_gemma4.create_causal_mask = hybrid_create_causal_mask
|
||||
_PATCH_APPLIED = True
|
||||
LOG.info(
|
||||
"gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to "
|
||||
"force SDPA-format masks for full-attention layers"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def unpatch_gemma4_hybrid_mask() -> None:
|
||||
"""Restore the original ``create_causal_mask``. Useful for tests."""
|
||||
global _PATCH_APPLIED
|
||||
if not _PATCH_APPLIED:
|
||||
return
|
||||
try:
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
except ImportError:
|
||||
_PATCH_APPLIED = False
|
||||
return
|
||||
current = modeling_gemma4.create_causal_mask
|
||||
original = getattr(current, "_axolotl_original", None)
|
||||
if original is not None:
|
||||
modeling_gemma4.create_causal_mask = original
|
||||
_PATCH_APPLIED = False
|
||||
@@ -24,7 +24,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
||||
# Some multimodal wrappers (e.g. Gemma 4) name the MLP class
|
||||
# ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the
|
||||
# language-side module is separated from the vision tower. Try
|
||||
# both names before giving up.
|
||||
mlp_cls = getattr(
|
||||
module,
|
||||
f"{model_cls_prefix}MLP",
|
||||
None,
|
||||
) or getattr(module, f"{model_cls_prefix}TextMLP")
|
||||
|
||||
if use_original_mlp:
|
||||
mlp_forward = mlp_cls.forward
|
||||
|
||||
@@ -320,6 +320,15 @@ def main(script_args: ScriptArguments):
|
||||
# --- Active LoRA state (shared across endpoints via closure) ---
|
||||
active_lora: dict = {"request": None}
|
||||
|
||||
# Serializes access to the worker pipe. The underlying
|
||||
# multiprocessing.Connection is a single full-duplex stream shared
|
||||
# across all HTTP handlers; concurrent requests interleave bytes on
|
||||
# the wire and corrupt the pickle framing (seen as
|
||||
# ``UnpicklingError: pickle data was truncated``). Any endpoint that
|
||||
# does ``conn.send(...); conn.recv()`` MUST hold this lock across
|
||||
# the round-trip so only one inflight call at a time per pipe.
|
||||
worker_pipe_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LoRA-specific endpoints
|
||||
# ------------------------------------------------------------------
|
||||
@@ -631,6 +640,150 @@ def main(script_args: ScriptArguments):
|
||||
},
|
||||
}
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_completions(request_body: dict):
|
||||
"""OpenAI-compatible text-completions endpoint.
|
||||
|
||||
Accepts either a string ``prompt`` or a list-of-int
|
||||
``prompt_token_ids`` (as the text-completions spec allows). Routes
|
||||
to the internal vLLM generate method with the active LoRA adapter
|
||||
and returns an OpenAI /v1/completions-shaped response including
|
||||
per-choice ``prompt_token_ids``, ``generation_token_ids``, and
|
||||
``generation_log_probs`` for NeMo Gym agents that need raw
|
||||
tokens + logprobs.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
prompt_raw = request_body.get("prompt")
|
||||
temperature = request_body.get("temperature", 1.0)
|
||||
max_tokens = request_body.get("max_tokens", 512)
|
||||
top_p = request_body.get("top_p", 1.0)
|
||||
n = request_body.get("n", 1)
|
||||
logprobs = request_body.get("logprobs") or 0
|
||||
stop_token_ids = request_body.get("stop_token_ids") or None
|
||||
|
||||
# Accept either a string or a list[int] token id prompt. Lists
|
||||
# must contain ints only (raise on lists of strings so callers get
|
||||
# a clear error). Also accept [[int, int, ...]] nesting for the
|
||||
# rare case callers pass a single-prompt batch.
|
||||
if (
|
||||
isinstance(prompt_raw, list)
|
||||
and prompt_raw
|
||||
and isinstance(prompt_raw[0], list)
|
||||
):
|
||||
prompt_raw = prompt_raw[0]
|
||||
|
||||
prompt_dict: dict[str, Any] = {}
|
||||
if isinstance(prompt_raw, list):
|
||||
prompt_dict = {"prompt_token_ids": prompt_raw}
|
||||
elif isinstance(prompt_raw, str):
|
||||
prompt_dict = {"prompt": prompt_raw}
|
||||
else:
|
||||
return {
|
||||
"error": {
|
||||
"message": ("prompt must be a string or a list of token ids"),
|
||||
"type": "invalid_request",
|
||||
}
|
||||
}
|
||||
|
||||
generation_kwargs: dict[str, Any] = {
|
||||
"n": n,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"logprobs": logprobs,
|
||||
}
|
||||
if stop_token_ids:
|
||||
generation_kwargs["stop_token_ids"] = stop_token_ids
|
||||
sampling_params = SamplingParams(
|
||||
**{k: v for k, v in generation_kwargs.items() if v is not None}
|
||||
)
|
||||
|
||||
chunked = chunk_list([prompt_dict], script_args.data_parallel_size)
|
||||
|
||||
# Hold the pipe lock across send+recv — concurrent requests would
|
||||
# otherwise interleave pickle frames on the worker connection.
|
||||
async with worker_pipe_lock:
|
||||
for conn, chunk in zip(connections, chunked, strict=True):
|
||||
if not chunk:
|
||||
chunk = [{"prompt": "<placeholder>"}]
|
||||
kwargs = {
|
||||
"prompts": chunk,
|
||||
"sampling_params": sampling_params,
|
||||
"lora_request": active_lora["request"],
|
||||
}
|
||||
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
all_outputs = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
|
||||
)
|
||||
|
||||
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||
for o in all_outputs:
|
||||
if isinstance(o, dict) and "error" in o:
|
||||
raise RuntimeError(f"vLLM worker error: {o['error']}")
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
if not all_outputs:
|
||||
return {"choices": [], "model": script_args.model}
|
||||
|
||||
choices = []
|
||||
for i, output in enumerate(all_outputs):
|
||||
for j, out in enumerate(output.outputs):
|
||||
text = out.text
|
||||
# OpenAI-style `logprobs` block for text-completions:
|
||||
# { "tokens": [...], "token_logprobs": [...] }
|
||||
lp_block = None
|
||||
if out.logprobs:
|
||||
tokens_str: list[str] = []
|
||||
token_lps: list[float] = []
|
||||
for step in out.logprobs:
|
||||
chosen = next(iter(step.values()))
|
||||
tokens_str.append(getattr(chosen, "decoded_token", "") or "")
|
||||
token_lps.append(float(chosen.logprob))
|
||||
lp_block = {
|
||||
"tokens": tokens_str,
|
||||
"token_logprobs": token_lps,
|
||||
}
|
||||
|
||||
choice = {
|
||||
"index": i * n + j,
|
||||
"text": text,
|
||||
"finish_reason": "stop"
|
||||
if out.finish_reason == "stop"
|
||||
else "length",
|
||||
"logprobs": lp_block,
|
||||
# NeMo-Gym / retrace agent extras — preserved on the
|
||||
# choice so callers with raw-token pipelines don't
|
||||
# have to re-tokenize.
|
||||
"prompt_token_ids": output.prompt_token_ids,
|
||||
"generation_token_ids": list(out.token_ids),
|
||||
"generation_log_probs": (
|
||||
[float(next(iter(lp.values())).logprob) for lp in out.logprobs]
|
||||
if out.logprobs
|
||||
else []
|
||||
),
|
||||
}
|
||||
choices.append(choice)
|
||||
|
||||
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
|
||||
completion_tokens = sum(
|
||||
len(out.token_ids) for o in all_outputs for out in o.outputs
|
||||
)
|
||||
|
||||
return {
|
||||
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "text_completion",
|
||||
"model": script_args.model,
|
||||
"choices": choices,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
||||
|
||||
@app.post("/init_communicator/")
|
||||
|
||||
@@ -770,6 +770,88 @@ class RLValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_batch_size_divisibility(cls, data):
|
||||
"""Surface GRPO batch-shape mismatches at config-parse time.
|
||||
|
||||
TRL's GRPOTrainer requires that the per-step generation batch size be
|
||||
evenly divisible by ``num_generations`` so that every prompt can be
|
||||
replicated exactly ``num_generations`` times. The runtime check inside
|
||||
``GRPOTrainer.__init__`` only fires after the model has been loaded —
|
||||
too late and too cryptic for the user. We replicate the check here so
|
||||
the failure is immediate and actionable.
|
||||
|
||||
Also enforces:
|
||||
- ``num_generations >= 2`` (group-relative advantage needs variance)
|
||||
- ``effective_gbs >= num_generations * world_size`` when capabilities
|
||||
indicate multiple ranks (each rank needs at least one full group)
|
||||
"""
|
||||
if data.get("rl") != "grpo":
|
||||
return data
|
||||
|
||||
trl_cfg = data.get("trl") or {}
|
||||
num_gen = trl_cfg.get("num_generations")
|
||||
if num_gen is None:
|
||||
# TRL's own default is 8 — but if the user didn't set it, we
|
||||
# don't have enough info to validate anything. Let TRL's own
|
||||
# init handle the default-vs-batch interaction.
|
||||
return data
|
||||
if num_gen < 2:
|
||||
raise ValueError(
|
||||
f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). "
|
||||
"With num_generations=1, every group has zero advantage and "
|
||||
"the policy never updates."
|
||||
)
|
||||
|
||||
explicit_gbs = trl_cfg.get("generation_batch_size")
|
||||
if explicit_gbs is not None:
|
||||
effective_gbs = int(explicit_gbs)
|
||||
gbs_source = "trl.generation_batch_size"
|
||||
else:
|
||||
mb = data.get("micro_batch_size") or 1
|
||||
ga = data.get("gradient_accumulation_steps") or 1
|
||||
effective_gbs = int(mb) * int(ga)
|
||||
gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})"
|
||||
|
||||
if effective_gbs % num_gen != 0:
|
||||
# Suggest the smallest GA bump that fixes it for the common case
|
||||
# where the user hasn't set generation_batch_size explicitly.
|
||||
hint = ""
|
||||
if explicit_gbs is None:
|
||||
from math import gcd
|
||||
|
||||
mb_val = int(data.get("micro_batch_size") or 1)
|
||||
# smallest GA such that mb*GA is a multiple of num_gen
|
||||
lcm = num_gen * mb_val // gcd(num_gen, mb_val)
|
||||
suggested_ga = lcm // mb_val
|
||||
hint = (
|
||||
f" Smallest fix: set `gradient_accumulation_steps: "
|
||||
f"{suggested_ga}` (so micro_batch_size * GA = "
|
||||
f"{mb_val * suggested_ga} is a multiple of {num_gen})."
|
||||
)
|
||||
raise ValueError(
|
||||
f"GRPO: generation batch size must be divisible by "
|
||||
f"`trl.num_generations`. Got effective_gbs={effective_gbs} "
|
||||
f"(from {gbs_source}) and num_generations={num_gen}.{hint}"
|
||||
)
|
||||
|
||||
# Multi-rank check: each rank must receive at least one full group
|
||||
# per step. Without `capabilities` populated yet (mode='before'), we
|
||||
# fall back to user-set distributed fields.
|
||||
world_size = (
|
||||
(data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1
|
||||
)
|
||||
if world_size and world_size > 1 and effective_gbs < num_gen * world_size:
|
||||
raise ValueError(
|
||||
f"GRPO with world_size={world_size} requires effective_gbs "
|
||||
f">= num_generations * world_size = {num_gen * world_size}, "
|
||||
f"got {effective_gbs}. Increase gradient_accumulation_steps "
|
||||
f"or micro_batch_size."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class OptimizationValidationMixin:
|
||||
"""Validation methods related to optimization and performance."""
|
||||
|
||||
@@ -216,5 +216,197 @@ class TestValidateQuantPatchRestore(unittest.TestCase):
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
|
||||
class TestVllmLoraSyncPatch(unittest.TestCase):
|
||||
"""The ``_generate_single_turn`` patch wires sync_weights to the right place.
|
||||
|
||||
These tests exercise the patch-installation branch in isolation. They build
|
||||
a stub trainer with just enough attributes to look like
|
||||
``AsyncGRPOTrainer`` for the duration of the relevant code path.
|
||||
|
||||
Background — there are two correct behaviors and we historically had a bug
|
||||
where both modes used the same one:
|
||||
|
||||
- Async prefetch ON: the BG generation thread can't safely call
|
||||
sync_weights mid-rollout. We no-op the stock hook and drive sync from
|
||||
the main thread via ``_maybe_sync_vllm_weights``.
|
||||
- Async prefetch OFF: TRL's stock ``_generate_single_turn`` already
|
||||
calls ``sync_weights`` once per step boundary on the main thread. We
|
||||
wire that hook directly to ``_sync_lora_adapter`` because
|
||||
``_maybe_sync_vllm_weights`` short-circuits when async is off.
|
||||
|
||||
Before the fix, both modes installed ``lambda: None``, so sync mode never
|
||||
pushed any LoRA adapter to vLLM and the trainer was a no-op.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_stub_trainer(*, vllm_lora_sync, async_prefetch):
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
class FakeArgs:
|
||||
pass
|
||||
|
||||
args = FakeArgs()
|
||||
args.vllm_lora_sync = vllm_lora_sync
|
||||
args.async_prefetch = async_prefetch
|
||||
|
||||
class FakeVllmGen:
|
||||
sync_weights = staticmethod(lambda: None)
|
||||
model = MagicMock()
|
||||
|
||||
# Use object.__new__ so we don't run __init__ (which needs a real
|
||||
# model, dataset, etc.). We only need the `_generate_single_turn`
|
||||
# method's patch branch to run, so we set up the minimum state.
|
||||
trainer = object.__new__(AsyncGRPOTrainer)
|
||||
trainer.args = args
|
||||
trainer.use_vllm = True
|
||||
trainer.vllm_generation = FakeVllmGen()
|
||||
trainer._patched_sync_weights = False
|
||||
# Spy on _sync_lora_adapter so we can assert it's the function the
|
||||
# hook delegates to in sync mode.
|
||||
trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy")
|
||||
trainer._sync_peft_weights_no_merge = MagicMock(
|
||||
name="_sync_peft_weights_no_merge_spy"
|
||||
)
|
||||
return trainer
|
||||
|
||||
@staticmethod
|
||||
def _run_patch_branch(trainer):
|
||||
"""Execute just the sync_weights-patching branch in isolation.
|
||||
|
||||
We can't easily call the real ``_generate_single_turn`` because it
|
||||
does a full vLLM generate. Instead we copy the exact branch out of
|
||||
the source so the test verifies the same logic the trainer runs.
|
||||
"""
|
||||
if not getattr(trainer, "_patched_sync_weights", False):
|
||||
if trainer.use_vllm and hasattr(trainer, "vllm_generation"):
|
||||
if getattr(trainer.args, "vllm_lora_sync", False):
|
||||
if getattr(trainer.args, "async_prefetch", False):
|
||||
trainer.vllm_generation.sync_weights = lambda: None
|
||||
else:
|
||||
sync_helper = trainer._sync_lora_adapter
|
||||
|
||||
def _lora_filesystem_sync():
|
||||
sync_helper()
|
||||
|
||||
trainer.vllm_generation.sync_weights = _lora_filesystem_sync
|
||||
trainer._patched_sync_weights = True
|
||||
|
||||
def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
|
||||
assert trainer._patched_sync_weights is True
|
||||
# Trigger the patched hook — it must call _sync_lora_adapter.
|
||||
trainer.vllm_generation.sync_weights()
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_async_mode_with_lora_sync_installs_noop_hook(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True)
|
||||
self._run_patch_branch(trainer)
|
||||
|
||||
assert trainer._patched_sync_weights is True
|
||||
# Hook must be a no-op so BG-thread generation doesn't fight the
|
||||
# main-thread optimizer step over the model weights.
|
||||
trainer.vllm_generation.sync_weights()
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
def test_sync_mode_with_lora_sync_does_not_call_during_install(self):
|
||||
"""Installing the patch should not pre-emptively sync."""
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
# _sync_lora_adapter should only be called when the patched hook
|
||||
# itself is invoked (e.g., from TRL's _generate_single_turn).
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
def test_patch_is_idempotent(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
first_hook = trainer.vllm_generation.sync_weights
|
||||
# Second call must not re-patch (otherwise we'd lose the original).
|
||||
self._run_patch_branch(trainer)
|
||||
assert trainer.vllm_generation.sync_weights is first_hook
|
||||
|
||||
|
||||
class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase):
|
||||
"""``_maybe_sync_vllm_weights`` must not crash when interval is unset.
|
||||
|
||||
Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError
|
||||
on the very first call when ``vllm_sync_interval`` was ``None`` (which
|
||||
is the default for any config that doesn't explicitly set it). We now
|
||||
fall back to interval=1 so unset means "sync every step", matching the
|
||||
behavior of TRL's own ``_generate_single_turn``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_stub_trainer(interval, async_prefetch):
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
class FakeArgs:
|
||||
pass
|
||||
|
||||
args = FakeArgs()
|
||||
args.async_prefetch = async_prefetch
|
||||
args.vllm_sync_interval = interval
|
||||
args.vllm_lora_sync = True
|
||||
|
||||
class FakeState:
|
||||
global_step = 1
|
||||
|
||||
trainer = object.__new__(AsyncGRPOTrainer)
|
||||
trainer.args = args
|
||||
trainer.use_vllm = True
|
||||
trainer.state = FakeState()
|
||||
trainer._last_synced_step = 0
|
||||
trainer._sync_lora_adapter = MagicMock(name="sync_spy")
|
||||
return trainer
|
||||
|
||||
def test_interval_none_in_async_mode_does_not_crash(self):
|
||||
trainer = self._make_stub_trainer(interval=None, async_prefetch=True)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
# Should not raise TypeError — defaults to every-step sync
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_sync_mode_drives_sync(self):
|
||||
"""Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``.
|
||||
|
||||
The previous behavior (early return when ``not async_prefetch``)
|
||||
assumed TRL's stock ``_generate_single_turn`` would handle sync.
|
||||
That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn
|
||||
where the data producer bypasses ``_generate_single_turn``
|
||||
entirely. Without this trigger no sync ever happens and the
|
||||
trainer becomes a no-op.
|
||||
"""
|
||||
trainer = self._make_stub_trainer(interval=1, async_prefetch=False)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_async_mode_with_explicit_interval_respects_modulo(self):
|
||||
trainer = self._make_stub_trainer(interval=4, async_prefetch=True)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
# global_step=1, interval=4 → 1 % 4 != 0 → no sync
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
# global_step=4 → 4 % 4 == 0 → sync
|
||||
trainer.state.global_step = 4
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -54,25 +54,7 @@ except (ImportError, ModuleNotFoundError):
|
||||
)
|
||||
|
||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1]
|
||||
smoe_A = torch.zeros(
|
||||
rank * num_experts,
|
||||
K_inter,
|
||||
device=peft_A.device,
|
||||
dtype=peft_A.dtype,
|
||||
)
|
||||
smoe_B = torch.zeros(
|
||||
N_hidden,
|
||||
rank * num_experts,
|
||||
device=peft_A.device,
|
||||
dtype=peft_A.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
s = e * rank
|
||||
smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T
|
||||
smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T
|
||||
return smoe_A, smoe_B
|
||||
return peft_A, peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
|
||||
def _unwrap_experts_lora(experts_module):
|
||||
return experts_module, None, None
|
||||
@@ -145,11 +127,7 @@ def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):
|
||||
|
||||
|
||||
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Convert peft LoRA for gate_up_proj to scattermoe layout.
|
||||
|
||||
Both gate_up_proj and down_proj need the A<->B swap because
|
||||
scattermoe transposes the parameter (W = param.T).
|
||||
"""
|
||||
"""Convert peft LoRA for gate_up_proj to scattermoe layout."""
|
||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||
|
||||
|
||||
@@ -322,14 +300,16 @@ class TestLoRABLayoutConversion:
|
||||
hidden, inter = 32, 16
|
||||
scaling = 2.0
|
||||
|
||||
peft_A = torch.randn(E * r, hidden)
|
||||
peft_B = torch.randn(inter, E * r)
|
||||
peft_A = torch.randn(E * r, inter)
|
||||
peft_B = torch.randn(hidden, E * r)
|
||||
|
||||
A_r = peft_A.reshape(E, r, hidden)
|
||||
B_r = peft_B.reshape(inter, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
||||
A_r = peft_A.reshape(E, r, inter)
|
||||
B_r = peft_B.reshape(hidden, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||
|
||||
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||
assert smoe_A.shape == (E * r, inter)
|
||||
assert smoe_B.shape == (hidden, E * r)
|
||||
for e in range(E):
|
||||
A_e = smoe_A[e * r : (e + 1) * r, :]
|
||||
B_e = smoe_B[:, e * r : (e + 1) * r]
|
||||
@@ -342,27 +322,26 @@ class TestLoRABLayoutConversion:
|
||||
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
|
||||
|
||||
gate_up_proj param: [E, 2*inter, hidden].
|
||||
peft: in_features=2*inter, out_features=hidden.
|
||||
peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E].
|
||||
peft: in_features=hidden, out_features=2*inter.
|
||||
peft lora_A: [r*E, hidden], lora_B: [2*inter, r*E].
|
||||
|
||||
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
|
||||
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
|
||||
|
||||
Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs.
|
||||
Uses non-square dims (hidden=32 != 2*inter=24) to catch layout bugs.
|
||||
"""
|
||||
E, r = 4, 2
|
||||
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
|
||||
scaling = 2.0
|
||||
|
||||
# peft assigns: in_features=2*inter, out_features=hidden
|
||||
peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter]
|
||||
peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E]
|
||||
# peft assigns: in_features=hidden, out_features=2*inter
|
||||
peft_A = torch.randn(E * r, hidden) # [r*E, in_features=hidden]
|
||||
peft_B = torch.randn(2 * inter, E * r) # [out_features=2*inter, r*E]
|
||||
|
||||
# peft delta via einsum: "o r e, e r i -> e i o"
|
||||
A_r = peft_A.reshape(E, r, 2 * inter)
|
||||
B_r = peft_B.reshape(hidden, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
||||
# delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden]
|
||||
A_r = peft_A.reshape(E, r, hidden)
|
||||
B_r = peft_B.reshape(2 * inter, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||
# delta_peft[e] has shape [out_features, in_features] = [2*inter, hidden]
|
||||
# = param[e] shape [2*inter, hidden]
|
||||
|
||||
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||
@@ -422,22 +401,22 @@ class TestPeftLoRAWeightExtraction:
|
||||
)
|
||||
|
||||
# gate_up_proj [E, 2*inter, hidden]
|
||||
# peft: in_features=2*inter (dim 1), out_features=hidden (dim 2)
|
||||
# peft: in_features=hidden (last dim), out_features=2*inter (middle dim)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
|
||||
].shape == (E * r, 2 * config.intermediate_size)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
|
||||
].shape == (config.hidden_size, E * r)
|
||||
|
||||
# down_proj [E, hidden, inter]
|
||||
# peft: in_features=hidden (dim 1), out_features=inter (dim 2)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.lora_A.default.weight"
|
||||
].shape == (E * r, config.hidden_size)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
|
||||
].shape == (2 * config.intermediate_size, E * r)
|
||||
|
||||
# down_proj [E, hidden, inter]
|
||||
# peft: in_features=inter (last dim), out_features=hidden (middle dim)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.lora_A.default.weight"
|
||||
].shape == (E * r, config.intermediate_size)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.lora_B.default.weight"
|
||||
].shape == (config.intermediate_size, E * r)
|
||||
].shape == (config.hidden_size, E * r)
|
||||
|
||||
@requires_cuda
|
||||
def test_peft_forward_runs(self):
|
||||
@@ -488,27 +467,29 @@ class TestPeftLoRAWeightExtraction:
|
||||
assert gup_lora is not None, "gate_up_proj LoRA not detected"
|
||||
assert down_lora is not None, "down_proj LoRA not detected"
|
||||
|
||||
# Check shapes (after peft->scattermoe conversion with A<->B swap)
|
||||
# gate_up_proj W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter
|
||||
# Check shapes after peft->scattermoe conversion.
|
||||
# gate_up_proj: peft A [E*r, hidden] / B [2*inter, E*r]
|
||||
# scattermoe: smoe_A [E*r, hidden], smoe_B [2*inter, E*r]
|
||||
E, r = config.num_experts, 4
|
||||
gup_A, gup_B, gup_s = gup_lora
|
||||
assert gup_A.shape == (E * r, config.hidden_size), (
|
||||
f"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, "
|
||||
f"gate_up_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
|
||||
f"got {gup_A.shape}"
|
||||
)
|
||||
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
|
||||
f"gate_up_proj smoe_B: expected [N=2*inter, r*E]="
|
||||
f"gate_up_proj smoe_B: expected [2*inter, r*E]="
|
||||
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
|
||||
)
|
||||
|
||||
# down_proj W = param.T = [E, inter, hidden], K=inter, N=hidden
|
||||
# down_proj: peft A [E*r, inter] / B [hidden, E*r]
|
||||
# scattermoe: smoe_A [E*r, inter], smoe_B [hidden, E*r]
|
||||
down_A, down_B, down_s = down_lora
|
||||
assert down_A.shape == (E * r, config.intermediate_size), (
|
||||
f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, "
|
||||
f"down_proj smoe_A: expected [r*E, inter]={(E * r, config.intermediate_size)}, "
|
||||
f"got {down_A.shape}"
|
||||
)
|
||||
assert down_B.shape == (config.hidden_size, E * r), (
|
||||
f"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, "
|
||||
f"down_proj smoe_B: expected [hidden, r*E]={(config.hidden_size, E * r)}, "
|
||||
f"got {down_B.shape}"
|
||||
)
|
||||
|
||||
|
||||
@@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase):
|
||||
assert cfg.dataloader_num_workers == 0
|
||||
|
||||
|
||||
class TestSelectWeightSyncTransport(unittest.TestCase):
|
||||
"""Pure-logic table tests for ``select_weight_sync_transport``."""
|
||||
|
||||
def _caps(self, **kwargs):
|
||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||
|
||||
c = VLLMWeightSyncCapabilities(probed=True)
|
||||
for k, v in kwargs.items():
|
||||
setattr(c, k, v)
|
||||
return c
|
||||
|
||||
def test_lora_with_native_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(lora_filesystem=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||
== "lora_filesystem"
|
||||
)
|
||||
|
||||
def test_lora_with_axolotl_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(lora_axolotl=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||
== "lora_filesystem"
|
||||
)
|
||||
|
||||
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(nccl=True)
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||
== "nccl"
|
||||
)
|
||||
|
||||
def test_full_param_prefers_nccl(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(nccl=True, http_full=True)
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "nccl"
|
||||
)
|
||||
|
||||
def test_full_param_falls_back_to_http(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps(http_full=True)
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "http_full"
|
||||
)
|
||||
|
||||
def test_full_param_no_routes_returns_none(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps() # all False
|
||||
assert (
|
||||
select_weight_sync_transport(
|
||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||
)
|
||||
== "none"
|
||||
)
|
||||
|
||||
def test_lora_no_routes_returns_none(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||
|
||||
caps = self._caps()
|
||||
assert (
|
||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||
== "none"
|
||||
)
|
||||
|
||||
|
||||
class TestProbeVllmWeightSync(unittest.TestCase):
|
||||
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
|
||||
|
||||
def test_stock_vllm_with_lora_enabled(self):
|
||||
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/v1/models": {"get": {}},
|
||||
"/v1/load_lora_adapter": {"post": {}},
|
||||
"/v1/unload_lora_adapter": {"post": {}},
|
||||
"/v1/completions": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.lora_filesystem is True
|
||||
assert caps.lora_axolotl is False
|
||||
assert caps.nccl is False
|
||||
assert caps.http_full is False
|
||||
|
||||
def test_axolotl_serve_lora_full_capabilities(self):
|
||||
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/init_communicator/": {"post": {}},
|
||||
"/update_named_param/": {"post": {}},
|
||||
"/batch_update_named_params/": {"post": {}},
|
||||
"/set_lora_adapter/": {"post": {}},
|
||||
"/clear_lora_adapter/": {"post": {}},
|
||||
"/http_update_weights/": {"post": {}},
|
||||
"/v1/load_lora_adapter": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.nccl is True
|
||||
assert caps.lora_axolotl is True
|
||||
assert caps.lora_filesystem is True
|
||||
assert caps.http_full is True
|
||||
|
||||
def test_trl_vllm_serve_nccl_only(self):
|
||||
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
spec = {
|
||||
"paths": {
|
||||
"/init_communicator/": {"post": {}},
|
||||
"/update_named_param/": {"post": {}},
|
||||
"/batch_update_named_params/": {"post": {}},
|
||||
"/close_communicator/": {"post": {}},
|
||||
"/generate/": {"post": {}},
|
||||
}
|
||||
}
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.return_value.raise_for_status = lambda: None
|
||||
mock_get.return_value.json = lambda: spec
|
||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||
|
||||
assert caps.probed is True
|
||||
assert caps.nccl is True
|
||||
assert caps.lora_filesystem is False
|
||||
assert caps.lora_axolotl is False
|
||||
assert caps.http_full is False
|
||||
|
||||
def test_unreachable_server_records_error(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_get.side_effect = ConnectionError("Connection refused")
|
||||
caps = probe_vllm_weight_sync("http://localhost:9999")
|
||||
|
||||
assert caps.probed is False
|
||||
assert caps.probe_error is not None
|
||||
assert "ConnectionError" in caps.probe_error
|
||||
assert caps.nccl is False
|
||||
assert caps.lora_filesystem is False
|
||||
|
||||
|
||||
class TestPluginWeightSyncEnforcement(unittest.TestCase):
|
||||
"""End-to-end test of post_trainer_create's transport-selection branch.
|
||||
|
||||
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
|
||||
leaving the trainer learning in isolation while vLLM kept serving the
|
||||
unmodified base model. After the fix:
|
||||
|
||||
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
|
||||
- LoRA + only NCCL endpoint → uses NCCL broadcast
|
||||
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
|
||||
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
|
||||
- No usable transport → raises ValueError with a precise diagnosis
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _fake_cfg(adapter, vllm_lora_sync):
|
||||
class FakeTRL:
|
||||
pass
|
||||
|
||||
class FakeCfg:
|
||||
pass
|
||||
|
||||
trl = FakeTRL()
|
||||
trl.vllm_lora_sync = vllm_lora_sync
|
||||
trl.vllm_server_host = "127.0.0.1"
|
||||
trl.vllm_server_port = 8000
|
||||
|
||||
cfg = FakeCfg()
|
||||
cfg.nemo_gym_enabled = True
|
||||
cfg.nemo_gym_model_name = None
|
||||
cfg.base_model = "test/model"
|
||||
cfg.nemo_gym_verify_timeout = 30
|
||||
cfg.nemo_gym_multi_turn = True
|
||||
cfg.adapter = adapter
|
||||
cfg.trl = trl
|
||||
return cfg
|
||||
|
||||
@staticmethod
|
||||
def _fake_trainer():
|
||||
class FakeVLLMGen:
|
||||
sync_weights = staticmethod(lambda: None)
|
||||
|
||||
class FakeTrainer:
|
||||
vllm_generation = FakeVLLMGen()
|
||||
|
||||
return FakeTrainer()
|
||||
|
||||
@staticmethod
|
||||
def _caps(**kwargs):
|
||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||
|
||||
c = VLLMWeightSyncCapabilities(probed=True)
|
||||
for k, v in kwargs.items():
|
||||
setattr(c, k, v)
|
||||
return c
|
||||
|
||||
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(lora_filesystem=True)
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with (
|
||||
patch.object(plugin, "_setup_lora_sync") as setup,
|
||||
patch.object(plugin, "_check_lora_endpoint") as check,
|
||||
patch.object(plugin, "_wire_multi_turn") as wire,
|
||||
):
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
setup.assert_called_once()
|
||||
check.assert_called_once()
|
||||
wire.assert_called_once()
|
||||
|
||||
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps() # all False, but probed
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
msg = str(ctx.exception)
|
||||
assert "no-op trainer" in msg
|
||||
assert "load_lora_adapter" in msg
|
||||
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
|
||||
|
||||
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(nccl=True)
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
|
||||
with patch.object(plugin, "_wire_multi_turn") as wire:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
wire.assert_called_once()
|
||||
|
||||
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps(http_full=True)
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(NotImplementedError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
assert "HTTP weight sync" in str(ctx.exception)
|
||||
|
||||
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
plugin._vllm_caps = self._caps()
|
||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
msg = str(ctx.exception)
|
||||
assert "no-op trainer" in msg
|
||||
assert "init_communicator" in msg
|
||||
assert "http_update_weights" in msg
|
||||
|
||||
def test_unprobed_caps_raises_with_probe_failure_message(self):
|
||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||
|
||||
plugin = NemoGymPlugin()
|
||||
# Plugin._vllm_caps left as default-None: the post_trainer_create
|
||||
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
|
||||
# probed=False, so the error path should mention probing.
|
||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||
trainer = self._fake_trainer()
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
assert "could not probe" in str(ctx.exception)
|
||||
|
||||
|
||||
class TestNemoGymE2E(unittest.TestCase):
|
||||
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
|
||||
|
||||
@@ -452,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase):
|
||||
trainer = self._make_mock_trainer()
|
||||
producer._trainer = trainer
|
||||
|
||||
# Mock the prompt iterator (returns a batch of 1 input)
|
||||
producer._prompt_iter = iter(
|
||||
[
|
||||
[
|
||||
{
|
||||
"prompt": [{"role": "user", "content": "Play Wordle!"}],
|
||||
}
|
||||
]
|
||||
]
|
||||
)
|
||||
producer._prompt_dl = [
|
||||
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
|
||||
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
|
||||
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
|
||||
# copies of each unique prompt — one entry per rollout.
|
||||
_prompt_batch = [
|
||||
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||
]
|
||||
producer._prompt_iter = iter([_prompt_batch])
|
||||
producer._prompt_dl = [_prompt_batch]
|
||||
|
||||
# Call produce
|
||||
result = producer.produce(model=MagicMock(), global_step=1)
|
||||
@@ -530,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase):
|
||||
producer._request_timeout = 30
|
||||
producer._num_generations = 2
|
||||
producer._trainer = self._make_mock_trainer()
|
||||
producer._prompt_iter = iter(
|
||||
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
||||
)
|
||||
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
||||
# RepeatSampler pre-expands by num_generations=2.
|
||||
_prompt_batch = [
|
||||
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||
]
|
||||
producer._prompt_iter = iter([_prompt_batch])
|
||||
producer._prompt_dl = [_prompt_batch]
|
||||
|
||||
result = producer.produce(model=MagicMock(), global_step=1)
|
||||
|
||||
|
||||
@@ -21,6 +21,51 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class TestPeftScatterMoELoRALayout:
|
||||
"""CPU-only guards for PEFT target_parameters layout conversion."""
|
||||
|
||||
def test_peft_layout_keeps_a_and_reorders_b(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
|
||||
peft_lora_to_scattermoe,
|
||||
)
|
||||
|
||||
E, r, K, N = 3, 2, 5, 7
|
||||
scaling = 2.0
|
||||
peft_A = torch.randn(E * r, K)
|
||||
peft_B = torch.randn(N, E * r)
|
||||
|
||||
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||
|
||||
assert smoe_A is peft_A
|
||||
assert smoe_A.shape == (E * r, K)
|
||||
assert smoe_B.shape == (N, E * r)
|
||||
|
||||
A_r = peft_A.reshape(E, r, K)
|
||||
B_r = peft_B.reshape(N, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||
|
||||
for e in range(E):
|
||||
A_e = smoe_A[e * r : (e + 1) * r, :]
|
||||
B_e = smoe_B[:, e * r : (e + 1) * r]
|
||||
torch.testing.assert_close(scaling * (B_e @ A_e), delta_peft[e])
|
||||
|
||||
def test_swapped_layout_fails_before_kernel_dispatch(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
|
||||
validate_scattermoe_lora_shapes,
|
||||
)
|
||||
|
||||
E, r, K, N = 3, 2, 5, 7
|
||||
expert_weights = torch.empty(E, K, N)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid ScatterMoE LoRA layout"):
|
||||
validate_scattermoe_lora_shapes(
|
||||
expert_weights=expert_weights,
|
||||
lora_A=torch.empty(E * r, N),
|
||||
lora_B=torch.empty(K, E * r),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 1. KernelsArgs: disable_mlp_kernel validator
|
||||
# ============================================================================
|
||||
|
||||
@@ -38,6 +38,30 @@ def _reference_norm_noscale(x, eps):
|
||||
return norm(x)
|
||||
|
||||
|
||||
def _reference_partial_norm_rope(x, weight, cos, sin, eps):
|
||||
"""Reference: Gemma4RMSNorm over the full head_dim, then stock
|
||||
``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with
|
||||
the trailing columns passed through unchanged. Mirrors how Llama-style
|
||||
partial rotary is layered on top of the stock RMSNorm + RoPE primitives.
|
||||
"""
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
Gemma4RMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
D = x.shape[-1]
|
||||
n_rot = cos.shape[-1]
|
||||
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
|
||||
norm.weight.data.copy_(weight)
|
||||
normed = norm(x)
|
||||
if n_rot == D:
|
||||
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
|
||||
x_rot = normed[..., :n_rot]
|
||||
x_pass = normed[..., n_rot:]
|
||||
rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2)
|
||||
return torch.cat([rotated, x_pass], dim=-1)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
(2, 64, 32, 256), # sliding window layer shape
|
||||
@@ -194,6 +218,172 @@ class TestFusedRMSNormRoPEBackward:
|
||||
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
|
||||
|
||||
|
||||
class TestFusedRMSNormRoPEPartialRotary:
|
||||
"""Partial-rotary: cos/sin last dim is smaller than head_dim.
|
||||
|
||||
Compares against the original primitives (`Gemma4RMSNorm` +
|
||||
`apply_rotary_pos_emb`) applied to the rotated slice with the trailing
|
||||
columns passed through. Without the kernel fix this used to crash with
|
||||
`RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"B,S,H,D,n_rot",
|
||||
[
|
||||
(2, 16, 4, 64, 32), # half rotary (Llama-style 0.5)
|
||||
(2, 16, 4, 64, 16), # quarter rotary
|
||||
(2, 32, 8, 128, 64), # half rotary, larger heads
|
||||
(1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial
|
||||
(1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path
|
||||
],
|
||||
ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"],
|
||||
)
|
||||
def test_forward_matches_reference(self, B, S, H, D, n_rot):
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
eps = 1e-6
|
||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps)
|
||||
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
||||
|
||||
assert y_fused.shape == y_ref.shape == (B, S, H, D)
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim > 0.999, (
|
||||
f"partial rotary forward cosine_sim={cos_sim:.6f} "
|
||||
f"(B={B},S={S},H={H},D={D},n_rot={n_rot})"
|
||||
)
|
||||
|
||||
# The pass-through tail must equal the reference RMSNorm output bit-
|
||||
# for-bit (any deviation would mean the kernel is touching it with a
|
||||
# spurious rotation, which is the original bug class).
|
||||
torch.testing.assert_close(
|
||||
y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"B,S,H,D,n_rot",
|
||||
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
||||
ids=["half_64", "quarter_256"],
|
||||
)
|
||||
def test_x_grad_matches_reference(self, B, S, H, D, n_rot):
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
eps = 1e-6
|
||||
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Reference backward via the original primitives
|
||||
x_ref = x_data.clone().requires_grad_(True)
|
||||
w_ref = weight_init.clone()
|
||||
y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps)
|
||||
y_ref.sum().backward()
|
||||
|
||||
# Fused backward
|
||||
x_fused = x_data.clone().requires_grad_(True)
|
||||
w_fused = weight_init.clone().requires_grad_(True)
|
||||
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
|
||||
y_fused.sum().backward()
|
||||
|
||||
cos_sim_x = torch.nn.functional.cosine_similarity(
|
||||
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"B,S,H,D,n_rot",
|
||||
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
||||
ids=["half_64", "quarter_256"],
|
||||
)
|
||||
def test_weight_grad_matches_reference(self, B, S, H, D, n_rot):
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
|
||||
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
eps = 1e-6
|
||||
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Reference: Gemma4RMSNorm whose .weight collects grads, then partial
|
||||
# rotary applied to the rotated slice.
|
||||
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
||||
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
|
||||
normed = norm_ref(x_data)
|
||||
from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb
|
||||
|
||||
rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2)
|
||||
y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1)
|
||||
y_ref.sum().backward()
|
||||
|
||||
w_fused = weight_init.clone().requires_grad_(True)
|
||||
fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward()
|
||||
|
||||
cos_sim_w = torch.nn.functional.cosine_similarity(
|
||||
w_fused.grad.flatten().float(),
|
||||
norm_ref.weight.grad.flatten().float(),
|
||||
dim=0,
|
||||
)
|
||||
assert cos_sim_w > 0.995, (
|
||||
f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}"
|
||||
)
|
||||
|
||||
def test_full_rotary_unchanged_when_n_rot_equals_d(self):
|
||||
"""Regression: passing cos/sin with shape == head_dim must still
|
||||
match the full-rotary reference (the partial-rotary code path must
|
||||
not perturb the existing full-rotary output)."""
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
B, S, H, D = 2, 16, 4, 64
|
||||
eps = 1e-6
|
||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
|
||||
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}"
|
||||
|
||||
def test_validation_errors(self):
|
||||
"""Wrapper rejects misshaped inputs cleanly (instead of a cryptic
|
||||
Triton crash deeper in the kernel)."""
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
B, S, H, D = 1, 4, 2, 64
|
||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
w = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# n_rot > head_dim
|
||||
cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
||||
sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
||||
with pytest.raises(ValueError, match="cannot exceed head_dim"):
|
||||
fused_rms_norm_rope(x, w, cos_big, sin_big)
|
||||
|
||||
# cos/sin last-dim mismatch
|
||||
cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16)
|
||||
with pytest.raises(ValueError, match="same last dim"):
|
||||
fused_rms_norm_rope(x, w, cos, sin)
|
||||
|
||||
# odd rotary dim
|
||||
cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
||||
sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
||||
with pytest.raises(ValueError, match="must be even"):
|
||||
fused_rms_norm_rope(x, w, cos_odd, sin_odd)
|
||||
|
||||
|
||||
class TestFusedRMSNormNoScale:
|
||||
"""Tests for v_norm (RMSNorm without learnable scale)."""
|
||||
|
||||
|
||||
219
tests/monkeypatch/test_gemma4_fused_attn.py
Normal file
219
tests/monkeypatch/test_gemma4_fused_attn.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Tests for the Gemma 4 fused-attention monkey-patch.
|
||||
|
||||
These tests exercise the patched ``Gemma4TextAttention.forward`` against
|
||||
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
|
||||
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
|
||||
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the
|
||||
partial-rotary RMSNorm+RoPE path through the fused Triton kernel is
|
||||
exercised end-to-end (this is the bug originally documented in
|
||||
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
|
||||
|
||||
The full-model forward also pins that the fused forward keeps accepting
|
||||
whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the
|
||||
installed transformers version — so any future signature drift on
|
||||
upstream's side trips a clear failure here instead of a confusing
|
||||
TypeError deep in a training run.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"),
|
||||
]
|
||||
|
||||
pytest.importorskip(
|
||||
"transformers.models.gemma4",
|
||||
reason="fused_attn patch only matters when Gemma 4 is available",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restore_gemma4_attention():
|
||||
"""Snapshot ``Gemma4TextAttention.forward`` and restore after the test
|
||||
so the monkey-patch does not leak across the suite."""
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||
|
||||
saved = Gemma4TextAttention.forward
|
||||
yield Gemma4TextAttention
|
||||
Gemma4TextAttention.forward = saved
|
||||
|
||||
|
||||
def _build_hybrid_config():
|
||||
"""Tiny hybrid Gemma 4 config: one sliding layer + one full-attention
|
||||
layer with proportional rope and partial_rotary_factor=0.25. This is
|
||||
the same shape pattern as ``google/gemma-4-26B-A4B-it`` but small
|
||||
enough to fit on any GPU."""
|
||||
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
||||
|
||||
cfg = Gemma4TextConfig(
|
||||
vocab_size=128,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
head_dim=32,
|
||||
global_head_dim=64,
|
||||
layer_types=["sliding_attention", "full_attention"],
|
||||
sliding_window=64,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size_per_layer_input=16,
|
||||
vocab_size_per_layer_input=128,
|
||||
rope_parameters={
|
||||
"sliding_attention": {
|
||||
"rope_type": "default",
|
||||
"rope_theta": 10000.0,
|
||||
},
|
||||
"full_attention": {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 1000000.0,
|
||||
"partial_rotary_factor": 0.25,
|
||||
},
|
||||
},
|
||||
)
|
||||
cfg._attn_implementation = "sdpa"
|
||||
return cfg
|
||||
|
||||
|
||||
def _build_model(seed=0):
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
|
||||
|
||||
torch.manual_seed(seed)
|
||||
cfg = _build_hybrid_config()
|
||||
return Gemma4TextModel(cfg).cuda().to(torch.bfloat16).eval()
|
||||
|
||||
|
||||
class TestFusedAttnSignature:
|
||||
"""The fused forward must accept the same call shape as
|
||||
``Gemma4TextDecoderLayer`` produces in the installed transformers
|
||||
version. Any signature drift surfaces here as a TypeError."""
|
||||
|
||||
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
|
||||
"""Run a model forward that exercises the real
|
||||
``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with
|
||||
the fused patch installed."""
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
m = _build_model()
|
||||
ids = torch.randint(0, 128, (2, 16), device="cuda")
|
||||
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
|
||||
|
||||
patch_gemma4_fused_attn()
|
||||
with torch.no_grad():
|
||||
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
|
||||
|
||||
assert out.shape == (2, 16, 64)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
class TestFusedAttnPerLayerCorrectness:
|
||||
"""Compare the patched attention layer to the stock implementation
|
||||
on a single forward call. This isolates the fused kernel correctness
|
||||
from cross-layer numerical drift."""
|
||||
|
||||
def _run_attention(self, model, layer_idx, hidden_states, position_ids):
|
||||
"""Call ``Gemma4TextAttention.forward`` (whatever is currently
|
||||
installed) for one layer and return the output."""
|
||||
attn = model.layers[layer_idx].self_attn
|
||||
layer_type = model.config.layer_types[layer_idx]
|
||||
cos, sin = model.rotary_emb(hidden_states, position_ids, layer_type)
|
||||
out, _ = attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=(cos, sin),
|
||||
attention_mask=None,
|
||||
shared_kv_states={},
|
||||
)
|
||||
return out
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_idx",
|
||||
[0, 1],
|
||||
ids=["sliding_head32", "global_head64_proportional"],
|
||||
)
|
||||
def test_forward_matches_stock(self, restore_gemma4_attention, layer_idx):
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
m = _build_model(seed=1)
|
||||
hs = torch.randn(2, 16, 64, device="cuda", dtype=torch.bfloat16)
|
||||
pos = torch.arange(16, device="cuda").unsqueeze(0).expand(2, -1)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = self._run_attention(m, layer_idx, hs, pos)
|
||||
|
||||
patch_gemma4_fused_attn()
|
||||
with torch.no_grad():
|
||||
got = self._run_attention(m, layer_idx, hs, pos)
|
||||
|
||||
assert got.shape == ref.shape
|
||||
assert torch.isfinite(got).all()
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
ref.flatten().float(), got.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim > 0.999, (
|
||||
f"layer {layer_idx} fused vs stock cosine_sim={cos_sim:.6f}"
|
||||
)
|
||||
# bf16 precision: a few millis of absolute drift per element is
|
||||
# acceptable for a Q/K/V projection pipeline. Anything larger is
|
||||
# a real bug.
|
||||
torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-2)
|
||||
|
||||
|
||||
class TestFusedAttnFullModel:
|
||||
"""End-to-end model forward + backward through both layer types."""
|
||||
|
||||
def test_full_forward_matches_stock(self, restore_gemma4_attention):
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
m = _build_model(seed=2)
|
||||
ids = torch.randint(0, 128, (2, 32), device="cuda")
|
||||
mask = torch.ones(2, 32, dtype=torch.long, device="cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
ref = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
|
||||
|
||||
patch_gemma4_fused_attn()
|
||||
with torch.no_grad():
|
||||
got = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
|
||||
|
||||
assert got.shape == ref.shape
|
||||
assert torch.isfinite(got).all()
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
ref.flatten().float(), got.flatten().float(), dim=0
|
||||
)
|
||||
# End-to-end through 2 layers (RMSNorm, attention, MLP/MoE) in bf16
|
||||
# accumulates a small amount of numerical drift; we just want to
|
||||
# pin that the two paths are computing the same function.
|
||||
assert cos_sim > 0.999, f"end-to-end cosine_sim={cos_sim:.6f}"
|
||||
|
||||
def test_backward_grad_flows_through_fused_path(self, restore_gemma4_attention):
|
||||
"""Gradients must propagate through the fused RMSNorm+RoPE kernels
|
||||
for both the sliding and proportional-rope layers."""
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
m = _build_model(seed=3).train()
|
||||
patch_gemma4_fused_attn()
|
||||
|
||||
ids = torch.randint(0, 128, (2, 16), device="cuda")
|
||||
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
|
||||
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
|
||||
out.sum().backward()
|
||||
|
||||
# Both layers must accumulate gradients on q_norm.weight and
|
||||
# k_norm.weight — that proves the fused kernel ran the backward.
|
||||
for i, layer in enumerate(m.layers[:2]):
|
||||
attn = layer.self_attn
|
||||
assert attn.q_norm.weight.grad is not None, f"layer {i} q_norm no grad"
|
||||
assert attn.k_norm.weight.grad is not None, f"layer {i} k_norm no grad"
|
||||
assert attn.q_norm.weight.grad.isfinite().all()
|
||||
assert attn.k_norm.weight.grad.isfinite().all()
|
||||
assert attn.q_norm.weight.grad.abs().sum() > 0
|
||||
assert attn.k_norm.weight.grad.abs().sum() > 0
|
||||
343
tests/monkeypatch/test_gemma4_hybrid_mask.py
Normal file
343
tests/monkeypatch/test_gemma4_hybrid_mask.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Tests for the Gemma 4 hybrid-attention mask fix.
|
||||
|
||||
These tests pin the single critical behavior: after installing the patch,
|
||||
``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to
|
||||
the underlying mask builder regardless of what the caller's config says.
|
||||
This is what keeps full-attention (head_dim=512) global layers from
|
||||
crashing at long sequence lengths — they need a 4D SDPA-format mask, not
|
||||
the 2D FA2 mask that would be built from the model-level config.
|
||||
|
||||
The tests use a mocked ``create_causal_mask`` so they don't have to load
|
||||
a real 26B Gemma 4 model or even have access to its weights. What matters
|
||||
for the bug fix is which config is handed to the mask factory, not the
|
||||
factory's actual output.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip(
|
||||
"transformers.models.gemma4",
|
||||
reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restore_gemma4_module():
|
||||
"""Snapshot ``modeling_gemma4.create_causal_mask`` and restore after
|
||||
each test so patch state doesn't leak across the suite."""
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
|
||||
saved = modeling_gemma4.create_causal_mask
|
||||
yield modeling_gemma4
|
||||
modeling_gemma4.create_causal_mask = saved
|
||||
# Reset the module-level flag so the next test can re-install cleanly.
|
||||
from axolotl.monkeypatch import gemma4_hybrid_mask
|
||||
|
||||
gemma4_hybrid_mask._PATCH_APPLIED = False
|
||||
|
||||
|
||||
def test_patch_replaces_create_causal_mask(restore_gemma4_module):
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
original = modeling_gemma4.create_causal_mask
|
||||
assert patch_gemma4_hybrid_mask() is True
|
||||
|
||||
assert modeling_gemma4.create_causal_mask is not original
|
||||
assert modeling_gemma4.create_causal_mask._axolotl_original is original, (
|
||||
"patched wrapper must expose the original reference for teardown"
|
||||
)
|
||||
|
||||
|
||||
def test_patch_is_idempotent(restore_gemma4_module):
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
patch_gemma4_hybrid_mask()
|
||||
wrapper_first = modeling_gemma4.create_causal_mask
|
||||
|
||||
# Second call must not re-wrap the already-wrapped function (which
|
||||
# would leak the original reference through a chain of wrappers).
|
||||
patch_gemma4_hybrid_mask()
|
||||
wrapper_second = modeling_gemma4.create_causal_mask
|
||||
|
||||
assert wrapper_first is wrapper_second
|
||||
|
||||
|
||||
def test_patched_mask_forces_sdpa_config(restore_gemma4_module):
|
||||
"""Core invariant: when the patched wrapper is called with a config
|
||||
that says ``flash_attention_2``, the underlying mask factory receives
|
||||
a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``.
|
||||
|
||||
Without this, the full-attention global layers get a 2D FA2 mask and
|
||||
crash at long seq lens with the [B, H, S, S] / [B, S] expand error.
|
||||
"""
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
# Swap in a mock BEFORE installing the patch so the wrapper captures
|
||||
# it as the "original". The mock records every call so we can inspect
|
||||
# what config got passed through.
|
||||
mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d")
|
||||
modeling_gemma4.create_causal_mask = mock_factory
|
||||
patch_gemma4_hybrid_mask()
|
||||
|
||||
# Caller-supplied config says FA2 (that's the model-level setting).
|
||||
caller_config = SimpleNamespace(
|
||||
_attn_implementation="flash_attention_2",
|
||||
head_dim=512,
|
||||
some_other_attr="preserved",
|
||||
)
|
||||
result = modeling_gemma4.create_causal_mask(
|
||||
caller_config,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
position_ids=None,
|
||||
)
|
||||
|
||||
# Wrapper returned whatever the mock returned — no transformation of
|
||||
# the result itself.
|
||||
assert result == "mask_4d"
|
||||
|
||||
# The mock was called exactly once with a config whose
|
||||
# ``_attn_implementation`` is sdpa, NOT the caller's fa2.
|
||||
assert mock_factory.call_count == 1
|
||||
(passed_config, *_), passed_kwargs = mock_factory.call_args
|
||||
assert passed_config._attn_implementation == "sdpa"
|
||||
|
||||
# The wrapper must NOT mutate the caller's config in place — other
|
||||
# mask builders (e.g. create_sliding_window_causal_mask) read from
|
||||
# the same config and must still see fa2.
|
||||
assert caller_config._attn_implementation == "flash_attention_2"
|
||||
|
||||
# Other attributes on the config must be preserved so the underlying
|
||||
# factory has everything it needs (head_dim, rope_theta, vocab_size, ...).
|
||||
assert passed_config.head_dim == 512
|
||||
assert passed_config.some_other_attr == "preserved"
|
||||
|
||||
|
||||
def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module):
|
||||
"""The wrapper must forward positional + keyword args to the original
|
||||
unchanged, so transformers' own call-site in Gemma4TextModel.forward
|
||||
keeps working across minor transformers-version signature drift."""
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
mock_factory = MagicMock(return_value="mask")
|
||||
modeling_gemma4.create_causal_mask = mock_factory
|
||||
patch_gemma4_hybrid_mask()
|
||||
|
||||
caller_config = SimpleNamespace(_attn_implementation="flash_attention_2")
|
||||
modeling_gemma4.create_causal_mask(
|
||||
caller_config,
|
||||
"positional_arg",
|
||||
inputs_embeds="embeds",
|
||||
attention_mask="mask_2d",
|
||||
past_key_values="cache",
|
||||
position_ids="positions",
|
||||
or_mask_function="or_fn",
|
||||
)
|
||||
|
||||
args, kwargs = mock_factory.call_args
|
||||
# First positional (after config override) is preserved.
|
||||
assert args[1] == "positional_arg"
|
||||
# All kwargs are forwarded untouched.
|
||||
assert kwargs["inputs_embeds"] == "embeds"
|
||||
assert kwargs["attention_mask"] == "mask_2d"
|
||||
assert kwargs["past_key_values"] == "cache"
|
||||
assert kwargs["position_ids"] == "positions"
|
||||
assert kwargs["or_mask_function"] == "or_fn"
|
||||
|
||||
|
||||
def test_unpatch_restores_original(restore_gemma4_module):
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import (
|
||||
patch_gemma4_hybrid_mask,
|
||||
unpatch_gemma4_hybrid_mask,
|
||||
)
|
||||
|
||||
sentinel = MagicMock(name="original")
|
||||
modeling_gemma4.create_causal_mask = sentinel
|
||||
patch_gemma4_hybrid_mask()
|
||||
assert modeling_gemma4.create_causal_mask is not sentinel
|
||||
|
||||
unpatch_gemma4_hybrid_mask()
|
||||
assert modeling_gemma4.create_causal_mask is sentinel
|
||||
|
||||
|
||||
def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module):
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask
|
||||
|
||||
# Should be a no-op, no exception.
|
||||
unpatch_gemma4_hybrid_mask()
|
||||
|
||||
|
||||
def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module):
|
||||
"""Only ``create_causal_mask`` is overridden — the sliding-window
|
||||
factory must remain bound to its original to preserve FA2 masks for
|
||||
the sliding-attention layers. If we accidentally patch both, the
|
||||
sliding layers get SDPA format and lose the FA2 speedup."""
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"):
|
||||
pytest.skip("transformers version has no create_sliding_window_causal_mask")
|
||||
|
||||
sliding_before = modeling_gemma4.create_sliding_window_causal_mask
|
||||
patch_gemma4_hybrid_mask()
|
||||
sliding_after = modeling_gemma4.create_sliding_window_causal_mask
|
||||
assert sliding_after is sliding_before
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests with a tiny randomly-initialized Gemma4TextModel.
|
||||
#
|
||||
# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text
|
||||
# model with 2 layers (one sliding, one full_attention), apply the hybrid
|
||||
# attention path end-to-end, and run a forward pass with a padded
|
||||
# attention_mask at a long-ish seq len. The invariant we're pinning is that
|
||||
# the full_attention layer does not crash with the
|
||||
# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]"
|
||||
# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k
|
||||
# tokens in the FSDP2 training run.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_tiny_gemma4_text_model():
|
||||
"""Return a tiny randomly-initialized Gemma4TextModel with mixed layers."""
|
||||
import torch
|
||||
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
|
||||
|
||||
cfg = Gemma4TextConfig(
|
||||
vocab_size=128,
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
head_dim=32,
|
||||
layer_types=["sliding_attention", "full_attention"],
|
||||
sliding_window=64,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size_per_layer_input=16,
|
||||
vocab_size_per_layer_input=128,
|
||||
)
|
||||
# Caller-supplied attn impl simulates the pilot config (fa2 at model
|
||||
# level). The hybrid patch is what makes this survive long context.
|
||||
cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later
|
||||
torch.manual_seed(42)
|
||||
model = Gemma4TextModel(cfg).eval()
|
||||
return model, cfg
|
||||
|
||||
|
||||
def _apply_hybrid_attn_inline(model, cfg):
|
||||
"""Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does
|
||||
to a model, without needing a full PatchManager / pydantic cfg."""
|
||||
import copy
|
||||
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
for layer_idx, layer in enumerate(model.layers):
|
||||
if cfg.layer_types[layer_idx] != "sliding_attention":
|
||||
attn = getattr(layer, "self_attn", None)
|
||||
if attn is not None and hasattr(attn, "config"):
|
||||
sdpa_cfg = copy.copy(attn.config)
|
||||
sdpa_cfg._attn_implementation = "sdpa"
|
||||
attn.config = sdpa_cfg
|
||||
patch_gemma4_hybrid_mask()
|
||||
|
||||
|
||||
def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module):
|
||||
"""End-to-end invariant: with the hybrid attn patch applied, a tiny
|
||||
Gemma4TextModel runs a forward at long context (1024 tokens) with
|
||||
real padding in the attention mask, producing the expected output
|
||||
shape. This exercises the actual code path that crashed the pilot
|
||||
without needing a real 26B checkpoint or CUDA."""
|
||||
import torch
|
||||
|
||||
model, cfg = _build_tiny_gemma4_text_model()
|
||||
_apply_hybrid_attn_inline(model, cfg)
|
||||
|
||||
B, S = 2, 1024
|
||||
input_ids = torch.randint(0, cfg.vocab_size, (B, S))
|
||||
attn_mask = torch.ones(B, S, dtype=torch.long)
|
||||
# Pad positions in the second row. Without padding, SDPA falls back to
|
||||
# ``is_causal=True`` with ``mask=None`` — we need a materialized 4D
|
||||
# mask to exercise the actual bug site.
|
||||
attn_mask[1, S // 2 :] = 0
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(input_ids=input_ids, attention_mask=attn_mask)
|
||||
|
||||
assert out.last_hidden_state.shape == (B, S, cfg.hidden_size)
|
||||
assert torch.isfinite(out.last_hidden_state).all()
|
||||
|
||||
|
||||
def test_patched_create_causal_mask_returns_4d_for_real_config(
|
||||
restore_gemma4_module,
|
||||
):
|
||||
"""Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper
|
||||
and verify the returned mask is a 4D tensor — which is the shape the
|
||||
SDPA-patched global layers need. Without the patch and with a
|
||||
caller-supplied FA2 config this would return a 2D mask and the layer
|
||||
would crash at long context."""
|
||||
import torch
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
||||
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
patch_gemma4_hybrid_mask()
|
||||
modeling_gemma4 = restore_gemma4_module
|
||||
|
||||
cfg = Gemma4TextConfig(
|
||||
vocab_size=128,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
head_dim=32,
|
||||
layer_types=["sliding_attention", "full_attention"],
|
||||
sliding_window=64,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size_per_layer_input=16,
|
||||
vocab_size_per_layer_input=128,
|
||||
)
|
||||
# Simulate the pilot: caller says flash_attention_2, but global layers
|
||||
# were switched to SDPA per-layer. Without the patch, create_causal_mask
|
||||
# would return an FA2 2D mask here and the SDPA layer would crash.
|
||||
cfg._attn_implementation = "flash_attention_2"
|
||||
|
||||
B, S = 2, 1024
|
||||
inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32)
|
||||
attention_mask = torch.ones((B, S), dtype=torch.long)
|
||||
attention_mask[1, S // 2 :] = 0 # force the 4D materialized path
|
||||
position_ids = torch.arange(S).unsqueeze(0).expand(B, -1)
|
||||
past_key_values = DynamicCache(config=cfg)
|
||||
|
||||
mask = modeling_gemma4.create_causal_mask(
|
||||
config=cfg,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
assert mask is not None
|
||||
assert isinstance(mask, torch.Tensor)
|
||||
assert mask.dim() == 4, (
|
||||
f"expected a 4D SDPA-format mask, got {mask.dim()}D "
|
||||
f"shape={tuple(mask.shape)}. The full_attention global layers need "
|
||||
"this shape or they crash at long context."
|
||||
)
|
||||
assert mask.shape[0] == B
|
||||
assert mask.shape[-1] == S
|
||||
assert mask.shape[-2] == S
|
||||
|
||||
# Caller's config must be untouched — other code paths still read it.
|
||||
assert cfg._attn_implementation == "flash_attention_2"
|
||||
@@ -491,7 +491,8 @@ class TestEfficientMerge:
|
||||
out_features = 4
|
||||
alpha = 4
|
||||
|
||||
base = torch.randn(num_experts, in_features, out_features)
|
||||
# PEFT ParamWrapper treats non-transposed 3D weights as (experts, out, in)
|
||||
base = torch.randn(num_experts, out_features, in_features)
|
||||
lora_a = torch.randn(r * num_experts, in_features)
|
||||
lora_b = torch.randn(out_features, r * num_experts)
|
||||
|
||||
@@ -507,7 +508,7 @@ class TestEfficientMerge:
|
||||
scale = alpha / r
|
||||
wa = lora_a.reshape(num_experts, r, in_features)
|
||||
wb = lora_b.reshape(out_features, r, num_experts)
|
||||
manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale
|
||||
manual_delta = torch.einsum("o r e, e r i -> e o i", wb, wa) * scale
|
||||
for e in range(num_experts):
|
||||
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
|
||||
f"Expert {e} mismatch"
|
||||
|
||||
@@ -5,6 +5,8 @@ Covers:
|
||||
- save_strategy: 'best' requires metric_for_best_model
|
||||
- streaming=True with val_set_size > 0 is rejected
|
||||
- lora_target_modules with invalid regex patterns is rejected
|
||||
- GRPO: generation batch size must be divisible by num_generations,
|
||||
num_generations >= 2, and effective_gbs >= num_generations * world_size
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@@ -117,3 +119,136 @@ class TestLoraTargetModulesRegexValidator:
|
||||
)
|
||||
with pytest.raises(ValueError, match="invalid regex pattern"):
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class TestGRPOBatchSizeValidator:
|
||||
"""GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2.
|
||||
|
||||
These call the @model_validator(mode="before") classmethod directly on a
|
||||
plain dict — same input shape it receives during full Pydantic validation,
|
||||
just without dragging in unrelated fields (datasets / model loading / etc.)
|
||||
that aren't relevant to what's under test. The validator is registered on
|
||||
``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the
|
||||
same code path ``axolotl train`` exercises.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _check(data):
|
||||
from axolotl.utils.schemas.validation import RLValidationMixin
|
||||
|
||||
return RLValidationMixin.check_grpo_batch_size_divisibility(data)
|
||||
|
||||
def test_divisible_passes(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
# Should return data unchanged (no exception)
|
||||
out = self._check(data)
|
||||
assert out["trl"]["num_generations"] == 4
|
||||
|
||||
def test_non_divisible_raises(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
with pytest.raises(ValueError, match="num_generations"):
|
||||
self._check(data)
|
||||
|
||||
def test_non_divisible_error_includes_fix_hint(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"):
|
||||
self._check(data)
|
||||
|
||||
def test_num_generations_one_raises(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"trl": {"num_generations": 1},
|
||||
}
|
||||
with pytest.raises(ValueError, match=r"num_generations >= 2"):
|
||||
self._check(data)
|
||||
|
||||
def test_explicit_generation_batch_size_divisible_passes(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"trl": {"num_generations": 4, "generation_batch_size": 8},
|
||||
}
|
||||
out = self._check(data)
|
||||
assert out["trl"]["generation_batch_size"] == 8
|
||||
|
||||
def test_explicit_generation_batch_size_non_divisible_raises(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"trl": {"num_generations": 4, "generation_batch_size": 6},
|
||||
}
|
||||
with pytest.raises(ValueError, match="trl.generation_batch_size"):
|
||||
self._check(data)
|
||||
|
||||
def test_non_grpo_skips_check(self):
|
||||
# Anything other than rl=grpo should pass through untouched, even
|
||||
# with non-divisible batch sizes — they're irrelevant to other RL
|
||||
# methods that don't use group-relative advantages.
|
||||
data = {
|
||||
"rl": "dpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
assert self._check(data) is data
|
||||
|
||||
def test_no_rl_set_skips_check(self):
|
||||
data = {
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
}
|
||||
assert self._check(data) is data
|
||||
|
||||
def test_grpo_without_num_generations_skips_check(self):
|
||||
# If num_generations isn't set, TRL uses its own default — we don't
|
||||
# have enough info to validate, so the validator must short-circuit
|
||||
# rather than guess.
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
"trl": {},
|
||||
}
|
||||
out = self._check(data)
|
||||
assert out["rl"] == "grpo"
|
||||
|
||||
def test_multi_rank_group_size_check(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4, # gbs=4
|
||||
"world_size": 2, # need gbs >= 4*2 = 8
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
with pytest.raises(ValueError, match=r"world_size=2"):
|
||||
self._check(data)
|
||||
|
||||
def test_multi_rank_group_size_satisfied(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8, # gbs=8 >= 4*2
|
||||
"world_size": 2,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
out = self._check(data)
|
||||
assert out["gradient_accumulation_steps"] == 8
|
||||
|
||||
Reference in New Issue
Block a user