From 78de2919a69e82ab89ea6a2f5e1e273d86b53d0d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 16 Apr 2026 13:24:41 +0000 Subject: [PATCH] tiled mlp fix for gemma4 --- src/axolotl/core/trainers/grpo/async_trainer.py | 4 +--- src/axolotl/integrations/nemo_gym/data_producer.py | 2 +- src/axolotl/monkeypatch/tiled_mlp/patch.py | 10 +++++++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index 334db8bb2..4759a30b0 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -276,9 +276,7 @@ class _GroupShardedSampler: if num_replicas < 1: raise ValueError(f"num_replicas must be >= 1, got {num_replicas}") if not (0 <= rank < num_replicas): - raise ValueError( - f"rank must be in [0, {num_replicas}), got {rank}" - ) + raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}") self.inner = inner self.num_generations = num_generations self.rank = rank diff --git a/src/axolotl/integrations/nemo_gym/data_producer.py b/src/axolotl/integrations/nemo_gym/data_producer.py index 1cbe5ad71..ca31e93ef 100644 --- a/src/axolotl/integrations/nemo_gym/data_producer.py +++ b/src/axolotl/integrations/nemo_gym/data_producer.py @@ -123,6 +123,7 @@ class NemoGymDataProducer(GRPODataProducer): # Diagnostic: log what this rank is about to fire. try: import collections + iid_counts = collections.Counter() for it in dataset_items: iid_counts[ @@ -248,7 +249,6 @@ class NemoGymDataProducer(GRPODataProducer): except Exception as exc: # never let metric logging break training LOG.warning("rollout wandb log failed: %s", exc) - # Decode completions for reward functions completions = trainer.processing_class.batch_decode( completion_ids, skip_special_tokens=True diff --git a/src/axolotl/monkeypatch/tiled_mlp/patch.py b/src/axolotl/monkeypatch/tiled_mlp/patch.py index 65885396b..23f48a101 100644 --- a/src/axolotl/monkeypatch/tiled_mlp/patch.py +++ b/src/axolotl/monkeypatch/tiled_mlp/patch.py @@ -24,7 +24,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None): module_path = f"transformers.models.{model_type}.modeling_{model_type}" model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"]) - mlp_cls = getattr(module, f"{model_cls_prefix}MLP") + # Some multimodal wrappers (e.g. Gemma 4) name the MLP class + # ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the + # language-side module is separated from the vision tower. Try + # both names before giving up. + mlp_cls = getattr( + module, + f"{model_cls_prefix}MLP", + None, + ) or getattr(module, f"{model_cls_prefix}TextMLP") if use_original_mlp: mlp_forward = mlp_cls.forward