gemma3 packing fixes (#2449)
* make gemma3 work with packing * multi-gpu e2e for ci * update gemma3 model namespace to use mirror * add gradient checkpointing to multigpu e2e ci * update gemma3 examples for use_reentrant and fix ddp find unused params * fix tests for gemma3 * fix import for test utils * set correct train loss for gemma3 e2e
This commit is contained in:
@@ -524,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
|
||||
# handle ddp
|
||||
ddp_find_unused_parameters = None
|
||||
if self.cfg.ddp:
|
||||
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
False if self.cfg.ddp else None
|
||||
ddp_find_unused_parameters
|
||||
)
|
||||
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
report_to = []
|
||||
|
||||
@@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"phi3",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemma3",
|
||||
"gemma3_text",
|
||||
"cohere",
|
||||
"cohere2",
|
||||
|
||||
@@ -112,6 +112,7 @@ class DataCollatorForSeq2Seq:
|
||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
has_attn_mask = "attention_mask" in features[0].keys()
|
||||
labels = None
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
@@ -164,6 +165,8 @@ class DataCollatorForSeq2Seq:
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
if not has_attn_mask:
|
||||
del features["attention_mask"]
|
||||
|
||||
# prepare decoder_input_ids
|
||||
if (
|
||||
|
||||
@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.model_config_type == "mamba":
|
||||
if cfg.model_config_type in ["mamba", "gemma3"]:
|
||||
LOG.info("dropping attention_mask column")
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
|
||||
Reference in New Issue
Block a user