Files
axolotl/tests/core/test_builders.py
Dan Saunders 00cda8cc70 Data loader refactor (#2707)
* data loading refactor (wip)

* updates

* progress

* pytest

* pytest fix

* lint

* zero_first -> filelock, more simplifications

* small simplification

* import change

* nit

* lint

* simplify dedup

* couldnt resist

* review comments WIP

* continued wip

* minor changes

* fix; remove contrived test

* further refactor

* set default seed in pydantic config

* lint

* continued simplication

* lint

* renaming and nits

* filelock tests

* fix

* fix

* lint

* remove nullable arg

* remove unnecessary code

* moving dataset save fn to shared module

* remove debug print

* matching var naming

* fn name change

* coderabbit comments

* naming nit

* fix test
2025-06-10 19:53:07 -04:00

600 lines
20 KiB
Python

"""Unit tests for axolotl.core.builders"""
# pylint: disable=protected-access
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RLType
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
@pytest.fixture(name="base_cfg")
def fixture_base_cfg():
"""
Base config with all common arguments between SFT and RLHF
"""
cfg = DictDefault(
{
# Model and tokenizer settings
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
"sequence_len": 2048,
"model_config_type": "llama", # example type
# Basic training settings
"micro_batch_size": 2,
"eval_batch_size": 2,
"num_epochs": 1,
"gradient_accumulation_steps": 1,
"max_steps": 100,
"val_set_size": 0,
# Optimizer settings
"optimizer": "adamw_torch_fused",
"learning_rate": 0.00005,
"weight_decay": 0.01,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_epsilon": 0.00001,
"max_grad_norm": 1.0,
# LR scheduler settings
"lr_scheduler": "cosine",
"lr_scheduler_kwargs": {"foo": "bar"},
"warmup_steps": 10,
"warmup_ratio": None,
"cosine_min_lr_ratio": 0.1,
"cosine_constant_lr_ratio": 0.2,
# Checkpointing and saving
"save_steps": 100,
"output_dir": "./model-out",
"save_safetensors": True,
"save_total_limit": 4,
"save_only_model": False,
# Hardware/performance settings
"gradient_checkpointing": False,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
# Dtype
"fp16": False,
"bf16": False,
"tf32": False,
# Logging and evaluation
"logging_steps": 10,
"eval_steps": 50,
"eval_strategy": "steps",
"save_strategy": "steps",
"include_tokens_per_second": True,
# Other common settings
"seed": 42,
"remove_unused_columns": True,
"ddp_timeout": 1800,
"ddp_bucket_cap_mb": 25,
"ddp_broadcast_buffers": False,
}
)
normalize_config(cfg)
return cfg
@pytest.fixture(name="dpo_cfg")
def fixture_dpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.DPO,
"dpo_use_weighting": True,
"dpo_use_logits_to_keep": True,
"dpo_label_smoothing": 0.1,
"beta": 0.1, # DPO beta
}
)
return cfg
@pytest.fixture(name="orpo_cfg")
def fixture_orpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.ORPO,
"orpo_alpha": 0.1,
"max_prompt_len": 512,
}
)
return cfg
@pytest.fixture(name="kto_cfg")
def fixture_kto_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.KTO,
"kto_desirable_weight": 1.0,
"kto_undesirable_weight": 1.0,
"max_prompt_len": 512,
}
)
return cfg
@pytest.fixture(name="grpo_cfg")
def fixture_grpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.GRPO,
"trl": DictDefault(
{
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": False, # run on CPU
# "vllm_device": "auto",
# "vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": ["rewards.rand_reward_func"],
}
),
# Must be evenly divisible by num_generations
"micro_batch_size": 4,
}
)
return cfg
@pytest.fixture(name="ipo_cfg")
def fixture_ipo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.IPO,
"dpo_label_smoothing": 0,
"beta": 0.1,
}
)
return cfg
@pytest.fixture(name="simpo_cfg")
def fixture_simpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.SIMPO,
"rl_beta": 0.2,
"cpo_alpha": 0.9,
"simpo_gamma": 0.4,
}
)
return cfg
@pytest.fixture(name="sft_cfg")
def fixture_sft_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": None,
"sample_packing": False,
"eval_sample_packing": False,
"flash_attention": False,
}
)
return cfg
@pytest.fixture(name="rm_cfg")
def fixture_rm_cfg(sft_cfg):
cfg = sft_cfg.copy()
cfg.update(
DictDefault(
{
"reward_model": True,
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "bradley_terry.chat_template",
"split": "train[:1%]",
}
],
}
)
)
return cfg
@pytest.fixture(name="prm_cfg")
def fixture_prm_cfg(sft_cfg):
cfg = sft_cfg.copy()
cfg.update(
DictDefault(
{
"process_reward_model": True,
"datasets": [
{
"path": "trl-lib/math_shepherd",
"type": "stepwise_supervised",
"split": "train[:1%]",
}
],
}
)
)
return cfg
@pytest.fixture(name="tokenizer")
def fixture_tokenizer(base_cfg):
return load_tokenizer(base_cfg)
@pytest.fixture(name="model")
def fixture_model(base_cfg, tokenizer):
model, _ = ModelLoader(base_cfg, tokenizer).load()
return model
class TestHFRLTrainerBuilder:
"""
TestCase class for RLHF trainer builders
"""
def _test_common_training_arguments(self, training_arguments, rl: str):
"""Helper to test common arguments across all variants"""
# Basic training settings
if rl == "grpo":
# grpo_cfg's micro_batch_size is diff from others
assert training_arguments.per_device_train_batch_size == 4
else:
assert training_arguments.per_device_train_batch_size == 2
assert training_arguments.gradient_accumulation_steps == 1
assert training_arguments.max_steps == 100
# Optimizer settings
assert training_arguments.learning_rate == 0.00005
assert training_arguments.weight_decay == 0.01
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.max_grad_norm == 1.0
# LR scheduler settings
assert training_arguments.lr_scheduler_type == "cosine"
assert training_arguments.warmup_steps == 10
assert training_arguments.cosine_min_lr_ratio == 0.1
assert training_arguments.cosine_constant_lr_ratio == 0.2
# Other settings
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
assert training_arguments.gradient_checkpointing is False
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=dpo_cfg.rl)
# DPO specific
assert training_arguments.beta == 0.1
assert hasattr(training_arguments, "use_weighting")
assert training_arguments.use_weighting is True
assert training_arguments.label_smoothing == 0.1
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
# ORPO specific
assert training_arguments.beta == 0.1 # maps from orpo_alpha
assert training_arguments.max_prompt_length == 512
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=kto_cfg.rl)
# KTO specific
assert training_arguments.desirable_weight == 1.0
assert training_arguments.undesirable_weight == 1.0
assert training_arguments.max_prompt_length == 512
def _write_rewards_file(self, rewards_dir: Path):
"""
Writes reward function to local tmp path to be loaded on trainer building
"""
# Create rewards.py in a directory we can import from
rewards_dir.mkdir()
rewards_file = rewards_dir / "rewards.py"
rewards_file.write_text(
"""import random
def rand_reward_func(prompts, completions) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
"""
)
def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
rewards_dir = tmp_path / "rewards_test"
self._write_rewards_file(rewards_dir)
# Add the directory to Python path so we can import the module
sys.path.insert(0, str(rewards_dir))
try:
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
# GRPO specific
assert training_arguments.beta == 0.001
assert training_arguments.max_completion_length == 256
assert training_arguments.use_vllm is False
# assert training_arguments.vllm_device == "auto"
# assert training_arguments.vllm_gpu_memory_utilization == 0.15
assert training_arguments.num_generations == 4
# Test trainer creation to verify reward_funcs
trainer = builder.build(100)
# Verify reward functions are properly loaded
assert len(trainer.reward_funcs) == 1
assert trainer.reward_funcs[0].__module__ == "rewards"
assert trainer.reward_funcs[0].__name__ == "rand_reward_func"
finally:
# remove imported module from path
if str(rewards_dir) in sys.path:
sys.path.remove(str(rewards_dir))
def test_ipo_training_arguments(self, ipo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(ipo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
# IPO specific
assert training_arguments.beta == 0.1
assert training_arguments.loss_type == "ipo"
assert training_arguments.label_smoothing == 0
def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(simpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=simpo_cfg.rl)
# SIMPO specific
assert training_arguments.beta == 0.2
assert training_arguments.cpo_alpha == 0.9
assert training_arguments.simpo_gamma == 0.4
@pytest.mark.parametrize(
("cfg_string", "dataset_name"),
[
(
"dpo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
(
"ipo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
(
"grpo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
("orpo_cfg", None), # don't use fixture for orpo to use smaller split
("kto_cfg", None), # no fixture for kto
(
"simpo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
],
)
def test_custom_optimizer_cls_and_kwargs(
self,
request,
cfg_string,
dataset_name,
tmp_path,
model,
tokenizer,
):
cfg = request.getfixturevalue(cfg_string)
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
cfg["optimizer"] = "muon"
if cfg_string in ["dpo_cfg", "ipo_cfg", "grpo_cfg", "simpo_cfg"]:
cfg["datasets"] = [DictDefault(ALPACA_MESSAGES_CONFIG_REVISION)]
elif cfg_string == "kto_cfg":
cfg["datasets"] = [
DictDefault(
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
"type": "llama3.ultra",
"split": "train[:1%]",
}
)
]
elif cfg_string == "orpo_cfg":
cfg["datasets"] = [
DictDefault(
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
"type": "chat_template.argilla",
"split": "train[:1%]",
}
)
]
else:
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
if cfg_string == "grpo_cfg":
rewards_dir = tmp_path / "rewards_test"
self._write_rewards_file(rewards_dir)
# Add the directory to Python path so we can import the module
sys.path.insert(0, str(rewards_dir))
try:
# Only use mock for the commented out configs
if dataset_name is not None:
with patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset:
mock_load_dataset.return_value = request.getfixturevalue(
dataset_name
)
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
else:
# Load actual datasets for orpo_cfg and kto_cfg
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
builder.train_dataset = train_dataset
builder.eval_dataset = eval_dataset
trainer = builder.build(100)
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
Muon,
MuonOptimizerFactory,
)
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
assert optimizer_cls is MuonOptimizerFactory
assert optimizer_kwargs["lr"] == 0.00005
assert optimizer_kwargs["weight_decay"] == 0.01
assert optimizer_kwargs["betas"] == (0.998, 0.9)
assert optimizer_kwargs["eps"] == 0.00001
# Ensure optimizer is created with correct class
optim = trainer.create_optimizer()
assert isinstance(optim, Muon)
finally:
# remove imported module from path
if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path:
sys.path.remove(str(rewards_dir))
class TestHFCausalTrainerBuilder:
"""
TestCase class for SFT trainer builder
"""
def test_training_arguments(self, sft_cfg, model, tokenizer):
builder = HFCausalTrainerBuilder(sft_cfg, model, tokenizer)
trainer = builder.build(100)
training_arguments = trainer.args
# Test common arguments
assert training_arguments.per_device_train_batch_size == 2
assert training_arguments.gradient_accumulation_steps == 1
assert training_arguments.max_steps == 100
assert training_arguments.learning_rate == 0.00005
assert training_arguments.weight_decay == 0.01
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.max_grad_norm == 1.0
assert training_arguments.lr_scheduler_type == "cosine"
assert training_arguments.warmup_steps == 10
assert training_arguments.cosine_min_lr_ratio == 0.1
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
assert training_arguments.gradient_checkpointing is False
# SFT specific
assert training_arguments.sample_packing is False
assert training_arguments.eval_sample_packing is False
@pytest.mark.parametrize(
"cfg_string",
[
"sft_cfg",
"rm_cfg",
"prm_cfg",
],
)
def test_custom_optimizer_cls_and_kwargs(
self, request, cfg_string, model, tokenizer
):
cfg = request.getfixturevalue(cfg_string)
builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
cfg["optimizer"] = "muon"
# need to load datasets for reward model and process reward model trainer
if cfg_string in ["rm_cfg", "prm_cfg"]:
dataset_meta = load_datasets(cfg=cfg)
builder.train_dataset = dataset_meta.train_dataset
builder.eval_dataset = dataset_meta.eval_dataset
trainer = builder.build(100)
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
Muon,
MuonOptimizerFactory,
)
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
assert optimizer_cls is MuonOptimizerFactory
assert optimizer_kwargs["lr"] == 0.00005
assert optimizer_kwargs["weight_decay"] == 0.01
assert optimizer_kwargs["betas"] == (0.998, 0.9)
assert optimizer_kwargs["eps"] == 0.00001
# Ensure optimizer is created with correct class
optim = trainer.create_optimizer()
assert isinstance(optim, Muon)
class TestTrainerClsPlugin:
"""
TestCase class for trainer builder with plugin
"""
def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tokenizer):
"""
Test that the trainer cls is not none with plugin
Fixes #2693
"""
cfg = kto_cfg.copy()
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
try:
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
builder.build(100)
except TypeError as e:
# Error raised if trainer_cls is None
assert "'tuple' object has no attribute 'config'" not in str(e)
except Exception: # pylint: disable=broad-exception-caught
# Another error happens, so we passed trainer_cls to builder
pass