gemma3 packing fixes (#2449)
* make gemma3 work with packing * multi-gpu e2e for ci * update gemma3 model namespace to use mirror * add gradient checkpointing to multigpu e2e ci * update gemma3 examples for use_reentrant and fix ddp find unused params * fix tests for gemma3 * fix import for test utils * set correct train loss for gemma3 e2e
This commit is contained in:
@@ -5,6 +5,9 @@ tokenizer_type: AutoTokenizer
|
|||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
# gemma3 doesn't seem to play nice with ddp
|
||||||
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
strict: false
|
strict: false
|
||||||
@@ -54,6 +57,8 @@ fp16:
|
|||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ skip_prepare_dataset: true
|
|||||||
remove_unused_columns: false
|
remove_unused_columns: false
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|
||||||
|
# gemma3 doesn't seem to play nice with ddp
|
||||||
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
@@ -48,6 +51,8 @@ fp16:
|
|||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|||||||
@@ -524,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
and self.cfg.eval_steps
|
and self.cfg.eval_steps
|
||||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||||
) or False
|
) or False
|
||||||
|
|
||||||
|
# handle ddp
|
||||||
|
ddp_find_unused_parameters = None
|
||||||
|
if self.cfg.ddp:
|
||||||
|
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
|
||||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||||
False if self.cfg.ddp else None
|
ddp_find_unused_parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
report_to = []
|
report_to = []
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"phi3",
|
"phi3",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
|
"gemma3",
|
||||||
"gemma3_text",
|
"gemma3_text",
|
||||||
"cohere",
|
"cohere",
|
||||||
"cohere2",
|
"cohere2",
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ class DataCollatorForSeq2Seq:
|
|||||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
|
has_attn_mask = "attention_mask" in features[0].keys()
|
||||||
labels = None
|
labels = None
|
||||||
if return_tensors is None:
|
if return_tensors is None:
|
||||||
return_tensors = self.return_tensors
|
return_tensors = self.return_tensors
|
||||||
@@ -164,6 +165,8 @@ class DataCollatorForSeq2Seq:
|
|||||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
)
|
)
|
||||||
|
if not has_attn_mask:
|
||||||
|
del features["attention_mask"]
|
||||||
|
|
||||||
# prepare decoder_input_ids
|
# prepare decoder_input_ids
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
if cfg.model_config_type == "mamba":
|
if cfg.model_config_type in ["mamba", "gemma3"]:
|
||||||
LOG.info("dropping attention_mask column")
|
LOG.info("dropping attention_mask column")
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
|
|||||||
100
tests/e2e/multigpu/test_gemma3.py
Normal file
100
tests/e2e/multigpu/test_gemma3.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for multigpu lora tinyllama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_tensorboard
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_model():
|
||||||
|
# download the model
|
||||||
|
snapshot_download("axolotl-mirrors/gemma-3-4b-pt", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiGPUGemma3:
|
||||||
|
"""
|
||||||
|
Test case for Gemma3 models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_lora_ddp_packed(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"ddp_find_unused_parameters": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"eval_sample_packing": False,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"chat_template": "gemma3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
"gradient_checkpointing_kwargs": {
|
||||||
|
"use_reentrant": False,
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
|
||||||
|
)
|
||||||
@@ -58,6 +58,7 @@ class TestMultiGPULlama:
|
|||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -121,6 +122,7 @@ class TestMultiGPULlama:
|
|||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -193,6 +195,7 @@ class TestMultiGPULlama:
|
|||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
@@ -270,6 +273,7 @@ class TestMultiGPULlama:
|
|||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
@@ -330,6 +334,7 @@ class TestMultiGPULlama:
|
|||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
@@ -400,6 +405,7 @@ class TestMultiGPULlama:
|
|||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
@@ -479,6 +485,7 @@ class TestMultiGPULlama:
|
|||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
@@ -781,6 +788,7 @@ class TestMultiGPULlama:
|
|||||||
"max_steps": 2,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
|
|||||||
Reference in New Issue
Block a user