From 7420fd4de6aa80196bddc0520039178354d1304b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 22 Apr 2026 09:05:46 -0400 Subject: [PATCH] fix async prefetch with nemogym (#3606) --- .../core/trainers/grpo/async_trainer.py | 170 ++++++++- .../integrations/nemo_gym/data_producer.py | 91 ++++- src/axolotl/integrations/nemo_gym/plugin.py | 236 ++++++++++-- src/axolotl/integrations/nemo_gym/server.py | 38 +- src/axolotl/kernels/gemma4_fused_rope.py | 168 ++++++--- src/axolotl/loaders/patch_manager.py | 27 ++ src/axolotl/monkeypatch/gemma4_hybrid_mask.py | 115 ++++++ src/axolotl/monkeypatch/tiled_mlp/patch.py | 10 +- src/axolotl/scripts/vllm_serve_lora.py | 153 ++++++++ src/axolotl/utils/schemas/validation.py | 82 ++++ tests/core/test_async_grpo.py | 192 ++++++++++ tests/integrations/test_nemo_gym.py | 354 +++++++++++++++++- tests/kernels/test_gemma4_fused_rope.py | 190 ++++++++++ tests/monkeypatch/test_gemma4_fused_attn.py | 219 +++++++++++ tests/monkeypatch/test_gemma4_hybrid_mask.py | 343 +++++++++++++++++ .../validation/test_config_validators.py | 135 +++++++ 16 files changed, 2388 insertions(+), 135 deletions(-) create mode 100644 src/axolotl/monkeypatch/gemma4_hybrid_mask.py create mode 100644 tests/monkeypatch/test_gemma4_fused_attn.py create mode 100644 tests/monkeypatch/test_gemma4_hybrid_mask.py diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index 3388687ad..4759a30b0 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -242,6 +242,85 @@ class ProducerConfig: ) +class _GroupShardedSampler: + """Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups. + + ``RepeatSampler`` yields ``num_generations`` consecutive copies of + each prompt, forming a GRPO group. For distributed training each + rank must see a disjoint slice of prompts (otherwise every rank + dogpiles on the first 1/world_size of the batch) while keeping each + group intact on a single rank so advantage normalization sees all + peer generations. + + ``accelerator.prepare(DataLoader)`` does not handle this correctly + for custom samplers with ``split_batches=False`` (the default): it + leaves the sampler alone and every rank replays identical indices. + This wrapper fixes that by consuming the inner sampler's full + output, chunking it into ``num_generations``-sized groups, and + round-robining whole groups across ranks. + + Intended to be used ONLY when distributed training is active + (``num_replicas > 1``); for single-rank it is a no-op but still + correct. + """ + + def __init__( + self, + inner: Any, + num_generations: int, + rank: int, + num_replicas: int, + ): + if num_generations < 1: + raise ValueError(f"num_generations must be >= 1, got {num_generations}") + if num_replicas < 1: + raise ValueError(f"num_replicas must be >= 1, got {num_replicas}") + if not (0 <= rank < num_replicas): + raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}") + self.inner = inner + self.num_generations = num_generations + self.rank = rank + self.num_replicas = num_replicas + + def __iter__(self): + all_indices = list(self.inner) + if len(all_indices) % self.num_generations != 0: + raise ValueError( + f"inner sampler yielded {len(all_indices)} indices, " + f"not a multiple of num_generations={self.num_generations}" + ) + # Chunk the flat index sequence into groups of num_generations + # consecutive indices. ``RepeatSampler`` guarantees that each + # group contains num_generations copies of the same prompt id. + groups = [ + all_indices[i : i + self.num_generations] + for i in range(0, len(all_indices), self.num_generations) + ] + # Round-robin whole groups across ranks. Round-robin (vs. + # contiguous chunking) preserves approximate shuffled order on + # each rank even when the group count is small relative to the + # world size. + for group in groups[self.rank :: self.num_replicas]: + yield from group + + def __len__(self): + try: + inner_len = len(self.inner) + except TypeError: + # Non-sized inner sampler — we can't know the per-rank + # length without materializing. Return 0 as a hint that the + # DataLoader should fall back to iteration. + return 0 + total_groups = inner_len // self.num_generations + # Ceiling division for the trailing groups that don't divide + # evenly — extra groups go to the first ``total_groups % + # num_replicas`` ranks, matching the round-robin above. + my_groups = ( + total_groups + self.num_replicas - self.rank - 1 + ) // self.num_replicas + return my_groups * self.num_generations + + class DataProducer(ABC): """Abstract base class for online data producers. @@ -556,6 +635,34 @@ class GRPODataProducer(BaseDataProducer): seed=self._seed, ) + # Shard the sampler across distributed ranks so each rank sees + # a disjoint slice of prompts. ``RepeatSampler`` groups each + # prompt with ``num_generations`` consecutive copies — our + # wrapper round-robins WHOLE groups across ranks so all + # generations of a given prompt stay on the same rank (needed + # for GRPO advantage normalization within a group). + # + # Without this, ``accelerator.prepare(dl)`` with the default + # ``split_batches=False`` leaves the custom sampler alone, so + # every rank iterates the identical index sequence and the + # cluster dogpiles on the first 1/world_size of the prompts. + num_replicas = max(1, trainer.accelerator.num_processes) + if num_replicas > 1: + sampler = _GroupShardedSampler( + inner=sampler, + num_generations=self._num_generations, + rank=trainer.accelerator.process_index, + num_replicas=num_replicas, + ) + logger.info( + "[RANK:%d] _GroupShardedSampler active " + "(num_replicas=%d, num_generations=%d, gen_batch=%d)", + trainer.accelerator.process_index, + num_replicas, + self._num_generations, + self._generation_batch_size, + ) + # Use identity collator (same as stock GRPOTrainer) def _identity(x): return x @@ -574,12 +681,11 @@ class GRPODataProducer(BaseDataProducer): rank=trainer.args.process_index, ), ) - self._prompt_dl = trainer.accelerator.prepare(dl) - - # Don't let accelerator track this dataloader - acc_dls = trainer.accelerator._dataloaders - if self._prompt_dl in acc_dls: - acc_dls.remove(self._prompt_dl) + # Skip accelerator.prepare — we're handling per-rank sharding + # ourselves via ``_GroupShardedSampler``. ``prepare()`` would + # otherwise try to wrap the DataLoader with its own sharding + # logic which does not understand our group structure. + self._prompt_dl = dl self._prompt_iter = iter(self._prompt_dl) @@ -1103,11 +1209,22 @@ class AsyncGRPOTrainer(GRPOTrainer): - vllm_lora_sync: saves adapter to filesystem, vLLM loads natively - PEFT no-merge: computes merged weights as new tensors, NCCL broadcast - Non-PEFT: stock sync_weights via merge_adapter + NCCL + + This is the canonical sync trigger and runs in BOTH async and + synchronous modes from ``_prepare_inputs_with_data_producer`` / + ``_prepare_inputs_legacy_async``. The ``_generate_single_turn`` + patch is a parallel backup for non-data-producer paths (vanilla + GRPO without NeMo Gym), where the data producer is bypassed + entirely and TRL's stock generate-then-sync flow is used instead. """ - if not (self.use_vllm and self.args.async_prefetch): + if not self.use_vllm: return step = self.state.global_step - interval = self.args.vllm_sync_interval + # Default to syncing every step when no interval is configured — + # otherwise ``step % None`` would TypeError, and the previous + # behavior of crashing on the first sync was strictly worse than + # the standard "sync every optimizer step". + interval = self.args.vllm_sync_interval or 1 if step != self._last_synced_step and step % interval == 0: if step == 0: logger.info("Skipping vLLM weight sync at step 0 (no training yet)") @@ -1202,13 +1319,42 @@ class AsyncGRPOTrainer(GRPOTrainer): # Permanently replace vllm_generation.sync_weights with our custom # sync to avoid merge_adapter (fails on FP8 / races with training). - # For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights - # handles the sync with proper interval tracking. + # + # The design has two modes that have to be threaded carefully: + # + # - Async prefetch ON: BG generation thread can't safely call + # sync_weights mid-rollout (it races with the trainer's optimizer + # step and can corrupt weights). We no-op the stock sync hook and + # drive sync ourselves from ``_maybe_sync_vllm_weights`` after the + # optimizer step on the main thread. + # + # - Async prefetch OFF (synchronous mode): TRL's stock + # ``_generate_single_turn`` calls ``sync_weights`` once per step + # boundary. There's no BG thread to race with, and + # ``_maybe_sync_vllm_weights`` short-circuits with + # ``if not async_prefetch: return``, so we MUST wire the stock + # hook directly to our LoRA sync helper — otherwise nothing ever + # pushes weights to vLLM and the trainer becomes a no-op (vLLM + # keeps serving the base model, every rollout in every group + # produces identical outputs, advantages are zero, optimizer + # step gets skipped, repeat). if not getattr(self, "_patched_sync_weights", False): if self.use_vllm and hasattr(self, "vllm_generation"): if getattr(self.args, "vllm_lora_sync", False): - # No-op: LoRA sync is driven by _maybe_sync_vllm_weights - self.vllm_generation.sync_weights = lambda: None + if getattr(self.args, "async_prefetch", False): + # Async: drive sync from main thread via + # _maybe_sync_vllm_weights instead. + self.vllm_generation.sync_weights = lambda: None + else: + # Sync mode: TRL's _generate_single_turn already + # calls sync_weights once per step boundary. Wire + # it directly to our LoRA filesystem sync helper. + sync_helper = self._sync_lora_adapter + + def _lora_filesystem_sync(): + sync_helper() + + self.vllm_generation.sync_weights = _lora_filesystem_sync self._patched_sync_weights = True else: from accelerate.utils import is_peft_model diff --git a/src/axolotl/integrations/nemo_gym/data_producer.py b/src/axolotl/integrations/nemo_gym/data_producer.py index 64b76d780..3a9635d15 100644 --- a/src/axolotl/integrations/nemo_gym/data_producer.py +++ b/src/axolotl/integrations/nemo_gym/data_producer.py @@ -110,11 +110,36 @@ class NemoGymDataProducer(GRPODataProducer): item["agent_ref"] = full_item["agent_ref"] dataset_items.append(item) - # Expand by num_generations (agent produces one rollout per call) - expanded_items = [] - for item in dataset_items: - for _ in range(self._num_generations): - expanded_items.append(item) + # NOTE: do NOT re-expand by num_generations here. + # ``RepeatSampler(mini_repeat_count=num_generations)`` already + # yields ``num_generations`` consecutive copies of each unique + # prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank * + # num_generations)`` items — one entry per rollout. Expanding + # again here would fire ``num_generations^2`` rollouts per + # prompt per rank and make every step dogpile on a handful of + # tasks. + expanded_items = dataset_items + + # Diagnostic: log what this rank is about to fire. + try: + import collections + + iid_counts: collections.Counter[str | None] = collections.Counter() + for it in dataset_items: + iid_counts[ + (it.get("responses_create_params", {}).get("metadata") or {}).get( + "instance_id" + ) + ] += 1 + LOG.info( + "[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s", + trainer.accelerator.process_index, + len(dataset_items), + len(iid_counts), + list(iid_counts.most_common(5)), + ) + except Exception: + pass # Call NeMo Gym agents loop = asyncio.new_event_loop() @@ -140,6 +165,7 @@ class NemoGymDataProducer(GRPODataProducer): logprobs_list = [] rewards_list = [] + num_turns_list: list[int] = [] for resp in responses: parsed = _parse_agent_response(resp, eos_token_id) prompt_ids_list.append(parsed["prompt_ids"]) @@ -147,6 +173,7 @@ class NemoGymDataProducer(GRPODataProducer): env_mask_list.append(parsed["env_mask"]) logprobs_list.append(parsed["logprobs"]) rewards_list.append(parsed["reward"]) + num_turns_list.append(parsed.get("num_turns", 0)) # Pad to tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -179,22 +206,48 @@ class NemoGymDataProducer(GRPODataProducer): tool_mask = [torch.tensor(m, device=device) for m in env_mask_list] tool_mask = pad(tool_mask, padding_value=1, padding_side="right") - # Inject rewards into inputs so _compute_deferred_scores can use them - # The deferred scoring path calls _calculate_rewards which reads reward_funcs. - # Our passthrough reward_fn reads "env_reward" from kwargs. + # Inject per-rollout reward + num_turns into each input. Since + # ``RepeatSampler`` already yields ``num_generations`` copies of + # each prompt, ``inputs`` has ONE entry per rollout (matching + # ``rewards_list`` 1:1). No per-prompt grouping happens here — + # GRPO advantage normalization is the trainer's job downstream. + assert len(inputs) == len(rewards_list), ( + f"rewards/inputs length mismatch: " + f"{len(rewards_list)} rewards vs {len(inputs)} inputs" + ) for i, inp in enumerate(inputs): - # Each input gets rewards for its num_generations rollouts - start = i * self._num_generations - end = start + self._num_generations - inp["env_reward"] = rewards_list[start:end] + inp["env_reward"] = rewards_list[i] + inp["num_turns"] = num_turns_list[i] - # Expand inputs to match expanded rollouts (num_generations copies) - expanded_inputs = [] - for inp in inputs: - for g in range(self._num_generations): - expanded_inp = dict(inp) - expanded_inp["env_reward"] = inp["env_reward"][g] - expanded_inputs.append(expanded_inp) + # One expanded_input per rollout (already correct count because + # inputs has num_generations copies baked in by the sampler). + expanded_inputs = [dict(inp) for inp in inputs] + + # Log rollout-level stats to wandb from rank 0. These are the + # true agent-side metrics (not the tokenized TRL view) — so + # num_turns reflects how many /run iterations each rollout + # actually took before finishing or hitting max_turns. + if is_main and num_turns_list: + try: + import wandb + + if wandb.run is not None: + import statistics as _stats + + nonzero = sum(1 for r in rewards_list if r > 0) + log_payload = { + "rollout/num_turns/mean": float(_stats.mean(num_turns_list)), + "rollout/num_turns/min": float(min(num_turns_list)), + "rollout/num_turns/max": float(max(num_turns_list)), + "rollout/reward/mean": float(_stats.mean(rewards_list)), + "rollout/reward/nonzero_frac": ( + nonzero / len(rewards_list) if rewards_list else 0.0 + ), + "rollout/n_samples": float(len(rewards_list)), + } + wandb.log(log_payload, commit=False) + except Exception as exc: # never let metric logging break training + LOG.warning("rollout wandb log failed: %s", exc) # Decode completions for reward functions completions = trainer.processing_class.batch_decode( diff --git a/src/axolotl/integrations/nemo_gym/plugin.py b/src/axolotl/integrations/nemo_gym/plugin.py index 14de684cf..b85e344db 100644 --- a/src/axolotl/integrations/nemo_gym/plugin.py +++ b/src/axolotl/integrations/nemo_gym/plugin.py @@ -19,6 +19,7 @@ Supports two modes: from __future__ import annotations import os +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Union from axolotl.integrations.base import BasePlugin @@ -30,6 +31,107 @@ if TYPE_CHECKING: LOG = get_logger(__name__) +# ---- vLLM weight-sync transport probe ------------------------------------ + + +@dataclass +class VLLMWeightSyncCapabilities: + """What weight-sync routes a vLLM server actually exposes. + + Discovered once at ``pre_model_load`` time by fetching the server's + ``/openapi.json``. Drives the transport-selection table below. + """ + + nccl: bool = False # /init_communicator/ + /update_named_param/ + lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native) + lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension) + http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension) + probed: bool = False + probe_error: str | None = None + routes: list[str] = field(default_factory=list) + + @property + def any_full_param_sync(self) -> bool: + """True if at least one transport can push full-model weights.""" + return self.nccl or self.http_full + + @property + def any_lora_sync(self) -> bool: + """True if at least one transport can push LoRA adapters.""" + return self.lora_filesystem or self.lora_axolotl or self.nccl + + +def probe_vllm_weight_sync( + base_url: str, timeout: float = 5.0 +) -> VLLMWeightSyncCapabilities: + """Detect which weight-sync routes the configured vLLM server exposes. + + Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport + we care about is mounted as a POST route there. Falls back to all-False + on any error so the caller can still decide what to do (typically: raise + a clear error rather than silently no-op). + """ + import requests + + caps = VLLMWeightSyncCapabilities() + try: + r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout) + r.raise_for_status() + spec = r.json() + routes = sorted((spec.get("paths") or {}).keys()) + caps.routes = routes + caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes + caps.lora_filesystem = "/v1/load_lora_adapter" in routes + caps.lora_axolotl = "/set_lora_adapter/" in routes + caps.http_full = "/http_update_weights/" in routes + caps.probed = True + except Exception as exc: + caps.probe_error = f"{type(exc).__name__}: {exc}" + LOG.warning( + "NeMo Gym: failed to probe vLLM /openapi.json at %s — %s. " + "Will fall back to LoRA-only behavior.", + base_url, + caps.probe_error, + ) + return caps + + +def select_weight_sync_transport( + caps: VLLMWeightSyncCapabilities, + *, + has_lora: bool, + vllm_lora_sync_pref: bool, +) -> str: + """Pick the right transport for a (server caps, model type) combo. + + Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or + ``"none"``. The caller decides what to do with ``"none"`` (typically: + raise an error explaining the misconfiguration). + + Selection table: + LoRA model + lora endpoint + lora-sync pref → lora_filesystem + LoRA model + lora endpoint → lora_filesystem + LoRA model + nccl endpoint → nccl (broadcast merged adapter) + Full model + nccl endpoint → nccl + Full model + http endpoint → http_full + anything else → none + """ + if has_lora: + if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref: + return "lora_filesystem" + if caps.lora_filesystem or caps.lora_axolotl: + return "lora_filesystem" + if caps.nccl: + return "nccl" + return "none" + # Full-parameter model + if caps.nccl: + return "nccl" + if caps.http_full: + return "http_full" + return "none" + + class NemoGymPlugin(BasePlugin): """Plugin for NVIDIA NeMo Gym integration with Axolotl. @@ -50,37 +152,69 @@ class NemoGymPlugin(BasePlugin): self._reward_fn = None self._dataset_lookup = None self._agent_servers = {} + self._vllm_caps: VLLMWeightSyncCapabilities | None = None def get_input_args(self): return "axolotl.integrations.nemo_gym.NemoGymArgs" def pre_model_load(self, cfg): - """Apply monkeypatches before trainer creation.""" + """Probe vLLM weight-sync routes and conditionally bypass NCCL init. + + Replaces the previous unconditional ``init_communicator`` monkey-patch + with a probe of the configured vLLM server's ``/openapi.json``. We only + bypass NCCL init when the server we're talking to actually lacks the + ``/init_communicator/`` route (i.e. stock ``vllm serve``); against + TRL/axolotl serve modules that DO expose NCCL routes, we leave the + standard TRL flow alone so full-finetune training can sync weights. + """ if not cfg.nemo_gym_enabled: return - # Always skip NCCL communicator init in NeMo Gym mode. - # NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL - # colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers. trl_cfg = getattr(cfg, "trl", None) - if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server": + if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"): + return + + host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1" + port = getattr(trl_cfg, "vllm_server_port", None) or 8000 + base_url = f"http://{host}:{port}" + self._vllm_caps = probe_vllm_weight_sync(base_url) + + if self._vllm_caps.probed: + LOG.info( + "NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s " + "lora_axolotl=%s http_full=%s", + base_url, + self._vllm_caps.nccl, + self._vllm_caps.lora_filesystem, + self._vllm_caps.lora_axolotl, + self._vllm_caps.http_full, + ) + + # Only bypass NCCL init when the server doesn't speak it. If NCCL is + # available we leave VLLMClient.init_communicator alone so the + # standard TRL sync flow can run for full-parameter training. + if not self._vllm_caps.nccl: self._patch_skip_nccl_init() def _patch_skip_nccl_init(self): """Monkeypatch VLLMClient.init_communicator to no-op. - NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA - serve script). The NCCL communicator is not needed and fails with both - vLLM V1 engine and standard OpenAI server mode. + Only called when the configured vLLM server doesn't expose + ``/init_communicator/`` (e.g. stock ``vllm serve``). In that case + TRL's standard ``init_communicator`` would 404 inside trainer + construction; we no-op it so the LoRA filesystem path can install + its own sync in ``post_trainer_create``. """ try: from trl.generation.vllm_client import VLLMClient VLLMClient._original_init_communicator = VLLMClient.init_communicator VLLMClient.init_communicator = lambda self, **kwargs: LOG.info( - "Skipping NCCL init_communicator (LoRA sync mode)" + "Skipping NCCL init_communicator (server has no /init_communicator/)" + ) + LOG.info( + "Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)" ) - LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync") except Exception as exc: LOG.warning(f"Failed to patch VLLMClient: {exc}") @@ -234,30 +368,80 @@ class NemoGymPlugin(BasePlugin): verify_timeout = cfg.nemo_gym_verify_timeout or 30 multi_turn = cfg.nemo_gym_multi_turn or False - # Handle weight sync. NeMo Gym skips NCCL init, so we need to either: - # - Install LoRA sync (when vllm_lora_sync=True) - # - Or no-op sync_weights (when using standard vLLM server) + # Pick a weight-sync transport based on what the configured vLLM + # server actually exposes (see ``pre_model_load`` probe) and what + # kind of model we're training. The selection table is documented + # in ``select_weight_sync_transport``. trl_cfg = getattr(cfg, "trl", None) if hasattr(trainer, "vllm_generation") and trainer.vllm_generation: vllm_gen = trainer.vllm_generation - if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False): + adapter = getattr(cfg, "adapter", None) + has_lora = adapter in ("lora", "qlora") + vllm_lora_sync_pref = bool( + trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False) + ) + caps = self._vllm_caps or VLLMWeightSyncCapabilities() + transport = select_weight_sync_transport( + caps, + has_lora=has_lora, + vllm_lora_sync_pref=vllm_lora_sync_pref, + ) + + if transport == "lora_filesystem": self._setup_lora_sync(trainer) - # Verify the vLLM server supports runtime LoRA loading self._check_lora_endpoint(vllm_gen) - else: - # No NCCL, no LoRA sync — skip all weight sync paths - vllm_gen.sync_weights = lambda: LOG.debug( - "Weight sync skipped (NeMo Gym mode)" + LOG.info("NeMo Gym weight sync: LoRA filesystem") + elif transport == "nccl": + # Standard TRL NCCL path. We leave ``VLLMClient.init_communicator`` + # alone (pre_model_load only patched it when the probe found no + # NCCL route) so the trainer's normal weight-sync flow runs. + LOG.info( + "NeMo Gym weight sync: NCCL (server exposes /init_communicator/)" ) - type(vllm_gen).sync_weights = lambda self: LOG.debug( - "Weight sync skipped (NeMo Gym mode)" + elif transport == "http_full": + # Full-parameter HTTP sync — implementation lands in step 3. + # For now, fail loudly so users know the path is detected but + # not yet wired up, instead of silently no-oping like before. + raise NotImplementedError( + "NeMo Gym + full fine-tune + HTTP weight sync is detected " + "but the client-side sync helper is not yet implemented " + "(planned). Use `adapter: lora|qlora` for now, or use a " + "vLLM serve module that exposes /init_communicator/ for " + "NCCL sync." ) - # Also patch the async trainer's internal sync method - if hasattr(trainer, "_maybe_sync_vllm_weights"): - trainer._maybe_sync_vllm_weights = lambda: LOG.debug( - "Async weight sync skipped (NeMo Gym mode)" + else: # transport == "none" + # No viable sync path. Build a precise error so the user knows + # exactly what's missing and how to fix it. + if not caps.probed: + msg = ( + "could not probe the vLLM server's " + f"/openapi.json: {caps.probe_error}. " + "Verify that vLLM is reachable at " + f"{getattr(trl_cfg, 'vllm_server_host', '?')}:" + f"{getattr(trl_cfg, 'vllm_server_port', '?')}." ) - LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)") + elif has_lora: + msg = ( + "the vLLM server has neither NCCL routes " + "(/init_communicator/) nor a LoRA-loading route " + "(/v1/load_lora_adapter or /set_lora_adapter/). " + "Restart vLLM with `--enable-lora --max-lora-rank N " + "VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock " + "server, or use `axolotl vllm-serve` for the " + "NCCL-capable serve module." + ) + else: + msg = ( + "the vLLM server exposes no full-parameter sync route " + "(/init_communicator/ for NCCL or /http_update_weights/ " + "for HTTP). Use `axolotl vllm-serve` (which has both) " + "or set `adapter: lora|qlora`." + ) + raise ValueError( + f"NeMo Gym: no usable weight-sync transport — {msg} Without " + "weight sync the trainer's gradient updates never reach the " + "rollout policy (functionally a no-op trainer)." + ) if multi_turn: self._wire_multi_turn(cfg, trainer, model_name, verify_timeout) diff --git a/src/axolotl/integrations/nemo_gym/server.py b/src/axolotl/integrations/nemo_gym/server.py index 0af9b3b71..bd619569e 100644 --- a/src/axolotl/integrations/nemo_gym/server.py +++ b/src/axolotl/integrations/nemo_gym/server.py @@ -130,21 +130,41 @@ def start_servers( ) -def get_server_configs(head_port: int = 11000) -> dict: +def get_server_configs(head_port: int = 11000, timeout: float = 30.0) -> dict: """Fetch the global config from the NeMo Gym head server. + Retries up to 3 times with exponential backoff. The default per-attempt + timeout is 30s (raised from the original 5s) because head servers can + be slow to respond when they're concurrently serving rollouts from a + prior training run. A 5s timeout was empirically too tight to survive + a kill-and-relaunch cycle. + Returns: Dict mapping server_name -> server config. """ - response = requests.get( - f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5 + url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml" + last_exc: Exception | None = None + for attempt in (1, 2, 3): + try: + response = requests.get(url, timeout=timeout) + response.raise_for_status() + result = yaml.safe_load(response.text) + # NeMo Gym head server double-encodes: YAML string inside a YAML string + if isinstance(result, str): + result = yaml.safe_load(result) + return result + except (requests.exceptions.RequestException, OSError) as exc: + last_exc = exc + LOG.warning( + "NeMo Gym head probe attempt %d/3 failed: %s. Retrying...", + attempt, + type(exc).__name__, + ) + if attempt < 3: + time.sleep(2.0 * attempt) + raise RuntimeError( + f"NeMo Gym head server at {url} did not respond after 3 attempts: {last_exc}" ) - response.raise_for_status() - result = yaml.safe_load(response.text) - # NeMo Gym head server double-encodes: YAML string inside a YAML string - if isinstance(result, str): - result = yaml.safe_load(result) - return result def get_agent_servers( diff --git a/src/axolotl/kernels/gemma4_fused_rope.py b/src/axolotl/kernels/gemma4_fused_rope.py index f3b68e603..f98e9a3de 100644 --- a/src/axolotl/kernels/gemma4_fused_rope.py +++ b/src/axolotl/kernels/gemma4_fused_rope.py @@ -53,6 +53,7 @@ def _rms_norm_rope_forward_kernel( RSTD_ptr, RSTD_row_stride, n_cols, + n_rot, n_heads, eps, HAS_WEIGHT: tl.constexpr, @@ -60,28 +61,35 @@ def _rms_norm_rope_forward_kernel( ): """ Fused forward: - x_norm = x / rms(x) [* weight] (RMSNorm) - y = x_norm * cos + rotate_half(x_norm) * sin (RoPE) + x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols) + y[..., :n_rot] = rope(x_norm[..., :n_rot]) + y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary) - rotate_half swaps first/second halves and negates the first: - rotate_half([a, b]) = [-b, a] + rotate_half swaps first/second halves and negates the first, restricted + to the rotary span [0, n_rot): + rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2 + + For the partial-rotary pass-through region we load cos with default 1.0 + and sin with default 0.0 outside [0, n_rot), so the same formula + `Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`. cos/sin are indexed by row_idx // n_heads to handle per-head broadcast - (cos/sin have shape (B*S, D) while X has shape (B*S*H, D)). + (cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)). """ row_idx = tl.program_id(0).to(tl.int64) - # cos/sin row: divide by n_heads since cos/sin are (B*S, D) + # cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot) cs_row_idx = row_idx // n_heads col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - half_dim = n_cols // 2 + rot_mask_col = col_offsets < n_rot + half_rot = n_rot // 2 # Load input row X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0) X_dtype = X_row.dtype X_fp32 = X_row.to(tl.float32) - # RMSNorm: compute 1/rms + # RMSNorm: compute 1/rms over the full row (rotary + pass-through) mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols rstd = rsqrt(mean_sq + eps) tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd) @@ -94,33 +102,38 @@ def _rms_norm_rope_forward_kernel( W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32) X_norm = X_norm * W_row - # RoPE: load cos/sin (broadcast across heads) + # RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get + # cos=1, sin=0 so the formula leaves X_norm untouched. cos_row = tl.load( - COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0 + COS_ptr + cs_row_idx * COS_row_stride + col_offsets, + mask=rot_mask_col, + other=1.0, ).to(tl.float32) sin_row = tl.load( - SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0 + SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, + mask=rot_mask_col, + other=0.0, ).to(tl.float32) - # rotate_half: for col < half_dim, take -X_norm[col + half_dim] - # for col >= half_dim, take X_norm[col - half_dim] + # rotate_half within [0, n_rot): + # for col < half_rot: take -X_norm[col + half_rot] + # for col in [half_rot, n_rot): take X_norm[col - half_rot] + # For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out). rot_offsets = tl.where( - col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim + col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot ) - rot_mask = rot_offsets < n_cols + rot_load_mask = (rot_offsets < n_cols) & rot_mask_col X_rot = tl.load( - X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0 + X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0 ).to(tl.float32) # Re-normalize the rotated values X_rot_norm = X_rot * rstd if HAS_WEIGHT: - W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to( - tl.float32 - ) + W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32) X_rot_norm = X_rot_norm * W_rot # Negate the first half (rotate_half negates x2, which becomes the first half) - sign = tl.where(col_offsets < half_dim, -1.0, 1.0) + sign = tl.where(col_offsets < half_rot, -1.0, 1.0) X_rot_norm = X_rot_norm * sign # Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin @@ -153,13 +166,21 @@ def _rms_norm_rope_backward_kernel( dW_row_stride, n_rows, n_cols, + n_rot, n_heads, rows_per_program, HAS_WEIGHT: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ - Backward for Y = RoPE(RMSNorm(X, W)) + Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary + (`n_rot <= n_cols`). + + For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the + output is just the normalized row, so dN[col] = dY[col] (achieved by + loading cos with default 1.0 and forcing the rotate-half contribution + to zero outside the rotary span). + cos/sin indexed by row_idx // n_heads for per-head broadcast. """ row_block_id = tl.program_id(0).to(tl.int64) @@ -167,7 +188,8 @@ def _rms_norm_rope_backward_kernel( row_end = min((row_block_id + 1) * rows_per_program, n_rows) col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - half_dim = n_cols // 2 + rot_mask_col = col_offsets < n_rot + half_rot = n_rot // 2 dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) @@ -186,33 +208,37 @@ def _rms_norm_rope_backward_kernel( rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride) cos_row = tl.load( - COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0 + COS_ptr + cs_row_idx * COS_row_stride + col_offsets, + mask=rot_mask_col, + other=1.0, ).to(tl.float32) - # dN = dY * cos + rotate_half^T(dY * sin) + # dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span) # rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half) # - # Compute rotate_half_transpose(dY * sin) by loading dY and sin at - # rotated offsets directly: dY[rot] * sin[rot] * adj_sign - # This is equivalent to rotating (dY * sin) because the rotation - # just permutes which elements are multiplied. + # For col >= n_rot the formula must collapse to dN = dY (since the + # forward is just a pass-through). cos defaults to 1.0 above; the + # rotate-half contribution is masked to zero below. rot_offsets = tl.where( - col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim + col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot ) - rot_mask = rot_offsets < n_cols + rot_load_mask = (rot_offsets < n_cols) & rot_mask_col dY_rot = tl.load( dY_ptr + row_idx * dY_row_stride + rot_offsets, - mask=rot_mask & mask, + mask=rot_load_mask, other=0, ).to(tl.float32) sin_rot = tl.load( SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets, - mask=rot_mask & mask, + mask=rot_load_mask, other=0, ).to(tl.float32) - adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0) - dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign + adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0) + rotate_term = dY_rot * sin_rot * adj_sign + # Zero out rotate-half contribution outside the rotary span. + rotate_term = tl.where(rot_mask_col, rotate_term, 0.0) + dN = dY_row * cos_row + rotate_term # Pre-weight normalized: n = rstd * x n = X_row * rstd @@ -241,15 +267,17 @@ def _rms_norm_rope_backward_kernel( ) -def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads): +def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot): """ Args: X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D) W: (head_dim,) or None — RMSNorm weight - cos: (B*S, head_dim) — position embeddings (broadcast across heads) - sin: (B*S, head_dim) — position embeddings (broadcast across heads) + cos: (B*S, n_rot) — position embeddings (broadcast across heads) + sin: (B*S, n_rot) — position embeddings (broadcast across heads) eps: float n_heads: int — number of attention heads (for cos/sin indexing) + n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for + partial rotary). Must be even and ``<= head_dim``. Returns: Y, X_saved, RSTD, BLOCK_SIZE, num_warps """ @@ -273,6 +301,7 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads): RSTD, RSTD.stride(0), n_cols, + n_rot, n_heads, eps, HAS_WEIGHT=has_weight, @@ -282,7 +311,9 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads): return Y, X, RSTD, BLOCK_SIZE, num_warps -def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps): +def rms_norm_rope_backward( + dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps +): n_rows, n_cols = dY.shape has_weight = W is not None @@ -315,6 +346,7 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa _dW.stride(0), n_rows, n_cols, + n_rot, n_heads, rows_per_program, HAS_WEIGHT=has_weight, @@ -329,13 +361,14 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa class FusedRMSNormRoPEFunction(torch.autograd.Function): @staticmethod @ensure_contiguous - def forward(ctx, X, W, cos, sin, eps, n_heads): + def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot): """ - X: (B*S*H, head_dim) - W: (head_dim,) or None - cos: (B*S, head_dim) — broadcast across heads - sin: (B*S, head_dim) — broadcast across heads + X: (B*S*H, head_dim) + W: (head_dim,) or None + cos: (B*S, n_rot) — broadcast across heads + sin: (B*S, n_rot) — broadcast across heads n_heads: int + n_rot: int — rotary dim (<= head_dim) """ Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward( X, @@ -344,11 +377,13 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function): sin, eps, n_heads, + n_rot, ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.n_heads = n_heads + ctx.n_rot = n_rot ctx.has_weight = W is not None ctx.save_for_backward(X_saved, W, cos, sin, RSTD) return Y @@ -365,21 +400,26 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function): sin, RSTD, ctx.n_heads, + ctx.n_rot, ctx.BLOCK_SIZE, ctx.num_warps, ) - return dX, dW, None, None, None, None + return dX, dW, None, None, None, None, None def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6): """ - Apply fused RMSNorm + RoPE. + Apply fused RMSNorm + (partial) RoPE. Args: x: (batch, seq_len, num_heads, head_dim) — after projection + view weight: (head_dim,) — RMSNorm weight, or None for no-scale norm - cos: (batch, seq_len, head_dim) — from RotaryEmbedding - sin: (batch, seq_len, head_dim) — from RotaryEmbedding + cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot`` + must be even and ``<= head_dim``. When ``n_rot < head_dim`` + the trailing ``head_dim - n_rot`` columns are RMSNorm-only + (partial-rotary pass-through), matching stock Gemma 4 with + ``partial_rotary_factor < 1.0``. + sin: (batch, seq_len, n_rot) — same shape as ``cos`` eps: float — RMSNorm epsilon Returns: @@ -387,14 +427,38 @@ def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6): """ shape = x.shape # (B, S, H, D) B, S, H, D = shape + n_rot = cos.shape[-1] + if sin.shape[-1] != n_rot: + raise ValueError( + f"cos and sin must have the same last dim, got cos={cos.shape[-1]} " + f"sin={sin.shape[-1]}" + ) + if n_rot > D: + raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})") + if n_rot % 2 != 0: + raise ValueError(f"rotary dim must be even, got {n_rot}") + # Flatten to 2D: (B*S*H, D) x_flat = x.reshape(-1, D).contiguous() - # Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast - # by dividing the row_idx by H to get the cos/sin row - cos_flat = cos.reshape(B * S, D).contiguous() - sin_flat = sin.reshape(B * S, D).contiguous() + # cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when + # all sequences share the same rotary positions). The kernel needs a + # dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly + # onto a single (b, s) pair, so expand-then-contiguous to materialize + # the per-batch broadcast. Expand is a no-op when B == cos.shape[0]. + if cos.shape[0] != B: + if cos.shape[0] != 1: + raise ValueError( + f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal " + f"to x batch dim ({B})" + ) + cos = cos.expand(B, S, n_rot) + sin = sin.expand(B, S, n_rot) + cos_flat = cos.reshape(B * S, n_rot).contiguous() + sin_flat = sin.reshape(B * S, n_rot).contiguous() - y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H) + y_flat = FusedRMSNormRoPEFunction.apply( + x_flat, weight, cos_flat, sin_flat, eps, H, n_rot + ) return y_flat.view(shape) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 873965516..01d9997d7 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -156,6 +156,14 @@ class PatchManager: # which would clobber any earlier fix. self._fix_nemotron_h_conversion_mapping() + # Gemma 4 hybrid attention runs here in post-build (NOT post-load): + # the per-layer ``self_attn.config._attn_implementation="sdpa"`` + # override needs to walk the raw model tree, which is broken by + # the post-load PEFT wrapping. The accompanying + # ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and + # installation-time-independent, so both halves of the fix live + # cleanly in the same call even though one is instance-scoped + # and the other is module-scoped. self._apply_gemma_hybrid_attention(model) self._finalize_moe_expert_quantization(model) @@ -172,12 +180,23 @@ class PatchManager: which exceeds flash attention's supported size. This patch loads the model with flash_attention_2 for the sliding window layers (head_dim=256), then gives each global layer a shallow-copied config with _attn_implementation="sdpa". + + We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask` + which fixes the corresponding mask construction inside + ``Gemma4TextModel.forward``. Without it, the per-layer SDPA config + override is not enough — the forward still builds a 2D FA2-format mask + at the model level and the SDPA layers crash at long context lengths + with ``RuntimeError: The expanded size of the tensor ... must match``. """ if not self.cfg.gemma4_hybrid_attn_impl: return import copy + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + patch_gemma4_hybrid_mask() + # Navigate to the module that has 'layers' - varies by model structure: # Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers # Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers @@ -391,6 +410,14 @@ class PatchManager: patch_qwen3_5_vlm_flash_attention() if self.cfg.model_config_type in ("gemma4", "gemma4_text"): + # The fused attn path is now compatible with + # ``gemma4_hybrid_attn_impl``: the kernel handles partial + # rotary (cos.shape[-1] < head_dim) and the fused forward + # mirrors the current ``Gemma4TextAttention.forward`` API + # for shared kv (read from / write to + # ``past_key_values.shared_layers``). See + # ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md`` + # for the history. from axolotl.monkeypatch.models.gemma4.fused_attn import ( patch_gemma4_fused_attn, ) diff --git a/src/axolotl/monkeypatch/gemma4_hybrid_mask.py b/src/axolotl/monkeypatch/gemma4_hybrid_mask.py new file mode 100644 index 000000000..17b8cf053 --- /dev/null +++ b/src/axolotl/monkeypatch/gemma4_hybrid_mask.py @@ -0,0 +1,115 @@ +"""Hybrid attention mask fix for Gemma 4. + +Gemma 4 has full-attention (global) layers with ``head_dim=512`` which +exceeds flash-attention-2's supported size. Axolotl's hybrid-attention +patch in ``patch_manager._apply_gemma_hybrid_attention`` works around +this by forcing ``_attn_implementation="sdpa"`` on each global layer's +``self_attn.config``, leaving sliding-window layers on FA2. + +The per-layer config override alone is insufficient, however: +``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict +using the **model-level** config and passes the mapped mask to each +decoder layer. With FA2 still set at the model level, the ``full_attention`` +entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask. +The global layers then fail with:: + + RuntimeError: The expanded size of the tensor (S) must match the existing + size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor + sizes: [B, S] + +...when the sequence length grows past roughly 7k tokens. + +This module fixes the symptom by monkey-patching ``create_causal_mask`` in +``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT +the original in ``masking_utils``. The wrapper forces +``_attn_implementation="sdpa"`` on a shallow-copied config before calling +through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward`` +is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left +alone, so sliding-window layers continue to receive FA2-format masks. + +The patch is idempotent. Install once per process, before any Gemma 4 +forward pass runs. +""" + +from __future__ import annotations + +import copy +from typing import Any + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +_PATCH_APPLIED = False + + +def patch_gemma4_hybrid_mask() -> bool: + """Install the Gemma 4 hybrid-attention mask fix. + + Returns ``True`` if the patch was installed (or was already installed), + ``False`` if the target module could not be imported (e.g. transformers + version predates Gemma 4) — in which case nothing is done and the + caller can continue unaffected. + """ + global _PATCH_APPLIED + if _PATCH_APPLIED: + return True + + try: + from transformers.models.gemma4 import modeling_gemma4 + except ImportError: + LOG.debug( + "gemma4_hybrid_mask: transformers.models.gemma4 not importable, " + "skipping. This is fine for non-Gemma4 training." + ) + return False + + if not hasattr(modeling_gemma4, "create_causal_mask"): + LOG.warning( + "gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' " + "binding, skipping. Transformers API may have changed." + ) + return False + + original = modeling_gemma4.create_causal_mask + + def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any): + """Wrapper that forces SDPA format for the full-attention mask. + + The global layers were patched to SDPA by + ``_apply_gemma_hybrid_attention``, so their mask must be 4D. The + original ``create_causal_mask`` dispatches on + ``config._attn_implementation``; we shadow that with a local + override. + """ + sdpa_config = copy.copy(config) + sdpa_config._attn_implementation = "sdpa" + return original(sdpa_config, *args, **kwargs) + + # Preserve the original reference on the wrapper for tests / teardown. + hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined] + + modeling_gemma4.create_causal_mask = hybrid_create_causal_mask + _PATCH_APPLIED = True + LOG.info( + "gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to " + "force SDPA-format masks for full-attention layers" + ) + return True + + +def unpatch_gemma4_hybrid_mask() -> None: + """Restore the original ``create_causal_mask``. Useful for tests.""" + global _PATCH_APPLIED + if not _PATCH_APPLIED: + return + try: + from transformers.models.gemma4 import modeling_gemma4 + except ImportError: + _PATCH_APPLIED = False + return + current = modeling_gemma4.create_causal_mask + original = getattr(current, "_axolotl_original", None) + if original is not None: + modeling_gemma4.create_causal_mask = original + _PATCH_APPLIED = False diff --git a/src/axolotl/monkeypatch/tiled_mlp/patch.py b/src/axolotl/monkeypatch/tiled_mlp/patch.py index 65885396b..23f48a101 100644 --- a/src/axolotl/monkeypatch/tiled_mlp/patch.py +++ b/src/axolotl/monkeypatch/tiled_mlp/patch.py @@ -24,7 +24,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None): module_path = f"transformers.models.{model_type}.modeling_{model_type}" model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"]) - mlp_cls = getattr(module, f"{model_cls_prefix}MLP") + # Some multimodal wrappers (e.g. Gemma 4) name the MLP class + # ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the + # language-side module is separated from the vision tower. Try + # both names before giving up. + mlp_cls = getattr( + module, + f"{model_cls_prefix}MLP", + None, + ) or getattr(module, f"{model_cls_prefix}TextMLP") if use_original_mlp: mlp_forward = mlp_cls.forward diff --git a/src/axolotl/scripts/vllm_serve_lora.py b/src/axolotl/scripts/vllm_serve_lora.py index 344c4327f..ca2f743fc 100644 --- a/src/axolotl/scripts/vllm_serve_lora.py +++ b/src/axolotl/scripts/vllm_serve_lora.py @@ -320,6 +320,15 @@ def main(script_args: ScriptArguments): # --- Active LoRA state (shared across endpoints via closure) --- active_lora: dict = {"request": None} + # Serializes access to the worker pipe. The underlying + # multiprocessing.Connection is a single full-duplex stream shared + # across all HTTP handlers; concurrent requests interleave bytes on + # the wire and corrupt the pickle framing (seen as + # ``UnpicklingError: pickle data was truncated``). Any endpoint that + # does ``conn.send(...); conn.recv()`` MUST hold this lock across + # the round-trip so only one inflight call at a time per pipe. + worker_pipe_lock = asyncio.Lock() + # ------------------------------------------------------------------ # LoRA-specific endpoints # ------------------------------------------------------------------ @@ -631,6 +640,150 @@ def main(script_args: ScriptArguments): }, } + @app.post("/v1/completions") + async def openai_completions(request_body: dict): + """OpenAI-compatible text-completions endpoint. + + Accepts either a string ``prompt`` or a list-of-int + ``prompt_token_ids`` (as the text-completions spec allows). Routes + to the internal vLLM generate method with the active LoRA adapter + and returns an OpenAI /v1/completions-shaped response including + per-choice ``prompt_token_ids``, ``generation_token_ids``, and + ``generation_log_probs`` for NeMo Gym agents that need raw + tokens + logprobs. + """ + import uuid + + prompt_raw = request_body.get("prompt") + temperature = request_body.get("temperature", 1.0) + max_tokens = request_body.get("max_tokens", 512) + top_p = request_body.get("top_p", 1.0) + n = request_body.get("n", 1) + logprobs = request_body.get("logprobs") or 0 + stop_token_ids = request_body.get("stop_token_ids") or None + + # Accept either a string or a list[int] token id prompt. Lists + # must contain ints only (raise on lists of strings so callers get + # a clear error). Also accept [[int, int, ...]] nesting for the + # rare case callers pass a single-prompt batch. + if ( + isinstance(prompt_raw, list) + and prompt_raw + and isinstance(prompt_raw[0], list) + ): + prompt_raw = prompt_raw[0] + + prompt_dict: dict[str, Any] = {} + if isinstance(prompt_raw, list): + prompt_dict = {"prompt_token_ids": prompt_raw} + elif isinstance(prompt_raw, str): + prompt_dict = {"prompt": prompt_raw} + else: + return { + "error": { + "message": ("prompt must be a string or a list of token ids"), + "type": "invalid_request", + } + } + + generation_kwargs: dict[str, Any] = { + "n": n, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "logprobs": logprobs, + } + if stop_token_ids: + generation_kwargs["stop_token_ids"] = stop_token_ids + sampling_params = SamplingParams( + **{k: v for k, v in generation_kwargs.items() if v is not None} + ) + + chunked = chunk_list([prompt_dict], script_args.data_parallel_size) + + # Hold the pipe lock across send+recv — concurrent requests would + # otherwise interleave pickle frames on the worker connection. + async with worker_pipe_lock: + for conn, chunk in zip(connections, chunked, strict=True): + if not chunk: + chunk = [{"prompt": ""}] + kwargs = { + "prompts": chunk, + "sampling_params": sampling_params, + "lora_request": active_lora["request"], + } + conn.send({"type": "call", "method": "generate", "kwargs": kwargs}) + + loop = asyncio.get_running_loop() + all_outputs = await asyncio.gather( + *(loop.run_in_executor(None, safe_recv, conn) for conn in connections) + ) + + all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c] + for o in all_outputs: + if isinstance(o, dict) and "error" in o: + raise RuntimeError(f"vLLM worker error: {o['error']}") + all_outputs = list(chain.from_iterable(all_outputs)) + + if not all_outputs: + return {"choices": [], "model": script_args.model} + + choices = [] + for i, output in enumerate(all_outputs): + for j, out in enumerate(output.outputs): + text = out.text + # OpenAI-style `logprobs` block for text-completions: + # { "tokens": [...], "token_logprobs": [...] } + lp_block = None + if out.logprobs: + tokens_str: list[str] = [] + token_lps: list[float] = [] + for step in out.logprobs: + chosen = next(iter(step.values())) + tokens_str.append(getattr(chosen, "decoded_token", "") or "") + token_lps.append(float(chosen.logprob)) + lp_block = { + "tokens": tokens_str, + "token_logprobs": token_lps, + } + + choice = { + "index": i * n + j, + "text": text, + "finish_reason": "stop" + if out.finish_reason == "stop" + else "length", + "logprobs": lp_block, + # NeMo-Gym / retrace agent extras — preserved on the + # choice so callers with raw-token pipelines don't + # have to re-tokenize. + "prompt_token_ids": output.prompt_token_ids, + "generation_token_ids": list(out.token_ids), + "generation_log_probs": ( + [float(next(iter(lp.values())).logprob) for lp in out.logprobs] + if out.logprobs + else [] + ), + } + choices.append(choice) + + prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0 + completion_tokens = sum( + len(out.token_ids) for o in all_outputs for out in o.outputs + ) + + return { + "id": f"cmpl-{uuid.uuid4().hex[:8]}", + "object": "text_completion", + "model": script_args.model, + "choices": choices, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + # --- Weight sync endpoints (legacy fallback, same as TRL) --- @app.post("/init_communicator/") diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 1780a9cc8..484a1fb47 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -760,6 +760,88 @@ class RLValidationMixin: ) return data + @model_validator(mode="before") + @classmethod + def check_grpo_batch_size_divisibility(cls, data): + """Surface GRPO batch-shape mismatches at config-parse time. + + TRL's GRPOTrainer requires that the per-step generation batch size be + evenly divisible by ``num_generations`` so that every prompt can be + replicated exactly ``num_generations`` times. The runtime check inside + ``GRPOTrainer.__init__`` only fires after the model has been loaded — + too late and too cryptic for the user. We replicate the check here so + the failure is immediate and actionable. + + Also enforces: + - ``num_generations >= 2`` (group-relative advantage needs variance) + - ``effective_gbs >= num_generations * world_size`` when capabilities + indicate multiple ranks (each rank needs at least one full group) + """ + if data.get("rl") != "grpo": + return data + + trl_cfg = data.get("trl") or {} + num_gen = trl_cfg.get("num_generations") + if num_gen is None: + # TRL's own default is 8 — but if the user didn't set it, we + # don't have enough info to validate anything. Let TRL's own + # init handle the default-vs-batch interaction. + return data + if num_gen < 2: + raise ValueError( + f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). " + "With num_generations=1, every group has zero advantage and " + "the policy never updates." + ) + + explicit_gbs = trl_cfg.get("generation_batch_size") + if explicit_gbs is not None: + effective_gbs = int(explicit_gbs) + gbs_source = "trl.generation_batch_size" + else: + mb = data.get("micro_batch_size") or 1 + ga = data.get("gradient_accumulation_steps") or 1 + effective_gbs = int(mb) * int(ga) + gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})" + + if effective_gbs % num_gen != 0: + # Suggest the smallest GA bump that fixes it for the common case + # where the user hasn't set generation_batch_size explicitly. + hint = "" + if explicit_gbs is None: + from math import gcd + + mb_val = int(data.get("micro_batch_size") or 1) + # smallest GA such that mb*GA is a multiple of num_gen + lcm = num_gen * mb_val // gcd(num_gen, mb_val) + suggested_ga = lcm // mb_val + hint = ( + f" Smallest fix: set `gradient_accumulation_steps: " + f"{suggested_ga}` (so micro_batch_size * GA = " + f"{mb_val * suggested_ga} is a multiple of {num_gen})." + ) + raise ValueError( + f"GRPO: generation batch size must be divisible by " + f"`trl.num_generations`. Got effective_gbs={effective_gbs} " + f"(from {gbs_source}) and num_generations={num_gen}.{hint}" + ) + + # Multi-rank check: each rank must receive at least one full group + # per step. Without `capabilities` populated yet (mode='before'), we + # fall back to user-set distributed fields. + world_size = ( + (data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1 + ) + if world_size and world_size > 1 and effective_gbs < num_gen * world_size: + raise ValueError( + f"GRPO with world_size={world_size} requires effective_gbs " + f">= num_generations * world_size = {num_gen * world_size}, " + f"got {effective_gbs}. Increase gradient_accumulation_steps " + f"or micro_batch_size." + ) + + return data + class OptimizationValidationMixin: """Validation methods related to optimization and performance.""" diff --git a/tests/core/test_async_grpo.py b/tests/core/test_async_grpo.py index 14c38df29..3a4c188bc 100644 --- a/tests/core/test_async_grpo.py +++ b/tests/core/test_async_grpo.py @@ -216,5 +216,197 @@ class TestValidateQuantPatchRestore(unittest.TestCase): self.assertIs(_trainer_module.validate_quantization_for_training, original) +class TestVllmLoraSyncPatch(unittest.TestCase): + """The ``_generate_single_turn`` patch wires sync_weights to the right place. + + These tests exercise the patch-installation branch in isolation. They build + a stub trainer with just enough attributes to look like + ``AsyncGRPOTrainer`` for the duration of the relevant code path. + + Background — there are two correct behaviors and we historically had a bug + where both modes used the same one: + + - Async prefetch ON: the BG generation thread can't safely call + sync_weights mid-rollout. We no-op the stock hook and drive sync from + the main thread via ``_maybe_sync_vllm_weights``. + - Async prefetch OFF: TRL's stock ``_generate_single_turn`` already + calls ``sync_weights`` once per step boundary on the main thread. We + wire that hook directly to ``_sync_lora_adapter`` because + ``_maybe_sync_vllm_weights`` short-circuits when async is off. + + Before the fix, both modes installed ``lambda: None``, so sync mode never + pushed any LoRA adapter to vLLM and the trainer was a no-op. + """ + + @staticmethod + def _make_stub_trainer(*, vllm_lora_sync, async_prefetch): + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + class FakeArgs: + pass + + args = FakeArgs() + args.vllm_lora_sync = vllm_lora_sync + args.async_prefetch = async_prefetch + + class FakeVllmGen: + sync_weights = staticmethod(lambda: None) + model = MagicMock() + + # Use object.__new__ so we don't run __init__ (which needs a real + # model, dataset, etc.). We only need the `_generate_single_turn` + # method's patch branch to run, so we set up the minimum state. + trainer = object.__new__(AsyncGRPOTrainer) + trainer.args = args + trainer.use_vllm = True + trainer.vllm_generation = FakeVllmGen() + trainer._patched_sync_weights = False + # Spy on _sync_lora_adapter so we can assert it's the function the + # hook delegates to in sync mode. + trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy") + trainer._sync_peft_weights_no_merge = MagicMock( + name="_sync_peft_weights_no_merge_spy" + ) + return trainer + + @staticmethod + def _run_patch_branch(trainer): + """Execute just the sync_weights-patching branch in isolation. + + We can't easily call the real ``_generate_single_turn`` because it + does a full vLLM generate. Instead we copy the exact branch out of + the source so the test verifies the same logic the trainer runs. + """ + if not getattr(trainer, "_patched_sync_weights", False): + if trainer.use_vllm and hasattr(trainer, "vllm_generation"): + if getattr(trainer.args, "vllm_lora_sync", False): + if getattr(trainer.args, "async_prefetch", False): + trainer.vllm_generation.sync_weights = lambda: None + else: + sync_helper = trainer._sync_lora_adapter + + def _lora_filesystem_sync(): + sync_helper() + + trainer.vllm_generation.sync_weights = _lora_filesystem_sync + trainer._patched_sync_weights = True + + def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self): + trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False) + self._run_patch_branch(trainer) + + assert trainer._patched_sync_weights is True + # Trigger the patched hook — it must call _sync_lora_adapter. + trainer.vllm_generation.sync_weights() + trainer._sync_lora_adapter.assert_called_once() + + def test_async_mode_with_lora_sync_installs_noop_hook(self): + trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True) + self._run_patch_branch(trainer) + + assert trainer._patched_sync_weights is True + # Hook must be a no-op so BG-thread generation doesn't fight the + # main-thread optimizer step over the model weights. + trainer.vllm_generation.sync_weights() + trainer._sync_lora_adapter.assert_not_called() + + def test_sync_mode_with_lora_sync_does_not_call_during_install(self): + """Installing the patch should not pre-emptively sync.""" + trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False) + self._run_patch_branch(trainer) + # _sync_lora_adapter should only be called when the patched hook + # itself is invoked (e.g., from TRL's _generate_single_turn). + trainer._sync_lora_adapter.assert_not_called() + + def test_patch_is_idempotent(self): + trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False) + self._run_patch_branch(trainer) + first_hook = trainer.vllm_generation.sync_weights + # Second call must not re-patch (otherwise we'd lose the original). + self._run_patch_branch(trainer) + assert trainer.vllm_generation.sync_weights is first_hook + + +class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase): + """``_maybe_sync_vllm_weights`` must not crash when interval is unset. + + Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError + on the very first call when ``vllm_sync_interval`` was ``None`` (which + is the default for any config that doesn't explicitly set it). We now + fall back to interval=1 so unset means "sync every step", matching the + behavior of TRL's own ``_generate_single_turn``. + """ + + @staticmethod + def _make_stub_trainer(interval, async_prefetch): + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + class FakeArgs: + pass + + args = FakeArgs() + args.async_prefetch = async_prefetch + args.vllm_sync_interval = interval + args.vllm_lora_sync = True + + class FakeState: + global_step = 1 + + trainer = object.__new__(AsyncGRPOTrainer) + trainer.args = args + trainer.use_vllm = True + trainer.state = FakeState() + trainer._last_synced_step = 0 + trainer._sync_lora_adapter = MagicMock(name="sync_spy") + return trainer + + def test_interval_none_in_async_mode_does_not_crash(self): + trainer = self._make_stub_trainer(interval=None, async_prefetch=True) + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + # Should not raise TypeError — defaults to every-step sync + AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer) + trainer._sync_lora_adapter.assert_called_once() + + def test_sync_mode_drives_sync(self): + """Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``. + + The previous behavior (early return when ``not async_prefetch``) + assumed TRL's stock ``_generate_single_turn`` would handle sync. + That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn + where the data producer bypasses ``_generate_single_turn`` + entirely. Without this trigger no sync ever happens and the + trainer becomes a no-op. + """ + trainer = self._make_stub_trainer(interval=1, async_prefetch=False) + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer) + trainer._sync_lora_adapter.assert_called_once() + + def test_async_mode_with_explicit_interval_respects_modulo(self): + trainer = self._make_stub_trainer(interval=4, async_prefetch=True) + from axolotl.core.trainers.grpo.async_trainer import ( + AsyncGRPOTrainer, + ) + + # global_step=1, interval=4 → 1 % 4 != 0 → no sync + AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer) + trainer._sync_lora_adapter.assert_not_called() + + # global_step=4 → 4 % 4 == 0 → sync + trainer.state.global_step = 4 + AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer) + trainer._sync_lora_adapter.assert_called_once() + + if __name__ == "__main__": unittest.main() diff --git a/tests/integrations/test_nemo_gym.py b/tests/integrations/test_nemo_gym.py index 7fd53cee0..83206043c 100644 --- a/tests/integrations/test_nemo_gym.py +++ b/tests/integrations/test_nemo_gym.py @@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase): assert cfg.dataloader_num_workers == 0 +class TestSelectWeightSyncTransport(unittest.TestCase): + """Pure-logic table tests for ``select_weight_sync_transport``.""" + + def _caps(self, **kwargs): + from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities + + c = VLLMWeightSyncCapabilities(probed=True) + for k, v in kwargs.items(): + setattr(c, k, v) + return c + + def test_lora_with_native_endpoint(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(lora_filesystem=True) + assert ( + select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True) + == "lora_filesystem" + ) + + def test_lora_with_axolotl_endpoint(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(lora_axolotl=True) + assert ( + select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False) + == "lora_filesystem" + ) + + def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(nccl=True) + assert ( + select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False) + == "nccl" + ) + + def test_full_param_prefers_nccl(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(nccl=True, http_full=True) + assert ( + select_weight_sync_transport( + caps, has_lora=False, vllm_lora_sync_pref=False + ) + == "nccl" + ) + + def test_full_param_falls_back_to_http(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps(http_full=True) + assert ( + select_weight_sync_transport( + caps, has_lora=False, vllm_lora_sync_pref=False + ) + == "http_full" + ) + + def test_full_param_no_routes_returns_none(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps() # all False + assert ( + select_weight_sync_transport( + caps, has_lora=False, vllm_lora_sync_pref=False + ) + == "none" + ) + + def test_lora_no_routes_returns_none(self): + from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport + + caps = self._caps() + assert ( + select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True) + == "none" + ) + + +class TestProbeVllmWeightSync(unittest.TestCase): + """``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps.""" + + def test_stock_vllm_with_lora_enabled(self): + """Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints.""" + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync + + spec = { + "paths": { + "/v1/models": {"get": {}}, + "/v1/load_lora_adapter": {"post": {}}, + "/v1/unload_lora_adapter": {"post": {}}, + "/v1/completions": {"post": {}}, + } + } + with patch("requests.get") as mock_get: + mock_get.return_value.raise_for_status = lambda: None + mock_get.return_value.json = lambda: spec + caps = probe_vllm_weight_sync("http://localhost:8000") + + assert caps.probed is True + assert caps.lora_filesystem is True + assert caps.lora_axolotl is False + assert caps.nccl is False + assert caps.http_full is False + + def test_axolotl_serve_lora_full_capabilities(self): + """``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync.""" + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync + + spec = { + "paths": { + "/init_communicator/": {"post": {}}, + "/update_named_param/": {"post": {}}, + "/batch_update_named_params/": {"post": {}}, + "/set_lora_adapter/": {"post": {}}, + "/clear_lora_adapter/": {"post": {}}, + "/http_update_weights/": {"post": {}}, + "/v1/load_lora_adapter": {"post": {}}, + } + } + with patch("requests.get") as mock_get: + mock_get.return_value.raise_for_status = lambda: None + mock_get.return_value.json = lambda: spec + caps = probe_vllm_weight_sync("http://localhost:8000") + + assert caps.probed is True + assert caps.nccl is True + assert caps.lora_axolotl is True + assert caps.lora_filesystem is True + assert caps.http_full is True + + def test_trl_vllm_serve_nccl_only(self): + """``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem.""" + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync + + spec = { + "paths": { + "/init_communicator/": {"post": {}}, + "/update_named_param/": {"post": {}}, + "/batch_update_named_params/": {"post": {}}, + "/close_communicator/": {"post": {}}, + "/generate/": {"post": {}}, + } + } + with patch("requests.get") as mock_get: + mock_get.return_value.raise_for_status = lambda: None + mock_get.return_value.json = lambda: spec + caps = probe_vllm_weight_sync("http://localhost:8000") + + assert caps.probed is True + assert caps.nccl is True + assert caps.lora_filesystem is False + assert caps.lora_axolotl is False + assert caps.http_full is False + + def test_unreachable_server_records_error(self): + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync + + with patch("requests.get") as mock_get: + mock_get.side_effect = ConnectionError("Connection refused") + caps = probe_vllm_weight_sync("http://localhost:9999") + + assert caps.probed is False + assert caps.probe_error is not None + assert "ConnectionError" in caps.probe_error + assert caps.nccl is False + assert caps.lora_filesystem is False + + +class TestPluginWeightSyncEnforcement(unittest.TestCase): + """End-to-end test of post_trainer_create's transport-selection branch. + + The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``, + leaving the trainer learning in isolation while vLLM kept serving the + unmodified base model. After the fix: + + - LoRA + LoRA-loading endpoint → installs filesystem LoRA sync + - LoRA + only NCCL endpoint → uses NCCL broadcast + - Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow) + - Full FT + HTTP endpoint → raises NotImplementedError (step 3) + - No usable transport → raises ValueError with a precise diagnosis + """ + + @staticmethod + def _fake_cfg(adapter, vllm_lora_sync): + class FakeTRL: + pass + + class FakeCfg: + pass + + trl = FakeTRL() + trl.vllm_lora_sync = vllm_lora_sync + trl.vllm_server_host = "127.0.0.1" + trl.vllm_server_port = 8000 + + cfg = FakeCfg() + cfg.nemo_gym_enabled = True + cfg.nemo_gym_model_name = None + cfg.base_model = "test/model" + cfg.nemo_gym_verify_timeout = 30 + cfg.nemo_gym_multi_turn = True + cfg.adapter = adapter + cfg.trl = trl + return cfg + + @staticmethod + def _fake_trainer(): + class FakeVLLMGen: + sync_weights = staticmethod(lambda: None) + + class FakeTrainer: + vllm_generation = FakeVLLMGen() + + return FakeTrainer() + + @staticmethod + def _caps(**kwargs): + from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities + + c = VLLMWeightSyncCapabilities(probed=True) + for k, v in kwargs.items(): + setattr(c, k, v) + return c + + def test_lora_with_lora_endpoint_installs_filesystem_sync(self): + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps(lora_filesystem=True) + cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True) + trainer = self._fake_trainer() + + with ( + patch.object(plugin, "_setup_lora_sync") as setup, + patch.object(plugin, "_check_lora_endpoint") as check, + patch.object(plugin, "_wire_multi_turn") as wire, + ): + plugin.post_trainer_create(cfg, trainer) + setup.assert_called_once() + check.assert_called_once() + wire.assert_called_once() + + def test_lora_with_no_routes_raises_with_lora_specific_message(self): + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps() # all False, but probed + cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False) + trainer = self._fake_trainer() + + with self.assertRaises(ValueError) as ctx: + plugin.post_trainer_create(cfg, trainer) + msg = str(ctx.exception) + assert "no-op trainer" in msg + assert "load_lora_adapter" in msg + assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg + + def test_full_finetune_with_nccl_endpoint_uses_nccl(self): + from unittest.mock import patch + + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps(nccl=True) + cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False) + trainer = self._fake_trainer() + + with patch.object(plugin, "_wire_multi_turn") as wire: + plugin.post_trainer_create(cfg, trainer) + wire.assert_called_once() + + def test_full_finetune_with_http_endpoint_not_implemented_yet(self): + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps(http_full=True) + cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False) + trainer = self._fake_trainer() + with self.assertRaises(NotImplementedError) as ctx: + plugin.post_trainer_create(cfg, trainer) + assert "HTTP weight sync" in str(ctx.exception) + + def test_full_finetune_with_no_routes_raises_with_full_param_message(self): + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + plugin._vllm_caps = self._caps() + cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False) + trainer = self._fake_trainer() + with self.assertRaises(ValueError) as ctx: + plugin.post_trainer_create(cfg, trainer) + msg = str(ctx.exception) + assert "no-op trainer" in msg + assert "init_communicator" in msg + assert "http_update_weights" in msg + + def test_unprobed_caps_raises_with_probe_failure_message(self): + from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin + + plugin = NemoGymPlugin() + # Plugin._vllm_caps left as default-None: the post_trainer_create + # branch falls back to a fresh VLLMWeightSyncCapabilities() with + # probed=False, so the error path should mention probing. + cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True) + trainer = self._fake_trainer() + with self.assertRaises(ValueError) as ctx: + plugin.post_trainer_create(cfg, trainer) + assert "could not probe" in str(ctx.exception) + + class TestNemoGymE2E(unittest.TestCase): """End-to-end test: data producer → agent (mocked) → parse → tensors → rewards. @@ -452,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase): trainer = self._make_mock_trainer() producer._trainer = trainer - # Mock the prompt iterator (returns a batch of 1 input) - producer._prompt_iter = iter( - [ - [ - { - "prompt": [{"role": "user", "content": "Play Wordle!"}], - } - ] - ] - ) - producer._prompt_dl = [ - [{"prompt": [{"role": "user", "content": "Play Wordle!"}]}] + # Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations) + # pre-expands prompts, so the iterator yields num_generations=2 consecutive + # copies of each unique prompt — one entry per rollout. + _prompt_batch = [ + {"prompt": [{"role": "user", "content": "Play Wordle!"}]}, + {"prompt": [{"role": "user", "content": "Play Wordle!"}]}, ] + producer._prompt_iter = iter([_prompt_batch]) + producer._prompt_dl = [_prompt_batch] # Call produce result = producer.produce(model=MagicMock(), global_step=1) @@ -530,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase): producer._request_timeout = 30 producer._num_generations = 2 producer._trainer = self._make_mock_trainer() - producer._prompt_iter = iter( - [[{"prompt": [{"role": "user", "content": "Play!"}]}]] - ) - producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]] + # RepeatSampler pre-expands by num_generations=2. + _prompt_batch = [ + {"prompt": [{"role": "user", "content": "Play!"}]}, + {"prompt": [{"role": "user", "content": "Play!"}]}, + ] + producer._prompt_iter = iter([_prompt_batch]) + producer._prompt_dl = [_prompt_batch] result = producer.produce(model=MagicMock(), global_step=1) diff --git a/tests/kernels/test_gemma4_fused_rope.py b/tests/kernels/test_gemma4_fused_rope.py index 7daedd612..297bb2527 100644 --- a/tests/kernels/test_gemma4_fused_rope.py +++ b/tests/kernels/test_gemma4_fused_rope.py @@ -38,6 +38,30 @@ def _reference_norm_noscale(x, eps): return norm(x) +def _reference_partial_norm_rope(x, weight, cos, sin, eps): + """Reference: Gemma4RMSNorm over the full head_dim, then stock + ``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with + the trailing columns passed through unchanged. Mirrors how Llama-style + partial rotary is layered on top of the stock RMSNorm + RoPE primitives. + """ + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4RMSNorm, + apply_rotary_pos_emb, + ) + + D = x.shape[-1] + n_rot = cos.shape[-1] + norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype) + norm.weight.data.copy_(weight) + normed = norm(x) + if n_rot == D: + return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2) + x_rot = normed[..., :n_rot] + x_pass = normed[..., n_rot:] + rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2) + return torch.cat([rotated, x_pass], dim=-1) + + @pytest.fixture( params=[ (2, 64, 32, 256), # sliding window layer shape @@ -194,6 +218,172 @@ class TestFusedRMSNormRoPEBackward: assert w.grad.abs().sum() > 0, "w.grad is all zeros" +class TestFusedRMSNormRoPEPartialRotary: + """Partial-rotary: cos/sin last dim is smaller than head_dim. + + Compares against the original primitives (`Gemma4RMSNorm` + + `apply_rotary_pos_emb`) applied to the rotated slice with the trailing + columns passed through. Without the kernel fix this used to crash with + `RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`. + """ + + @pytest.mark.parametrize( + "B,S,H,D,n_rot", + [ + (2, 16, 4, 64, 32), # half rotary (Llama-style 0.5) + (2, 16, 4, 64, 16), # quarter rotary + (2, 32, 8, 128, 64), # half rotary, larger heads + (1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial + (1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path + ], + ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"], + ) + def test_forward_matches_reference(self, B, S, H, D, n_rot): + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + eps = 1e-6 + x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + weight = torch.randn(D, device="cuda", dtype=torch.bfloat16) + cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + + y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps) + y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps) + + assert y_fused.shape == y_ref.shape == (B, S, H, D) + cos_sim = torch.nn.functional.cosine_similarity( + y_ref.flatten().float(), y_fused.flatten().float(), dim=0 + ) + assert cos_sim > 0.999, ( + f"partial rotary forward cosine_sim={cos_sim:.6f} " + f"(B={B},S={S},H={H},D={D},n_rot={n_rot})" + ) + + # The pass-through tail must equal the reference RMSNorm output bit- + # for-bit (any deviation would mean the kernel is touching it with a + # spurious rotation, which is the original bug class). + torch.testing.assert_close( + y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2 + ) + + @pytest.mark.parametrize( + "B,S,H,D,n_rot", + [(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)], + ids=["half_64", "quarter_256"], + ) + def test_x_grad_matches_reference(self, B, S, H, D, n_rot): + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + eps = 1e-6 + cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16) + x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + + # Reference backward via the original primitives + x_ref = x_data.clone().requires_grad_(True) + w_ref = weight_init.clone() + y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps) + y_ref.sum().backward() + + # Fused backward + x_fused = x_data.clone().requires_grad_(True) + w_fused = weight_init.clone().requires_grad_(True) + y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps) + y_fused.sum().backward() + + cos_sim_x = torch.nn.functional.cosine_similarity( + x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0 + ) + assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}" + + @pytest.mark.parametrize( + "B,S,H,D,n_rot", + [(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)], + ids=["half_64", "quarter_256"], + ) + def test_weight_grad_matches_reference(self, B, S, H, D, n_rot): + from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm + + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + eps = 1e-6 + cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16) + x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + + # Reference: Gemma4RMSNorm whose .weight collects grads, then partial + # rotary applied to the rotated slice. + norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16) + norm_ref.weight = torch.nn.Parameter(weight_init.clone()) + normed = norm_ref(x_data) + from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb + + rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2) + y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1) + y_ref.sum().backward() + + w_fused = weight_init.clone().requires_grad_(True) + fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward() + + cos_sim_w = torch.nn.functional.cosine_similarity( + w_fused.grad.flatten().float(), + norm_ref.weight.grad.flatten().float(), + dim=0, + ) + assert cos_sim_w > 0.995, ( + f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}" + ) + + def test_full_rotary_unchanged_when_n_rot_equals_d(self): + """Regression: passing cos/sin with shape == head_dim must still + match the full-rotary reference (the partial-rotary code path must + not perturb the existing full-rotary output).""" + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + B, S, H, D = 2, 16, 4, 64 + eps = 1e-6 + x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + weight = torch.randn(D, device="cuda", dtype=torch.bfloat16) + cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + + y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps) + y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps) + cos_sim = torch.nn.functional.cosine_similarity( + y_ref.flatten().float(), y_fused.flatten().float(), dim=0 + ) + assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}" + + def test_validation_errors(self): + """Wrapper rejects misshaped inputs cleanly (instead of a cryptic + Triton crash deeper in the kernel).""" + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + B, S, H, D = 1, 4, 2, 64 + x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + w = torch.randn(D, device="cuda", dtype=torch.bfloat16) + + # n_rot > head_dim + cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16) + sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="cannot exceed head_dim"): + fused_rms_norm_rope(x, w, cos_big, sin_big) + + # cos/sin last-dim mismatch + cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="same last dim"): + fused_rms_norm_rope(x, w, cos, sin) + + # odd rotary dim + cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16) + sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="must be even"): + fused_rms_norm_rope(x, w, cos_odd, sin_odd) + + class TestFusedRMSNormNoScale: """Tests for v_norm (RMSNorm without learnable scale).""" diff --git a/tests/monkeypatch/test_gemma4_fused_attn.py b/tests/monkeypatch/test_gemma4_fused_attn.py new file mode 100644 index 000000000..0530d0ee8 --- /dev/null +++ b/tests/monkeypatch/test_gemma4_fused_attn.py @@ -0,0 +1,219 @@ +"""Tests for the Gemma 4 fused-attention monkey-patch. + +These tests exercise the patched ``Gemma4TextAttention.forward`` against +the stock implementation it replaces. The hybrid Gemma 4 model intentionally +mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope +layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the +partial-rotary RMSNorm+RoPE path through the fused Triton kernel is +exercised end-to-end (this is the bug originally documented in +``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``). + +The full-model forward also pins that the fused forward keeps accepting +whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the +installed transformers version — so any future signature drift on +upstream's side trips a clear failure here instead of a confusing +TypeError deep in a training run. +""" + +import pytest +import torch + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), +] + +pytest.importorskip( + "transformers.models.gemma4", + reason="fused_attn patch only matters when Gemma 4 is available", +) + + +@pytest.fixture +def restore_gemma4_attention(): + """Snapshot ``Gemma4TextAttention.forward`` and restore after the test + so the monkey-patch does not leak across the suite.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention + + saved = Gemma4TextAttention.forward + yield Gemma4TextAttention + Gemma4TextAttention.forward = saved + + +def _build_hybrid_config(): + """Tiny hybrid Gemma 4 config: one sliding layer + one full-attention + layer with proportional rope and partial_rotary_factor=0.25. This is + the same shape pattern as ``google/gemma-4-26B-A4B-it`` but small + enough to fit on any GPU.""" + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + + cfg = Gemma4TextConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=32, + global_head_dim=64, + layer_types=["sliding_attention", "full_attention"], + sliding_window=64, + max_position_embeddings=2048, + hidden_size_per_layer_input=16, + vocab_size_per_layer_input=128, + rope_parameters={ + "sliding_attention": { + "rope_type": "default", + "rope_theta": 10000.0, + }, + "full_attention": { + "rope_type": "proportional", + "rope_theta": 1000000.0, + "partial_rotary_factor": 0.25, + }, + }, + ) + cfg._attn_implementation = "sdpa" + return cfg + + +def _build_model(seed=0): + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel + + torch.manual_seed(seed) + cfg = _build_hybrid_config() + return Gemma4TextModel(cfg).cuda().to(torch.bfloat16).eval() + + +class TestFusedAttnSignature: + """The fused forward must accept the same call shape as + ``Gemma4TextDecoderLayer`` produces in the installed transformers + version. Any signature drift surfaces here as a TypeError.""" + + def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention): + """Run a model forward that exercises the real + ``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with + the fused patch installed.""" + from axolotl.monkeypatch.models.gemma4.fused_attn import ( + patch_gemma4_fused_attn, + ) + + m = _build_model() + ids = torch.randint(0, 128, (2, 16), device="cuda") + mask = torch.ones(2, 16, dtype=torch.long, device="cuda") + + patch_gemma4_fused_attn() + with torch.no_grad(): + out = m(input_ids=ids, attention_mask=mask).last_hidden_state + + assert out.shape == (2, 16, 64) + assert torch.isfinite(out).all() + + +class TestFusedAttnPerLayerCorrectness: + """Compare the patched attention layer to the stock implementation + on a single forward call. This isolates the fused kernel correctness + from cross-layer numerical drift.""" + + def _run_attention(self, model, layer_idx, hidden_states, position_ids): + """Call ``Gemma4TextAttention.forward`` (whatever is currently + installed) for one layer and return the output.""" + attn = model.layers[layer_idx].self_attn + layer_type = model.config.layer_types[layer_idx] + cos, sin = model.rotary_emb(hidden_states, position_ids, layer_type) + out, _ = attn( + hidden_states=hidden_states, + position_embeddings=(cos, sin), + attention_mask=None, + shared_kv_states={}, + ) + return out + + @pytest.mark.parametrize( + "layer_idx", + [0, 1], + ids=["sliding_head32", "global_head64_proportional"], + ) + def test_forward_matches_stock(self, restore_gemma4_attention, layer_idx): + from axolotl.monkeypatch.models.gemma4.fused_attn import ( + patch_gemma4_fused_attn, + ) + + m = _build_model(seed=1) + hs = torch.randn(2, 16, 64, device="cuda", dtype=torch.bfloat16) + pos = torch.arange(16, device="cuda").unsqueeze(0).expand(2, -1) + + with torch.no_grad(): + ref = self._run_attention(m, layer_idx, hs, pos) + + patch_gemma4_fused_attn() + with torch.no_grad(): + got = self._run_attention(m, layer_idx, hs, pos) + + assert got.shape == ref.shape + assert torch.isfinite(got).all() + cos_sim = torch.nn.functional.cosine_similarity( + ref.flatten().float(), got.flatten().float(), dim=0 + ) + assert cos_sim > 0.999, ( + f"layer {layer_idx} fused vs stock cosine_sim={cos_sim:.6f}" + ) + # bf16 precision: a few millis of absolute drift per element is + # acceptable for a Q/K/V projection pipeline. Anything larger is + # a real bug. + torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-2) + + +class TestFusedAttnFullModel: + """End-to-end model forward + backward through both layer types.""" + + def test_full_forward_matches_stock(self, restore_gemma4_attention): + from axolotl.monkeypatch.models.gemma4.fused_attn import ( + patch_gemma4_fused_attn, + ) + + m = _build_model(seed=2) + ids = torch.randint(0, 128, (2, 32), device="cuda") + mask = torch.ones(2, 32, dtype=torch.long, device="cuda") + + with torch.no_grad(): + ref = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone() + + patch_gemma4_fused_attn() + with torch.no_grad(): + got = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone() + + assert got.shape == ref.shape + assert torch.isfinite(got).all() + cos_sim = torch.nn.functional.cosine_similarity( + ref.flatten().float(), got.flatten().float(), dim=0 + ) + # End-to-end through 2 layers (RMSNorm, attention, MLP/MoE) in bf16 + # accumulates a small amount of numerical drift; we just want to + # pin that the two paths are computing the same function. + assert cos_sim > 0.999, f"end-to-end cosine_sim={cos_sim:.6f}" + + def test_backward_grad_flows_through_fused_path(self, restore_gemma4_attention): + """Gradients must propagate through the fused RMSNorm+RoPE kernels + for both the sliding and proportional-rope layers.""" + from axolotl.monkeypatch.models.gemma4.fused_attn import ( + patch_gemma4_fused_attn, + ) + + m = _build_model(seed=3).train() + patch_gemma4_fused_attn() + + ids = torch.randint(0, 128, (2, 16), device="cuda") + mask = torch.ones(2, 16, dtype=torch.long, device="cuda") + out = m(input_ids=ids, attention_mask=mask).last_hidden_state + out.sum().backward() + + # Both layers must accumulate gradients on q_norm.weight and + # k_norm.weight — that proves the fused kernel ran the backward. + for i, layer in enumerate(m.layers[:2]): + attn = layer.self_attn + assert attn.q_norm.weight.grad is not None, f"layer {i} q_norm no grad" + assert attn.k_norm.weight.grad is not None, f"layer {i} k_norm no grad" + assert attn.q_norm.weight.grad.isfinite().all() + assert attn.k_norm.weight.grad.isfinite().all() + assert attn.q_norm.weight.grad.abs().sum() > 0 + assert attn.k_norm.weight.grad.abs().sum() > 0 diff --git a/tests/monkeypatch/test_gemma4_hybrid_mask.py b/tests/monkeypatch/test_gemma4_hybrid_mask.py new file mode 100644 index 000000000..66d56bcf1 --- /dev/null +++ b/tests/monkeypatch/test_gemma4_hybrid_mask.py @@ -0,0 +1,343 @@ +"""Tests for the Gemma 4 hybrid-attention mask fix. + +These tests pin the single critical behavior: after installing the patch, +``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to +the underlying mask builder regardless of what the caller's config says. +This is what keeps full-attention (head_dim=512) global layers from +crashing at long sequence lengths — they need a 4D SDPA-format mask, not +the 2D FA2 mask that would be built from the model-level config. + +The tests use a mocked ``create_causal_mask`` so they don't have to load +a real 26B Gemma 4 model or even have access to its weights. What matters +for the bug fix is which config is handed to the mask factory, not the +factory's actual output. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +pytest.importorskip( + "transformers.models.gemma4", + reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available", +) + + +@pytest.fixture +def restore_gemma4_module(): + """Snapshot ``modeling_gemma4.create_causal_mask`` and restore after + each test so patch state doesn't leak across the suite.""" + from transformers.models.gemma4 import modeling_gemma4 + + saved = modeling_gemma4.create_causal_mask + yield modeling_gemma4 + modeling_gemma4.create_causal_mask = saved + # Reset the module-level flag so the next test can re-install cleanly. + from axolotl.monkeypatch import gemma4_hybrid_mask + + gemma4_hybrid_mask._PATCH_APPLIED = False + + +def test_patch_replaces_create_causal_mask(restore_gemma4_module): + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + original = modeling_gemma4.create_causal_mask + assert patch_gemma4_hybrid_mask() is True + + assert modeling_gemma4.create_causal_mask is not original + assert modeling_gemma4.create_causal_mask._axolotl_original is original, ( + "patched wrapper must expose the original reference for teardown" + ) + + +def test_patch_is_idempotent(restore_gemma4_module): + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + patch_gemma4_hybrid_mask() + wrapper_first = modeling_gemma4.create_causal_mask + + # Second call must not re-wrap the already-wrapped function (which + # would leak the original reference through a chain of wrappers). + patch_gemma4_hybrid_mask() + wrapper_second = modeling_gemma4.create_causal_mask + + assert wrapper_first is wrapper_second + + +def test_patched_mask_forces_sdpa_config(restore_gemma4_module): + """Core invariant: when the patched wrapper is called with a config + that says ``flash_attention_2``, the underlying mask factory receives + a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``. + + Without this, the full-attention global layers get a 2D FA2 mask and + crash at long seq lens with the [B, H, S, S] / [B, S] expand error. + """ + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + # Swap in a mock BEFORE installing the patch so the wrapper captures + # it as the "original". The mock records every call so we can inspect + # what config got passed through. + mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d") + modeling_gemma4.create_causal_mask = mock_factory + patch_gemma4_hybrid_mask() + + # Caller-supplied config says FA2 (that's the model-level setting). + caller_config = SimpleNamespace( + _attn_implementation="flash_attention_2", + head_dim=512, + some_other_attr="preserved", + ) + result = modeling_gemma4.create_causal_mask( + caller_config, + inputs_embeds=None, + attention_mask=None, + past_key_values=None, + position_ids=None, + ) + + # Wrapper returned whatever the mock returned — no transformation of + # the result itself. + assert result == "mask_4d" + + # The mock was called exactly once with a config whose + # ``_attn_implementation`` is sdpa, NOT the caller's fa2. + assert mock_factory.call_count == 1 + (passed_config, *_), passed_kwargs = mock_factory.call_args + assert passed_config._attn_implementation == "sdpa" + + # The wrapper must NOT mutate the caller's config in place — other + # mask builders (e.g. create_sliding_window_causal_mask) read from + # the same config and must still see fa2. + assert caller_config._attn_implementation == "flash_attention_2" + + # Other attributes on the config must be preserved so the underlying + # factory has everything it needs (head_dim, rope_theta, vocab_size, ...). + assert passed_config.head_dim == 512 + assert passed_config.some_other_attr == "preserved" + + +def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module): + """The wrapper must forward positional + keyword args to the original + unchanged, so transformers' own call-site in Gemma4TextModel.forward + keeps working across minor transformers-version signature drift.""" + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + mock_factory = MagicMock(return_value="mask") + modeling_gemma4.create_causal_mask = mock_factory + patch_gemma4_hybrid_mask() + + caller_config = SimpleNamespace(_attn_implementation="flash_attention_2") + modeling_gemma4.create_causal_mask( + caller_config, + "positional_arg", + inputs_embeds="embeds", + attention_mask="mask_2d", + past_key_values="cache", + position_ids="positions", + or_mask_function="or_fn", + ) + + args, kwargs = mock_factory.call_args + # First positional (after config override) is preserved. + assert args[1] == "positional_arg" + # All kwargs are forwarded untouched. + assert kwargs["inputs_embeds"] == "embeds" + assert kwargs["attention_mask"] == "mask_2d" + assert kwargs["past_key_values"] == "cache" + assert kwargs["position_ids"] == "positions" + assert kwargs["or_mask_function"] == "or_fn" + + +def test_unpatch_restores_original(restore_gemma4_module): + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import ( + patch_gemma4_hybrid_mask, + unpatch_gemma4_hybrid_mask, + ) + + sentinel = MagicMock(name="original") + modeling_gemma4.create_causal_mask = sentinel + patch_gemma4_hybrid_mask() + assert modeling_gemma4.create_causal_mask is not sentinel + + unpatch_gemma4_hybrid_mask() + assert modeling_gemma4.create_causal_mask is sentinel + + +def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module): + from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask + + # Should be a no-op, no exception. + unpatch_gemma4_hybrid_mask() + + +def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module): + """Only ``create_causal_mask`` is overridden — the sliding-window + factory must remain bound to its original to preserve FA2 masks for + the sliding-attention layers. If we accidentally patch both, the + sliding layers get SDPA format and lose the FA2 speedup.""" + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"): + pytest.skip("transformers version has no create_sliding_window_causal_mask") + + sliding_before = modeling_gemma4.create_sliding_window_causal_mask + patch_gemma4_hybrid_mask() + sliding_after = modeling_gemma4.create_sliding_window_causal_mask + assert sliding_after is sliding_before + + +# --------------------------------------------------------------------------- +# Integration tests with a tiny randomly-initialized Gemma4TextModel. +# +# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text +# model with 2 layers (one sliding, one full_attention), apply the hybrid +# attention path end-to-end, and run a forward pass with a padded +# attention_mask at a long-ish seq len. The invariant we're pinning is that +# the full_attention layer does not crash with the +# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]" +# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k +# tokens in the FSDP2 training run. +# --------------------------------------------------------------------------- + + +def _build_tiny_gemma4_text_model(): + """Return a tiny randomly-initialized Gemma4TextModel with mixed layers.""" + import torch + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel + + cfg = Gemma4TextConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=32, + layer_types=["sliding_attention", "full_attention"], + sliding_window=64, + max_position_embeddings=2048, + hidden_size_per_layer_input=16, + vocab_size_per_layer_input=128, + ) + # Caller-supplied attn impl simulates the pilot config (fa2 at model + # level). The hybrid patch is what makes this survive long context. + cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later + torch.manual_seed(42) + model = Gemma4TextModel(cfg).eval() + return model, cfg + + +def _apply_hybrid_attn_inline(model, cfg): + """Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does + to a model, without needing a full PatchManager / pydantic cfg.""" + import copy + + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + for layer_idx, layer in enumerate(model.layers): + if cfg.layer_types[layer_idx] != "sliding_attention": + attn = getattr(layer, "self_attn", None) + if attn is not None and hasattr(attn, "config"): + sdpa_cfg = copy.copy(attn.config) + sdpa_cfg._attn_implementation = "sdpa" + attn.config = sdpa_cfg + patch_gemma4_hybrid_mask() + + +def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module): + """End-to-end invariant: with the hybrid attn patch applied, a tiny + Gemma4TextModel runs a forward at long context (1024 tokens) with + real padding in the attention mask, producing the expected output + shape. This exercises the actual code path that crashed the pilot + without needing a real 26B checkpoint or CUDA.""" + import torch + + model, cfg = _build_tiny_gemma4_text_model() + _apply_hybrid_attn_inline(model, cfg) + + B, S = 2, 1024 + input_ids = torch.randint(0, cfg.vocab_size, (B, S)) + attn_mask = torch.ones(B, S, dtype=torch.long) + # Pad positions in the second row. Without padding, SDPA falls back to + # ``is_causal=True`` with ``mask=None`` — we need a materialized 4D + # mask to exercise the actual bug site. + attn_mask[1, S // 2 :] = 0 + + with torch.no_grad(): + out = model(input_ids=input_ids, attention_mask=attn_mask) + + assert out.last_hidden_state.shape == (B, S, cfg.hidden_size) + assert torch.isfinite(out.last_hidden_state).all() + + +def test_patched_create_causal_mask_returns_4d_for_real_config( + restore_gemma4_module, +): + """Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper + and verify the returned mask is a 4D tensor — which is the shape the + SDPA-patched global layers need. Without the patch and with a + caller-supplied FA2 config this would return a 2D mask and the layer + would crash at long context.""" + import torch + from transformers.cache_utils import DynamicCache + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + patch_gemma4_hybrid_mask() + modeling_gemma4 = restore_gemma4_module + + cfg = Gemma4TextConfig( + vocab_size=128, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=32, + layer_types=["sliding_attention", "full_attention"], + sliding_window=64, + max_position_embeddings=2048, + hidden_size_per_layer_input=16, + vocab_size_per_layer_input=128, + ) + # Simulate the pilot: caller says flash_attention_2, but global layers + # were switched to SDPA per-layer. Without the patch, create_causal_mask + # would return an FA2 2D mask here and the SDPA layer would crash. + cfg._attn_implementation = "flash_attention_2" + + B, S = 2, 1024 + inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32) + attention_mask = torch.ones((B, S), dtype=torch.long) + attention_mask[1, S // 2 :] = 0 # force the 4D materialized path + position_ids = torch.arange(S).unsqueeze(0).expand(B, -1) + past_key_values = DynamicCache(config=cfg) + + mask = modeling_gemma4.create_causal_mask( + config=cfg, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + assert mask is not None + assert isinstance(mask, torch.Tensor) + assert mask.dim() == 4, ( + f"expected a 4D SDPA-format mask, got {mask.dim()}D " + f"shape={tuple(mask.shape)}. The full_attention global layers need " + "this shape or they crash at long context." + ) + assert mask.shape[0] == B + assert mask.shape[-1] == S + assert mask.shape[-2] == S + + # Caller's config must be untouched — other code paths still read it. + assert cfg._attn_implementation == "flash_attention_2" diff --git a/tests/utils/schemas/validation/test_config_validators.py b/tests/utils/schemas/validation/test_config_validators.py index c756f1362..fbfa79ad8 100644 --- a/tests/utils/schemas/validation/test_config_validators.py +++ b/tests/utils/schemas/validation/test_config_validators.py @@ -5,6 +5,8 @@ Covers: - save_strategy: 'best' requires metric_for_best_model - streaming=True with val_set_size > 0 is rejected - lora_target_modules with invalid regex patterns is rejected + - GRPO: generation batch size must be divisible by num_generations, + num_generations >= 2, and effective_gbs >= num_generations * world_size """ import pytest @@ -117,3 +119,136 @@ class TestLoraTargetModulesRegexValidator: ) with pytest.raises(ValueError, match="invalid regex pattern"): validate_config(cfg) + + +class TestGRPOBatchSizeValidator: + """GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2. + + These call the @model_validator(mode="before") classmethod directly on a + plain dict — same input shape it receives during full Pydantic validation, + just without dragging in unrelated fields (datasets / model loading / etc.) + that aren't relevant to what's under test. The validator is registered on + ``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the + same code path ``axolotl train`` exercises. + """ + + @staticmethod + def _check(data): + from axolotl.utils.schemas.validation import RLValidationMixin + + return RLValidationMixin.check_grpo_batch_size_divisibility(data) + + def test_divisible_passes(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "trl": {"num_generations": 4}, + } + # Should return data unchanged (no exception) + out = self._check(data) + assert out["trl"]["num_generations"] == 4 + + def test_non_divisible_raises(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 2, + "trl": {"num_generations": 4}, + } + with pytest.raises(ValueError, match="num_generations"): + self._check(data) + + def test_non_divisible_error_includes_fix_hint(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 3, + "trl": {"num_generations": 4}, + } + with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"): + self._check(data) + + def test_num_generations_one_raises(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "trl": {"num_generations": 1}, + } + with pytest.raises(ValueError, match=r"num_generations >= 2"): + self._check(data) + + def test_explicit_generation_batch_size_divisible_passes(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "trl": {"num_generations": 4, "generation_batch_size": 8}, + } + out = self._check(data) + assert out["trl"]["generation_batch_size"] == 8 + + def test_explicit_generation_batch_size_non_divisible_raises(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "trl": {"num_generations": 4, "generation_batch_size": 6}, + } + with pytest.raises(ValueError, match="trl.generation_batch_size"): + self._check(data) + + def test_non_grpo_skips_check(self): + # Anything other than rl=grpo should pass through untouched, even + # with non-divisible batch sizes — they're irrelevant to other RL + # methods that don't use group-relative advantages. + data = { + "rl": "dpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 3, + "trl": {"num_generations": 4}, + } + assert self._check(data) is data + + def test_no_rl_set_skips_check(self): + data = { + "micro_batch_size": 1, + "gradient_accumulation_steps": 3, + } + assert self._check(data) is data + + def test_grpo_without_num_generations_skips_check(self): + # If num_generations isn't set, TRL uses its own default — we don't + # have enough info to validate, so the validator must short-circuit + # rather than guess. + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 3, + "trl": {}, + } + out = self._check(data) + assert out["rl"] == "grpo" + + def test_multi_rank_group_size_check(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, # gbs=4 + "world_size": 2, # need gbs >= 4*2 = 8 + "trl": {"num_generations": 4}, + } + with pytest.raises(ValueError, match=r"world_size=2"): + self._check(data) + + def test_multi_rank_group_size_satisfied(self): + data = { + "rl": "grpo", + "micro_batch_size": 1, + "gradient_accumulation_steps": 8, # gbs=8 >= 4*2 + "world_size": 2, + "trl": {"num_generations": 4}, + } + out = self._check(data) + assert out["gradient_accumulation_steps"] == 8