validate batch shape against num_generations at config time

Surfaces a class of GRPO config errors at axolotl-train startup instead
of letting them bubble out of GRPOTrainer.__init__ after the model loads.
Three checks under RLValidationMixin.check_grpo_batch_size_divisibility:

  - effective generation_batch_size (or mb*GA fallback) must be divisible
    by trl.num_generations, with a hint pointing at the smallest GA bump
    that fixes the violation
  - num_generations >= 2 (group-relative advantage needs variance; with
    num_gen=1 the policy never updates)
  - When world_size > 1, effective gbs >= num_generations * world_size

11 unit tests cover the table: divisible/non-divisible, explicit and
implicit gbs, multi-rank constraint, GRPO-disabled passthrough, and
unset num_generations.
This commit is contained in:
Wing Lian
2026-04-13 18:29:22 +00:00
parent 323da791eb
commit 80a97f192b
2 changed files with 217 additions and 0 deletions

View File

@@ -770,6 +770,88 @@ class RLValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_grpo_batch_size_divisibility(cls, data):
"""Surface GRPO batch-shape mismatches at config-parse time.
TRL's GRPOTrainer requires that the per-step generation batch size be
evenly divisible by ``num_generations`` so that every prompt can be
replicated exactly ``num_generations`` times. The runtime check inside
``GRPOTrainer.__init__`` only fires after the model has been loaded —
too late and too cryptic for the user. We replicate the check here so
the failure is immediate and actionable.
Also enforces:
- ``num_generations >= 2`` (group-relative advantage needs variance)
- ``effective_gbs >= num_generations * world_size`` when capabilities
indicate multiple ranks (each rank needs at least one full group)
"""
if data.get("rl") != "grpo":
return data
trl_cfg = data.get("trl") or {}
num_gen = trl_cfg.get("num_generations")
if num_gen is None:
# TRL's own default is 8 — but if the user didn't set it, we
# don't have enough info to validate anything. Let TRL's own
# init handle the default-vs-batch interaction.
return data
if num_gen < 2:
raise ValueError(
f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). "
"With num_generations=1, every group has zero advantage and "
"the policy never updates."
)
explicit_gbs = trl_cfg.get("generation_batch_size")
if explicit_gbs is not None:
effective_gbs = int(explicit_gbs)
gbs_source = "trl.generation_batch_size"
else:
mb = data.get("micro_batch_size") or 1
ga = data.get("gradient_accumulation_steps") or 1
effective_gbs = int(mb) * int(ga)
gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})"
if effective_gbs % num_gen != 0:
# Suggest the smallest GA bump that fixes it for the common case
# where the user hasn't set generation_batch_size explicitly.
hint = ""
if explicit_gbs is None:
from math import gcd
mb_val = int(data.get("micro_batch_size") or 1)
# smallest GA such that mb*GA is a multiple of num_gen
lcm = num_gen * mb_val // gcd(num_gen, mb_val)
suggested_ga = lcm // mb_val
hint = (
f" Smallest fix: set `gradient_accumulation_steps: "
f"{suggested_ga}` (so micro_batch_size * GA = "
f"{mb_val * suggested_ga} is a multiple of {num_gen})."
)
raise ValueError(
f"GRPO: generation batch size must be divisible by "
f"`trl.num_generations`. Got effective_gbs={effective_gbs} "
f"(from {gbs_source}) and num_generations={num_gen}.{hint}"
)
# Multi-rank check: each rank must receive at least one full group
# per step. Without `capabilities` populated yet (mode='before'), we
# fall back to user-set distributed fields.
world_size = (
(data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1
)
if world_size and world_size > 1 and effective_gbs < num_gen * world_size:
raise ValueError(
f"GRPO with world_size={world_size} requires effective_gbs "
f">= num_generations * world_size = {num_gen * world_size}, "
f"got {effective_gbs}. Increase gradient_accumulation_steps "
f"or micro_batch_size."
)
return data
class OptimizationValidationMixin:
"""Validation methods related to optimization and performance."""

View File

@@ -5,6 +5,8 @@ Covers:
- save_strategy: 'best' requires metric_for_best_model
- streaming=True with val_set_size > 0 is rejected
- lora_target_modules with invalid regex patterns is rejected
- GRPO: generation batch size must be divisible by num_generations,
num_generations >= 2, and effective_gbs >= num_generations * world_size
"""
import pytest
@@ -117,3 +119,136 @@ class TestLoraTargetModulesRegexValidator:
)
with pytest.raises(ValueError, match="invalid regex pattern"):
validate_config(cfg)
class TestGRPOBatchSizeValidator:
"""GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2.
These call the @model_validator(mode="before") classmethod directly on a
plain dict — same input shape it receives during full Pydantic validation,
just without dragging in unrelated fields (datasets / model loading / etc.)
that aren't relevant to what's under test. The validator is registered on
``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the
same code path ``axolotl train`` exercises.
"""
@staticmethod
def _check(data):
from axolotl.utils.schemas.validation import RLValidationMixin
return RLValidationMixin.check_grpo_batch_size_divisibility(data)
def test_divisible_passes(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"trl": {"num_generations": 4},
}
# Should return data unchanged (no exception)
out = self._check(data)
assert out["trl"]["num_generations"] == 4
def test_non_divisible_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 2,
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match="num_generations"):
self._check(data)
def test_non_divisible_error_includes_fix_hint(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"):
self._check(data)
def test_num_generations_one_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"trl": {"num_generations": 1},
}
with pytest.raises(ValueError, match=r"num_generations >= 2"):
self._check(data)
def test_explicit_generation_batch_size_divisible_passes(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"trl": {"num_generations": 4, "generation_batch_size": 8},
}
out = self._check(data)
assert out["trl"]["generation_batch_size"] == 8
def test_explicit_generation_batch_size_non_divisible_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"trl": {"num_generations": 4, "generation_batch_size": 6},
}
with pytest.raises(ValueError, match="trl.generation_batch_size"):
self._check(data)
def test_non_grpo_skips_check(self):
# Anything other than rl=grpo should pass through untouched, even
# with non-divisible batch sizes — they're irrelevant to other RL
# methods that don't use group-relative advantages.
data = {
"rl": "dpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {"num_generations": 4},
}
assert self._check(data) is data
def test_no_rl_set_skips_check(self):
data = {
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
}
assert self._check(data) is data
def test_grpo_without_num_generations_skips_check(self):
# If num_generations isn't set, TRL uses its own default — we don't
# have enough info to validate, so the validator must short-circuit
# rather than guess.
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {},
}
out = self._check(data)
assert out["rl"] == "grpo"
def test_multi_rank_group_size_check(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4, # gbs=4
"world_size": 2, # need gbs >= 4*2 = 8
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match=r"world_size=2"):
self._check(data)
def test_multi_rank_group_size_satisfied(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 8, # gbs=8 >= 4*2
"world_size": 2,
"trl": {"num_generations": 4},
}
out = self._check(data)
assert out["gradient_accumulation_steps"] == 8