gemma3 packing fixes (#2449)

* make gemma3 work with packing

* multi-gpu e2e for ci

* update gemma3 model namespace to use mirror

* add gradient checkpointing to multigpu e2e ci

* update gemma3 examples for use_reentrant and fix ddp find unused params

* fix tests for gemma3

* fix import for test utils

* set correct train loss for gemma3 e2e
This commit is contained in:
Wing Lian
2025-03-31 17:15:23 -04:00
committed by GitHub
parent 4d36ecc724
commit 328d598114
8 changed files with 130 additions and 2 deletions

View File

@@ -5,6 +5,9 @@ tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -54,6 +57,8 @@ fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -7,6 +7,9 @@ skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
@@ -48,6 +51,8 @@ fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
local_rank:
logging_steps: 1
flash_attention: true

View File

@@ -524,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
# handle ddp
ddp_find_unused_parameters = None
if self.cfg.ddp:
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = []

View File

@@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"phi3",
"gemma",
"gemma2",
"gemma3",
"gemma3_text",
"cohere",
"cohere2",

View File

@@ -112,6 +112,7 @@ class DataCollatorForSeq2Seq:
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys()
labels = None
if return_tensors is None:
return_tensors = self.return_tensors
@@ -164,6 +165,8 @@ class DataCollatorForSeq2Seq:
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if not has_attn_mask:
del features["attention_mask"]
# prepare decoder_input_ids
if (

View File

@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.model_config_type == "mamba":
if cfg.model_config_type in ["mamba", "gemma3"]:
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:

View File

@@ -0,0 +1,100 @@
"""
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@pytest.fixture(scope="session", autouse=True)
def download_model():
# download the model
snapshot_download("axolotl-mirrors/gemma-3-4b-pt", repo_type="model")
class TestMultiGPUGemma3:
"""
Test case for Gemma3 models using LoRA
"""
def test_lora_ddp_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
"sequence_len": 2048,
"ddp_find_unused_parameters": True,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"chat_template": "gemma3",
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {
"use_reentrant": False,
},
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
)

View File

@@ -58,6 +58,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -121,6 +122,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -193,6 +195,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
@@ -270,6 +273,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
@@ -330,6 +334,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -400,6 +405,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -479,6 +485,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -781,6 +788,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",