Feat: Add support for gemma3_text and add e2e for gemma2 (#2406)
This commit is contained in:
@@ -513,7 +513,6 @@ lr_div_factor: # Learning rate div factor
|
|||||||
# in the examples/ for your model and fine-tuning use case.
|
# in the examples/ for your model and fine-tuning use case.
|
||||||
#
|
#
|
||||||
# Valid values for 'optimizer' include:
|
# Valid values for 'optimizer' include:
|
||||||
# - adamw_hf
|
|
||||||
# - adamw_torch
|
# - adamw_torch
|
||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
|
|||||||
74
examples/gemma3/qlora.yml
Normal file
74
examples/gemma3/qlora.yml
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
base_model: google/gemma-3-1b-it
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: gemma3_text
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -12,7 +12,7 @@ liger-kernel==0.5.3
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.15.0
|
peft==0.15.0
|
||||||
transformers==4.49.0
|
transformers==4.50.0
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.5.2
|
accelerate==1.5.2
|
||||||
datasets==3.4.1
|
datasets==3.4.1
|
||||||
|
|||||||
@@ -114,3 +114,5 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]:
|
||||||
|
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"phi3",
|
"phi3",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
|
"gemma3_text",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"deepseek_v2",
|
"deepseek_v2",
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -23,6 +23,7 @@ class ChatTemplate(str, Enum):
|
|||||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
|
gemma3_text = "gemma3_text" # pylint: disable=invalid-name
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
|
|||||||
def test_geglu_model_integration():
|
def test_geglu_model_integration():
|
||||||
"""Test GeGLU activation with Gemma model."""
|
"""Test GeGLU activation with Gemma model."""
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
|
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
|
||||||
)
|
)
|
||||||
peft_config = get_peft_config(
|
peft_config = get_peft_config(
|
||||||
{
|
{
|
||||||
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
|
|||||||
"""Test LoRA kernel patches across different model architectures."""
|
"""Test LoRA kernel patches across different model architectures."""
|
||||||
# Load model with appropriate dtype
|
# Load model with appropriate dtype
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
|
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply LoRA configuration
|
# Apply LoRA configuration
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
E2E tests for lora llama
|
E2E tests for deepseekv3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
133
tests/e2e/test_gemma2.py
Normal file
133
tests/e2e/test_gemma2.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for gemma2
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGemma2:
|
||||||
|
"""
|
||||||
|
Test case for Gemma2 models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sample_packing",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_lora_gemma2(self, temp_dir, sample_packing):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "axolotl-ai-co/gemma-2-33M",
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"sample_packing": sample_packing,
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_property_mappings": {
|
||||||
|
"role": "from",
|
||||||
|
"content": "value",
|
||||||
|
},
|
||||||
|
"drop_system_message": True,
|
||||||
|
"split": "train[:1%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"special_tokens": {
|
||||||
|
"bos_token": "<bos>",
|
||||||
|
"eos_token": "<eos>",
|
||||||
|
},
|
||||||
|
"chat_template": "gemma", # gemma2's template is same as gemma
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sample_packing",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_fft_gemma2(self, temp_dir, sample_packing):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "axolotl-ai-co/gemma-2-33M",
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"sample_packing": sample_packing,
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_property_mappings": {
|
||||||
|
"role": "from",
|
||||||
|
"content": "value",
|
||||||
|
},
|
||||||
|
"split": "train[:1%]",
|
||||||
|
"drop_system_message": True,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"chat_template": "gemma", # gemma2's template is same as gemma
|
||||||
|
"special_tokens": {
|
||||||
|
"bos_token": "<bos>",
|
||||||
|
"eos_token": "<eos>",
|
||||||
|
},
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
131
tests/e2e/test_gemma3_text.py
Normal file
131
tests/e2e/test_gemma3_text.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for gemma3_text
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGemma3Text:
|
||||||
|
"""
|
||||||
|
Test case for Gemma3Text models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sample_packing",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_lora_gemma3_text(self, temp_dir, sample_packing):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "axolotl-ai-co/gemma-3-34M",
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"sample_packing": sample_packing,
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_property_mappings": {
|
||||||
|
"role": "from",
|
||||||
|
"content": "value",
|
||||||
|
},
|
||||||
|
"split": "train[:1%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"special_tokens": {
|
||||||
|
"bos_token": "<bos>",
|
||||||
|
"eos_token": "<eos>",
|
||||||
|
},
|
||||||
|
"chat_template": "gemma3_text",
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sample_packing",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_fft_gemma3_text(self, temp_dir, sample_packing):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "axolotl-ai-co/gemma-3-34M",
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"sample_packing": sample_packing,
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_property_mappings": {
|
||||||
|
"role": "from",
|
||||||
|
"content": "value",
|
||||||
|
},
|
||||||
|
"split": "train[:1%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"chat_template": "gemma3_text",
|
||||||
|
"special_tokens": {
|
||||||
|
"bos_token": "<bos>",
|
||||||
|
"eos_token": "<eos>",
|
||||||
|
},
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
@@ -54,7 +54,7 @@ class TestCustomSchedulers(unittest.TestCase):
|
|||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_hf",
|
"optimizer": "adamw_torch_fused",
|
||||||
"max_steps": 20,
|
"max_steps": 20,
|
||||||
"lr_scheduler": "rex",
|
"lr_scheduler": "rex",
|
||||||
"warmup_steps": 5,
|
"warmup_steps": 5,
|
||||||
|
|||||||
Reference in New Issue
Block a user