From 80a97f192b0f7b7289275c3c772881433c8bfd0b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 13 Apr 2026 18:29:22 +0000 Subject: [PATCH] validate batch shape against num_generations at config time Surfaces a class of GRPO config errors at axolotl-train startup instead of letting them bubble out of GRPOTrainer.__init__ after the model loads. Three checks under RLValidationMixin.check_grpo_batch_size_divisibility: - effective generation_batch_size (or mb*GA fallback) must be divisible by trl.num_generations, with a hint pointing at the smallest GA bump that fixes the violation - num_generations >= 2 (group-relative advantage needs variance; with num_gen=1 the policy never updates) - When world_size > 1, effective gbs >= num_generations * world_size 11 unit tests cover the table: divisible/non-divisible, explicit and implicit gbs, multi-rank constraint, GRPO-disabled passthrough, and unset num_generations. --- src/axolotl/utils/schemas/validation.py | 82 +++++++++++ .../validation/test_config_validators.py | 135 ++++++++++++++++++ 2 files changed, 217 insertions(+) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index ff7813600..1df39a44f 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -770,6 +770,88 @@ class RLValidationMixin: ) return data + @model_validator(mode="before") + @classmethod + def check_grpo_batch_size_divisibility(cls, data): + """Surface GRPO batch-shape mismatches at config-parse time. + + TRL's GRPOTrainer requires that the per-step generation batch size be + evenly divisible by ``num_generations`` so that every prompt can be + replicated exactly ``num_generations`` times. The runtime check inside + ``GRPOTrainer.__init__`` only fires after the model has been loaded — + too late and too cryptic for the user. We replicate the check here so + the failure is immediate and actionable. + + Also enforces: + - ``num_generations >= 2`` (group-relative advantage needs variance) + - ``effective_gbs >= num_generations * world_size`` when capabilities + indicate multiple ranks (each rank needs at least one full group) + """ + if data.get("rl") != "grpo": + return data + + trl_cfg = data.get("trl") or {} + num_gen = trl_cfg.get("num_generations") + if num_gen is None: + # TRL's own default is 8 — but if the user didn't set it, we + # don't have enough info to validate anything. Let TRL's own + # init handle the default-vs-batch interaction. + return data + if num_gen < 2: + raise ValueError( + f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). " + "With num_generations=1, every group has zero advantage and " + "the policy never updates." + ) + + explicit_gbs = trl_cfg.get("generation_batch_size") + if explicit_gbs is not None: + effective_gbs = int(explicit_gbs) + gbs_source = "trl.generation_batch_size" + else: + mb = data.get("micro_batch_size") or 1 + ga = data.get("gradient_accumulation_steps") or 1 + effective_gbs = int(mb) * int(ga) + gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})" + + if effective_gbs % num_gen != 0: + # Suggest the smallest GA bump that fixes it for the common case + # where the user hasn't set generation_batch_size explicitly. + hint = "" + if explicit_gbs is None: + from math import gcd + + mb_val = int(data.get("micro_batch_size") or 1) + # smallest GA such that mb*GA is a multiple of num_gen + lcm = num_gen * mb_val // gcd(num_gen, mb_val) + suggested_ga = lcm // mb_val + hint = ( + f" Smallest fix: set `gradient_accumulation_steps: " + f"{suggested_ga}` (so micro_batch_size * GA = " + f"{mb_val * suggested_ga} is a multiple of {num_gen})." + ) + raise ValueError( + f"GRPO: generation batch size must be divisible by " + f"`trl.num_generations`. Got effective_gbs={effective_gbs} " + f"(from {gbs_source}) and num_generations={num_gen}.{hint}" + ) + + # Multi-rank check: each rank must receive at least one full group + # per step. Without `capabilities` populated yet (mode='before'), we + # fall back to user-set distributed fields. + world_size = ( + (data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1 + ) + if world_size and world_size > 1 and effective_gbs < num_gen * world_size: + raise ValueError( + f"GRPO with world_size={world_size} requires effective_gbs " + f">= num_generations * world_size = {num_gen * world_size}, " + f"got {effective_gbs}. Increase gradient_accumulation_steps " + f"or micro_batch_size." + ) + + return data + class OptimizationValidationMixin: """Validation methods related to optimization and performance.""" diff --git a/tests/utils/schemas/validation/test_config_validators.py b/tests/utils/schemas/validation/test_config_validators.py index c756f1362..fbfa79ad8 100644 --- a/tests/utils/schemas/validation/test_config_validators.py +++ b/tests/utils/schemas/validation/test_config_validators.py @@ -5,6 +5,8 @@ Covers: - save_strategy: 'best' requires metric_for_best_model - streaming=True with val_set_size > 0 is rejected - lora_target_modules with invalid regex patterns is rejected + - GRPO: generation batch size must be divisible by num_generations, + num_generations >= 2, and effective_gbs >= num_generations * world_size """ import pytest @@ -117,3 +119,136 @@ class TestLoraTargetModulesRegexValidator: ) with pytest.raises(ValueError, match="invalid regex pattern"): validate_config(cfg) + + +class TestGRPOBatchSizeValidator: + """GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2. + + These call the @model_validator(mode="before") classmethod directly on a + plain dict — same input shape it receives during full Pydantic validation, + just without dragging in unrelated fields (datasets / model loading / etc.) + that aren't relevant to what's under test. The validator is registered on + ``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the + same code path ``axolotl train`` exercises. + """ + + @staticmethod + def _check(data): + from axolotl.utils.schemas.validation import RLValidationMixin + + return RLValidationMixin.check_grpo_batch_size_divisibility(data) + + def test_divisible_passes(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "trl": {"num_generations": 4}, + } + # Should return data unchanged (no exception) + out = self._check(data) + assert out["trl"]["num_generations"] == 4 + + def test_non_divisible_raises(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 2, + "trl": {"num_generations": 4}, + } + with pytest.raises(ValueError, match="num_generations"): + self._check(data) + + def test_non_divisible_error_includes_fix_hint(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 3, + "trl": {"num_generations": 4}, + } + with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"): + self._check(data) + + def test_num_generations_one_raises(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "trl": {"num_generations": 1}, + } + with pytest.raises(ValueError, match=r"num_generations >= 2"): + self._check(data) + + def test_explicit_generation_batch_size_divisible_passes(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "trl": {"num_generations": 4, "generation_batch_size": 8}, + } + out = self._check(data) + assert out["trl"]["generation_batch_size"] == 8 + + def test_explicit_generation_batch_size_non_divisible_raises(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "trl": {"num_generations": 4, "generation_batch_size": 6}, + } + with pytest.raises(ValueError, match="trl.generation_batch_size"): + self._check(data) + + def test_non_grpo_skips_check(self): + # Anything other than rl=grpo should pass through untouched, even + # with non-divisible batch sizes — they're irrelevant to other RL + # methods that don't use group-relative advantages. + data = { + "rl": "dpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 3, + "trl": {"num_generations": 4}, + } + assert self._check(data) is data + + def test_no_rl_set_skips_check(self): + data = { + "micro_batch_size": 1, + "gradient_accumulation_steps": 3, + } + assert self._check(data) is data + + def test_grpo_without_num_generations_skips_check(self): + # If num_generations isn't set, TRL uses its own default — we don't + # have enough info to validate, so the validator must short-circuit + # rather than guess. + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 3, + "trl": {}, + } + out = self._check(data) + assert out["rl"] == "grpo" + + def test_multi_rank_group_size_check(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, # gbs=4 + "world_size": 2, # need gbs >= 4*2 = 8 + "trl": {"num_generations": 4}, + } + with pytest.raises(ValueError, match=r"world_size=2"): + self._check(data) + + def test_multi_rank_group_size_satisfied(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 8, # gbs=8 >= 4*2 + "world_size": 2, + "trl": {"num_generations": 4}, + } + out = self._check(data) + assert out["gradient_accumulation_steps"] == 8