Files
axolotl/tests/core/test_trainer_builder.py
2025-05-23 16:26:31 +07:00

367 lines
12 KiB
Python

"""
unit tests for axolotl.core.trainer_builder
"""
import sys
from pathlib import Path
import pytest
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.schemas.enums import RLType
@pytest.fixture(name="base_cfg")
def fixture_base_cfg():
"""
Base config with all common arguments between SFT and RLHF
"""
cfg = DictDefault(
{
# Model and tokenizer settings
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"model_config_type": "llama", # example type
# Basic training settings
"micro_batch_size": 2,
"eval_batch_size": 2,
"num_epochs": 1,
"gradient_accumulation_steps": 1,
"max_steps": 100,
"val_set_size": 0,
# Optimizer settings
"optimizer": "adamw_torch_fused",
"learning_rate": 0.00005,
"weight_decay": 0.01,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_epsilon": 0.00001,
"max_grad_norm": 1.0,
# LR scheduler settings
"lr_scheduler": "cosine",
"lr_scheduler_kwargs": {"foo": "bar"},
"warmup_steps": 10,
"warmup_ratio": None,
"cosine_min_lr_ratio": 0.1,
"cosine_constant_lr_ratio": 0.2,
# Checkpointing and saving
"save_steps": 100,
"output_dir": "./model-out",
"save_safetensors": True,
"save_total_limit": 4,
"save_only_model": False,
# Hardware/performance settings
"gradient_checkpointing": False,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
# Dtype
"fp16": False,
"bf16": False,
"tf32": False,
# Logging and evaluation
"logging_steps": 10,
"eval_steps": 50,
"eval_strategy": "steps",
"save_strategy": "steps",
"include_tokens_per_second": True,
# Other common settings
"seed": 42,
"remove_unused_columns": True,
"ddp_timeout": 1800,
"ddp_bucket_cap_mb": 25,
"ddp_broadcast_buffers": False,
}
)
normalize_config(cfg)
return cfg
@pytest.fixture(name="dpo_cfg")
def fixture_dpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.DPO,
"dpo_use_weighting": True,
"dpo_use_logits_to_keep": True,
"dpo_label_smoothing": 0.1,
"beta": 0.1, # DPO beta
}
)
return cfg
@pytest.fixture(name="orpo_cfg")
def fixture_orpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.ORPO,
"orpo_alpha": 0.1,
"max_prompt_len": 512,
}
)
return cfg
@pytest.fixture(name="kto_cfg")
def fixture_kto_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.KTO,
"kto_desirable_weight": 1.0,
"kto_undesirable_weight": 1.0,
"max_prompt_len": 512,
}
)
return cfg
@pytest.fixture(name="grpo_cfg")
def fixture_grpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.GRPO,
"trl": DictDefault(
{
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": False, # run on CPU
# "vllm_device": "auto",
# "vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": ["rewards.rand_reward_func"],
}
),
# Must be evenly divisible by num_generations
"micro_batch_size": 4,
}
)
return cfg
@pytest.fixture(name="ipo_cfg")
def fixture_ipo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.IPO,
"dpo_label_smoothing": 0.1,
"beta": 0.1,
}
)
return cfg
@pytest.fixture(name="sft_cfg")
def fixture_sft_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": None,
"sample_packing": False,
"eval_sample_packing": False,
"flash_attention": False,
}
)
return cfg
@pytest.fixture(name="tokenizer")
def fixture_tokenizer(base_cfg):
return load_tokenizer(base_cfg)
@pytest.fixture(name="model")
def fixture_model(base_cfg, tokenizer):
model, _ = load_model(base_cfg, tokenizer)
return model
class TestHFRLTrainerBuilder:
"""
TestCase class for RLHF trainer builders
"""
def _test_common_training_arguments(self, training_arguments, rl: str):
"""Helper to test common arguments across all variants"""
# Basic training settings
if rl == "grpo":
# grpo_cfg's micro_batch_size is diff from others
assert training_arguments.per_device_train_batch_size == 4
else:
assert training_arguments.per_device_train_batch_size == 2
assert training_arguments.gradient_accumulation_steps == 1
assert training_arguments.max_steps == 100
# Optimizer settings
assert training_arguments.learning_rate == 0.00005
assert training_arguments.weight_decay == 0.01
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.max_grad_norm == 1.0
# LR scheduler settings
assert training_arguments.lr_scheduler_type == "cosine"
assert training_arguments.warmup_steps == 10
assert training_arguments.cosine_min_lr_ratio == 0.1
assert training_arguments.cosine_constant_lr_ratio == 0.2
# Other settings
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
assert training_arguments.gradient_checkpointing is False
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=dpo_cfg.rl)
# DPO specific
assert training_arguments.beta == 0.1
assert hasattr(training_arguments, "use_weighting")
assert training_arguments.use_weighting is True
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
# ORPO specific
assert training_arguments.beta == 0.1 # maps from orpo_alpha
assert training_arguments.max_prompt_length == 512
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=kto_cfg.rl)
# KTO specific
assert training_arguments.desirable_weight == 1.0
assert training_arguments.undesirable_weight == 1.0
assert training_arguments.max_prompt_length == 512
def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
def _write_rewards_file(rewards_dir: Path):
"""
Writes reward function to local tmp path to be loaded on trainer building
"""
# Create rewards.py in a directory we can import from
rewards_dir.mkdir()
rewards_file = rewards_dir / "rewards.py"
rewards_file.write_text(
"""import random
def rand_reward_func(prompts, completions) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
"""
)
rewards_dir = tmp_path / "rewards_test"
_write_rewards_file(rewards_dir)
# Add the directory to Python path so we can import the module
sys.path.insert(0, str(rewards_dir))
try:
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
# GRPO specific
assert training_arguments.beta == 0.001
assert training_arguments.max_completion_length == 256
assert training_arguments.use_vllm is False
# assert training_arguments.vllm_device == "auto"
# assert training_arguments.vllm_gpu_memory_utilization == 0.15
assert training_arguments.num_generations == 4
# Test trainer creation to verify reward_funcs
trainer = builder.build(100)
# Verify reward functions are properly loaded
assert len(trainer.reward_funcs) == 1
assert trainer.reward_funcs[0].__module__ == "rewards"
assert trainer.reward_funcs[0].__name__ == "rand_reward_func"
finally:
# remove imported module from path
if str(rewards_dir) in sys.path:
sys.path.remove(str(rewards_dir))
def test_ipo_training_arguments(self, ipo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(ipo_cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
# IPO specific
assert training_arguments.beta == 0.1
assert training_arguments.loss_type == "ipo"
class TestHFCausalTrainerBuilder:
"""
TestCase class for SFT trainer builder
"""
def test_training_arguments(self, sft_cfg, model, tokenizer):
builder = HFCausalTrainerBuilder(sft_cfg, model, tokenizer)
trainer = builder.build(100)
training_arguments = trainer.args
# Test common arguments
assert training_arguments.per_device_train_batch_size == 2
assert training_arguments.gradient_accumulation_steps == 1
assert training_arguments.max_steps == 100
assert training_arguments.learning_rate == 0.00005
assert training_arguments.weight_decay == 0.01
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.max_grad_norm == 1.0
assert training_arguments.lr_scheduler_type == "cosine"
assert training_arguments.warmup_steps == 10
assert training_arguments.cosine_min_lr_ratio == 0.1
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
assert training_arguments.gradient_checkpointing is False
# SFT specific
assert training_arguments.sample_packing is False
assert training_arguments.eval_sample_packing is False
class TestTrainerClsPlugin:
"""
TestCase class for trainer builder with plugin
"""
def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tokenizer):
"""
Test that the trainer cls is not none with plugin
Fixes #2693
"""
kto_cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
with pytest.raises(
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
builder.build(100)