tiled mlp fix for gemma4
This commit is contained in:
@@ -276,9 +276,7 @@ class _GroupShardedSampler:
|
||||
if num_replicas < 1:
|
||||
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
|
||||
if not (0 <= rank < num_replicas):
|
||||
raise ValueError(
|
||||
f"rank must be in [0, {num_replicas}), got {rank}"
|
||||
)
|
||||
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
|
||||
self.inner = inner
|
||||
self.num_generations = num_generations
|
||||
self.rank = rank
|
||||
|
||||
@@ -123,6 +123,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
# Diagnostic: log what this rank is about to fire.
|
||||
try:
|
||||
import collections
|
||||
|
||||
iid_counts = collections.Counter()
|
||||
for it in dataset_items:
|
||||
iid_counts[
|
||||
@@ -248,7 +249,6 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
except Exception as exc: # never let metric logging break training
|
||||
LOG.warning("rollout wandb log failed: %s", exc)
|
||||
|
||||
|
||||
# Decode completions for reward functions
|
||||
completions = trainer.processing_class.batch_decode(
|
||||
completion_ids, skip_special_tokens=True
|
||||
|
||||
@@ -24,7 +24,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
||||
# Some multimodal wrappers (e.g. Gemma 4) name the MLP class
|
||||
# ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the
|
||||
# language-side module is separated from the vision tower. Try
|
||||
# both names before giving up.
|
||||
mlp_cls = getattr(
|
||||
module,
|
||||
f"{model_cls_prefix}MLP",
|
||||
None,
|
||||
) or getattr(module, f"{model_cls_prefix}TextMLP")
|
||||
|
||||
if use_original_mlp:
|
||||
mlp_forward = mlp_cls.forward
|
||||
|
||||
Reference in New Issue
Block a user