validate batch shape against num_generations at config time
Surfaces a class of GRPO config errors at axolotl-train startup instead
of letting them bubble out of GRPOTrainer.__init__ after the model loads.
Three checks under RLValidationMixin.check_grpo_batch_size_divisibility:
- effective generation_batch_size (or mb*GA fallback) must be divisible
by trl.num_generations, with a hint pointing at the smallest GA bump
that fixes the violation
- num_generations >= 2 (group-relative advantage needs variance; with
num_gen=1 the policy never updates)
- When world_size > 1, effective gbs >= num_generations * world_size
11 unit tests cover the table: divisible/non-divisible, explicit and
implicit gbs, multi-rank constraint, GRPO-disabled passthrough, and
unset num_generations.
This commit is contained in:
@@ -770,6 +770,88 @@ class RLValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_batch_size_divisibility(cls, data):
|
||||
"""Surface GRPO batch-shape mismatches at config-parse time.
|
||||
|
||||
TRL's GRPOTrainer requires that the per-step generation batch size be
|
||||
evenly divisible by ``num_generations`` so that every prompt can be
|
||||
replicated exactly ``num_generations`` times. The runtime check inside
|
||||
``GRPOTrainer.__init__`` only fires after the model has been loaded —
|
||||
too late and too cryptic for the user. We replicate the check here so
|
||||
the failure is immediate and actionable.
|
||||
|
||||
Also enforces:
|
||||
- ``num_generations >= 2`` (group-relative advantage needs variance)
|
||||
- ``effective_gbs >= num_generations * world_size`` when capabilities
|
||||
indicate multiple ranks (each rank needs at least one full group)
|
||||
"""
|
||||
if data.get("rl") != "grpo":
|
||||
return data
|
||||
|
||||
trl_cfg = data.get("trl") or {}
|
||||
num_gen = trl_cfg.get("num_generations")
|
||||
if num_gen is None:
|
||||
# TRL's own default is 8 — but if the user didn't set it, we
|
||||
# don't have enough info to validate anything. Let TRL's own
|
||||
# init handle the default-vs-batch interaction.
|
||||
return data
|
||||
if num_gen < 2:
|
||||
raise ValueError(
|
||||
f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). "
|
||||
"With num_generations=1, every group has zero advantage and "
|
||||
"the policy never updates."
|
||||
)
|
||||
|
||||
explicit_gbs = trl_cfg.get("generation_batch_size")
|
||||
if explicit_gbs is not None:
|
||||
effective_gbs = int(explicit_gbs)
|
||||
gbs_source = "trl.generation_batch_size"
|
||||
else:
|
||||
mb = data.get("micro_batch_size") or 1
|
||||
ga = data.get("gradient_accumulation_steps") or 1
|
||||
effective_gbs = int(mb) * int(ga)
|
||||
gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})"
|
||||
|
||||
if effective_gbs % num_gen != 0:
|
||||
# Suggest the smallest GA bump that fixes it for the common case
|
||||
# where the user hasn't set generation_batch_size explicitly.
|
||||
hint = ""
|
||||
if explicit_gbs is None:
|
||||
from math import gcd
|
||||
|
||||
mb_val = int(data.get("micro_batch_size") or 1)
|
||||
# smallest GA such that mb*GA is a multiple of num_gen
|
||||
lcm = num_gen * mb_val // gcd(num_gen, mb_val)
|
||||
suggested_ga = lcm // mb_val
|
||||
hint = (
|
||||
f" Smallest fix: set `gradient_accumulation_steps: "
|
||||
f"{suggested_ga}` (so micro_batch_size * GA = "
|
||||
f"{mb_val * suggested_ga} is a multiple of {num_gen})."
|
||||
)
|
||||
raise ValueError(
|
||||
f"GRPO: generation batch size must be divisible by "
|
||||
f"`trl.num_generations`. Got effective_gbs={effective_gbs} "
|
||||
f"(from {gbs_source}) and num_generations={num_gen}.{hint}"
|
||||
)
|
||||
|
||||
# Multi-rank check: each rank must receive at least one full group
|
||||
# per step. Without `capabilities` populated yet (mode='before'), we
|
||||
# fall back to user-set distributed fields.
|
||||
world_size = (
|
||||
(data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1
|
||||
)
|
||||
if world_size and world_size > 1 and effective_gbs < num_gen * world_size:
|
||||
raise ValueError(
|
||||
f"GRPO with world_size={world_size} requires effective_gbs "
|
||||
f">= num_generations * world_size = {num_gen * world_size}, "
|
||||
f"got {effective_gbs}. Increase gradient_accumulation_steps "
|
||||
f"or micro_batch_size."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class OptimizationValidationMixin:
|
||||
"""Validation methods related to optimization and performance."""
|
||||
|
||||
@@ -5,6 +5,8 @@ Covers:
|
||||
- save_strategy: 'best' requires metric_for_best_model
|
||||
- streaming=True with val_set_size > 0 is rejected
|
||||
- lora_target_modules with invalid regex patterns is rejected
|
||||
- GRPO: generation batch size must be divisible by num_generations,
|
||||
num_generations >= 2, and effective_gbs >= num_generations * world_size
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@@ -117,3 +119,136 @@ class TestLoraTargetModulesRegexValidator:
|
||||
)
|
||||
with pytest.raises(ValueError, match="invalid regex pattern"):
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class TestGRPOBatchSizeValidator:
|
||||
"""GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2.
|
||||
|
||||
These call the @model_validator(mode="before") classmethod directly on a
|
||||
plain dict — same input shape it receives during full Pydantic validation,
|
||||
just without dragging in unrelated fields (datasets / model loading / etc.)
|
||||
that aren't relevant to what's under test. The validator is registered on
|
||||
``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the
|
||||
same code path ``axolotl train`` exercises.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _check(data):
|
||||
from axolotl.utils.schemas.validation import RLValidationMixin
|
||||
|
||||
return RLValidationMixin.check_grpo_batch_size_divisibility(data)
|
||||
|
||||
def test_divisible_passes(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
# Should return data unchanged (no exception)
|
||||
out = self._check(data)
|
||||
assert out["trl"]["num_generations"] == 4
|
||||
|
||||
def test_non_divisible_raises(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
with pytest.raises(ValueError, match="num_generations"):
|
||||
self._check(data)
|
||||
|
||||
def test_non_divisible_error_includes_fix_hint(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"):
|
||||
self._check(data)
|
||||
|
||||
def test_num_generations_one_raises(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"trl": {"num_generations": 1},
|
||||
}
|
||||
with pytest.raises(ValueError, match=r"num_generations >= 2"):
|
||||
self._check(data)
|
||||
|
||||
def test_explicit_generation_batch_size_divisible_passes(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"trl": {"num_generations": 4, "generation_batch_size": 8},
|
||||
}
|
||||
out = self._check(data)
|
||||
assert out["trl"]["generation_batch_size"] == 8
|
||||
|
||||
def test_explicit_generation_batch_size_non_divisible_raises(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"trl": {"num_generations": 4, "generation_batch_size": 6},
|
||||
}
|
||||
with pytest.raises(ValueError, match="trl.generation_batch_size"):
|
||||
self._check(data)
|
||||
|
||||
def test_non_grpo_skips_check(self):
|
||||
# Anything other than rl=grpo should pass through untouched, even
|
||||
# with non-divisible batch sizes — they're irrelevant to other RL
|
||||
# methods that don't use group-relative advantages.
|
||||
data = {
|
||||
"rl": "dpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
assert self._check(data) is data
|
||||
|
||||
def test_no_rl_set_skips_check(self):
|
||||
data = {
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
}
|
||||
assert self._check(data) is data
|
||||
|
||||
def test_grpo_without_num_generations_skips_check(self):
|
||||
# If num_generations isn't set, TRL uses its own default — we don't
|
||||
# have enough info to validate, so the validator must short-circuit
|
||||
# rather than guess.
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 3,
|
||||
"trl": {},
|
||||
}
|
||||
out = self._check(data)
|
||||
assert out["rl"] == "grpo"
|
||||
|
||||
def test_multi_rank_group_size_check(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4, # gbs=4
|
||||
"world_size": 2, # need gbs >= 4*2 = 8
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
with pytest.raises(ValueError, match=r"world_size=2"):
|
||||
self._check(data)
|
||||
|
||||
def test_multi_rank_group_size_satisfied(self):
|
||||
data = {
|
||||
"rl": "grpo",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8, # gbs=8 >= 4*2
|
||||
"world_size": 2,
|
||||
"trl": {"num_generations": 4},
|
||||
}
|
||||
out = self._check(data)
|
||||
assert out["gradient_accumulation_steps"] == 8
|
||||
|
||||
Reference in New Issue
Block a user