tiled mlp fix for gemma4
This commit is contained in:
@@ -276,9 +276,7 @@ class _GroupShardedSampler:
|
|||||||
if num_replicas < 1:
|
if num_replicas < 1:
|
||||||
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
|
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
|
||||||
if not (0 <= rank < num_replicas):
|
if not (0 <= rank < num_replicas):
|
||||||
raise ValueError(
|
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
|
||||||
f"rank must be in [0, {num_replicas}), got {rank}"
|
|
||||||
)
|
|
||||||
self.inner = inner
|
self.inner = inner
|
||||||
self.num_generations = num_generations
|
self.num_generations = num_generations
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
# Diagnostic: log what this rank is about to fire.
|
# Diagnostic: log what this rank is about to fire.
|
||||||
try:
|
try:
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
iid_counts = collections.Counter()
|
iid_counts = collections.Counter()
|
||||||
for it in dataset_items:
|
for it in dataset_items:
|
||||||
iid_counts[
|
iid_counts[
|
||||||
@@ -248,7 +249,6 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
except Exception as exc: # never let metric logging break training
|
except Exception as exc: # never let metric logging break training
|
||||||
LOG.warning("rollout wandb log failed: %s", exc)
|
LOG.warning("rollout wandb log failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
# Decode completions for reward functions
|
# Decode completions for reward functions
|
||||||
completions = trainer.processing_class.batch_decode(
|
completions = trainer.processing_class.batch_decode(
|
||||||
completion_ids, skip_special_tokens=True
|
completion_ids, skip_special_tokens=True
|
||||||
|
|||||||
@@ -24,7 +24,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
|||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||||
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
# Some multimodal wrappers (e.g. Gemma 4) name the MLP class
|
||||||
|
# ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the
|
||||||
|
# language-side module is separated from the vision tower. Try
|
||||||
|
# both names before giving up.
|
||||||
|
mlp_cls = getattr(
|
||||||
|
module,
|
||||||
|
f"{model_cls_prefix}MLP",
|
||||||
|
None,
|
||||||
|
) or getattr(module, f"{model_cls_prefix}TextMLP")
|
||||||
|
|
||||||
if use_original_mlp:
|
if use_original_mlp:
|
||||||
mlp_forward = mlp_cls.forward
|
mlp_forward = mlp_cls.forward
|
||||||
|
|||||||
Reference in New Issue
Block a user