Feat: Add support for gemma3_text and add e2e for gemma2 (#2406)

This commit is contained in:
NanoCode012
2025-03-23 07:33:21 +07:00
committed by GitHub
parent 86bac48d14
commit 9f00465a5c
12 changed files with 348 additions and 6 deletions

View File

@@ -513,7 +513,6 @@ lr_div_factor: # Learning rate div factor
# in the examples/ for your model and fine-tuning use case.
#
# Valid values for 'optimizer' include:
# - adamw_hf
# - adamw_torch
# - adamw_torch_fused
# - adamw_torch_xla

74
examples/gemma3/qlora.yml Normal file
View File

@@ -0,0 +1,74 @@
base_model: google/gemma-3-1b-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
chat_template: gemma3_text
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -12,7 +12,7 @@ liger-kernel==0.5.3
packaging==23.2
peft==0.15.0
transformers==4.49.0
transformers==4.50.0
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.4.1

View File

@@ -114,3 +114,5 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")

View File

@@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"phi3",
"gemma",
"gemma2",
"gemma3_text",
"gemmoe",
"starcoder2",
"deepseek_v2",

File diff suppressed because one or more lines are too long

View File

@@ -23,6 +23,7 @@ class ChatTemplate(str, Enum):
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
gemma3_text = "gemma3_text" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name

View File

@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
)
peft_config = get_peft_config(
{
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
)
# Apply LoRA configuration

View File

@@ -1,5 +1,5 @@
"""
E2E tests for lora llama
E2E tests for deepseekv3
"""
import logging

133
tests/e2e/test_gemma2.py Normal file
View File

@@ -0,0 +1,133 @@
"""
E2E tests for gemma2
"""
import logging
import os
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestGemma2:
"""
Test case for Gemma2 models
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_lora_gemma2(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-2-33M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"drop_system_message": True,
"split": "train[:1%]",
},
],
"special_tokens": {
"bos_token": "<bos>",
"eos_token": "<eos>",
},
"chat_template": "gemma", # gemma2's template is same as gemma
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_fft_gemma2(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-2-33M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"split": "train[:1%]",
"drop_system_message": True,
},
],
"chat_template": "gemma", # gemma2's template is same as gemma
"special_tokens": {
"bos_token": "<bos>",
"eos_token": "<eos>",
},
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -0,0 +1,131 @@
"""
E2E tests for gemma3_text
"""
import logging
import os
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestGemma3Text:
"""
Test case for Gemma3Text models
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_lora_gemma3_text(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-3-34M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"split": "train[:1%]",
},
],
"special_tokens": {
"bos_token": "<bos>",
"eos_token": "<eos>",
},
"chat_template": "gemma3_text",
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_fft_gemma3_text(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-3-34M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"split": "train[:1%]",
},
],
"chat_template": "gemma3_text",
"special_tokens": {
"bos_token": "<bos>",
"eos_token": "<eos>",
},
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -54,7 +54,7 @@ class TestCustomSchedulers(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_hf",
"optimizer": "adamw_torch_fused",
"max_steps": 20,
"lr_scheduler": "rex",
"warmup_steps": 5,