gemma3 packing fixes (#2449)

* make gemma3 work with packing

* multi-gpu e2e for ci

* update gemma3 model namespace to use mirror

* add gradient checkpointing to multigpu e2e ci

* update gemma3 examples for use_reentrant and fix ddp find unused params

* fix tests for gemma3

* fix import for test utils

* set correct train loss for gemma3 e2e
This commit is contained in:
Wing Lian
2025-03-31 17:15:23 -04:00
committed by GitHub
parent 4d36ecc724
commit 328d598114
8 changed files with 130 additions and 2 deletions

View File

@@ -524,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
# handle ddp
ddp_find_unused_parameters = None
if self.cfg.ddp:
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = []

View File

@@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"phi3",
"gemma",
"gemma2",
"gemma3",
"gemma3_text",
"cohere",
"cohere2",

View File

@@ -112,6 +112,7 @@ class DataCollatorForSeq2Seq:
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys()
labels = None
if return_tensors is None:
return_tensors = self.return_tensors
@@ -164,6 +165,8 @@ class DataCollatorForSeq2Seq:
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if not has_attn_mask:
del features["attention_mask"]
# prepare decoder_input_ids
if (

View File

@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.model_config_type == "mamba":
if cfg.model_config_type in ["mamba", "gemma3"]:
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: