tiled mlp fix for gemma4

This commit is contained in:
Wing Lian
2026-04-16 13:24:41 +00:00
parent 28283ff373
commit 78de2919a6
3 changed files with 11 additions and 5 deletions

View File

@@ -276,9 +276,7 @@ class _GroupShardedSampler:
if num_replicas < 1:
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
if not (0 <= rank < num_replicas):
raise ValueError(
f"rank must be in [0, {num_replicas}), got {rank}"
)
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
self.inner = inner
self.num_generations = num_generations
self.rank = rank

View File

@@ -123,6 +123,7 @@ class NemoGymDataProducer(GRPODataProducer):
# Diagnostic: log what this rank is about to fire.
try:
import collections
iid_counts = collections.Counter()
for it in dataset_items:
iid_counts[
@@ -248,7 +249,6 @@ class NemoGymDataProducer(GRPODataProducer):
except Exception as exc: # never let metric logging break training
LOG.warning("rollout wandb log failed: %s", exc)
# Decode completions for reward functions
completions = trainer.processing_class.batch_decode(
completion_ids, skip_special_tokens=True

View File

@@ -24,7 +24,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
# Some multimodal wrappers (e.g. Gemma 4) name the MLP class
# ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the
# language-side module is separated from the vision tower. Try
# both names before giving up.
mlp_cls = getattr(
module,
f"{model_cls_prefix}MLP",
None,
) or getattr(module, f"{model_cls_prefix}TextMLP")
if use_original_mlp:
mlp_forward = mlp_cls.forward