602 lines
20 KiB
Python
602 lines
20 KiB
Python
"""Unit tests for axolotl.core.builders"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from axolotl.common.datasets import load_datasets
|
|
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
|
from axolotl.loaders import ModelLoader, load_tokenizer
|
|
from axolotl.utils.config import normalize_config
|
|
from axolotl.utils.data import prepare_preference_datasets
|
|
from axolotl.utils.dict import DictDefault
|
|
from axolotl.utils.schemas.enums import RLType
|
|
|
|
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
|
|
|
|
|
|
@pytest.fixture(name="base_cfg")
|
|
def fixture_base_cfg():
|
|
"""
|
|
Base config with all common arguments between SFT and RLHF
|
|
"""
|
|
cfg = DictDefault(
|
|
{
|
|
# Model and tokenizer settings
|
|
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
|
"sequence_len": 2048,
|
|
"model_config_type": "llama", # example type
|
|
# Basic training settings
|
|
"micro_batch_size": 2,
|
|
"eval_batch_size": 2,
|
|
"num_epochs": 1,
|
|
"gradient_accumulation_steps": 1,
|
|
"max_steps": 100,
|
|
"val_set_size": 0,
|
|
# Optimizer settings
|
|
"optimizer": "adamw_torch_fused",
|
|
"learning_rate": 0.00005,
|
|
"weight_decay": 0.01,
|
|
"adam_beta1": 0.998,
|
|
"adam_beta2": 0.9,
|
|
"adam_epsilon": 0.00001,
|
|
"max_grad_norm": 1.0,
|
|
# LR scheduler settings
|
|
"lr_scheduler": "cosine",
|
|
"lr_scheduler_kwargs": {"foo": "bar"},
|
|
"warmup_steps": 10,
|
|
"warmup_ratio": None,
|
|
"cosine_min_lr_ratio": 0.1,
|
|
"cosine_constant_lr_ratio": 0.2,
|
|
# Checkpointing and saving
|
|
"save_steps": 100,
|
|
"output_dir": "./model-out",
|
|
"save_safetensors": True,
|
|
"save_total_limit": 4,
|
|
"save_only_model": False,
|
|
# Hardware/performance settings
|
|
"gradient_checkpointing": False,
|
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
|
"dataloader_num_workers": 1,
|
|
"dataloader_pin_memory": True,
|
|
"dataloader_prefetch_factor": 2,
|
|
"context_parallel_size": 1,
|
|
"tensor_parallel_size": 1,
|
|
# Dtype
|
|
"fp16": False,
|
|
"bf16": False,
|
|
"tf32": False,
|
|
# Logging and evaluation
|
|
"logging_steps": 10,
|
|
"eval_steps": 50,
|
|
"eval_strategy": "steps",
|
|
"save_strategy": "steps",
|
|
"include_tokens_per_second": True,
|
|
# Other common settings
|
|
"seed": 42,
|
|
"remove_unused_columns": True,
|
|
"ddp_timeout": 1800,
|
|
"ddp_bucket_cap_mb": 25,
|
|
"ddp_broadcast_buffers": False,
|
|
"dataset_processes": 4,
|
|
}
|
|
)
|
|
|
|
normalize_config(cfg)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="dpo_cfg")
|
|
def fixture_dpo_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.DPO,
|
|
"dpo_use_weighting": True,
|
|
"dpo_use_logits_to_keep": True,
|
|
"dpo_label_smoothing": 0.1,
|
|
"beta": 0.1, # DPO beta
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="orpo_cfg")
|
|
def fixture_orpo_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.ORPO,
|
|
"orpo_alpha": 0.1,
|
|
"max_prompt_len": 512,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="kto_cfg")
|
|
def fixture_kto_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.KTO,
|
|
"kto_desirable_weight": 1.0,
|
|
"kto_undesirable_weight": 1.0,
|
|
"max_prompt_len": 512,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="grpo_cfg")
|
|
def fixture_grpo_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.GRPO,
|
|
"trl": DictDefault(
|
|
{
|
|
"beta": 0.001,
|
|
"max_completion_length": 256,
|
|
"use_vllm": False, # run on CPU
|
|
# "vllm_device": "auto",
|
|
# "vllm_gpu_memory_utilization": 0.15,
|
|
"num_generations": 4,
|
|
"reward_funcs": ["rewards.rand_reward_func"],
|
|
}
|
|
),
|
|
# Must be evenly divisible by num_generations
|
|
"micro_batch_size": 4,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="ipo_cfg")
|
|
def fixture_ipo_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.IPO,
|
|
"dpo_label_smoothing": 0,
|
|
"beta": 0.1,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="simpo_cfg")
|
|
def fixture_simpo_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.SIMPO,
|
|
"rl_beta": 0.2,
|
|
"cpo_alpha": 0.9,
|
|
"simpo_gamma": 0.4,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="sft_cfg")
|
|
def fixture_sft_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": None,
|
|
"sample_packing": False,
|
|
"eval_sample_packing": False,
|
|
"flash_attention": False,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="rm_cfg")
|
|
def fixture_rm_cfg(sft_cfg):
|
|
cfg = sft_cfg.copy()
|
|
cfg.update(
|
|
DictDefault(
|
|
{
|
|
"reward_model": True,
|
|
"datasets": [
|
|
{
|
|
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
|
"type": "bradley_terry.chat_template",
|
|
"split": "train[:1%]",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="prm_cfg")
|
|
def fixture_prm_cfg(sft_cfg):
|
|
cfg = sft_cfg.copy()
|
|
cfg.update(
|
|
DictDefault(
|
|
{
|
|
"process_reward_model": True,
|
|
"datasets": [
|
|
{
|
|
"path": "trl-lib/math_shepherd",
|
|
"type": "stepwise_supervised",
|
|
"split": "train[:1%]",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="tokenizer")
|
|
def fixture_tokenizer(base_cfg):
|
|
return load_tokenizer(base_cfg)
|
|
|
|
|
|
@pytest.fixture(name="model")
|
|
def fixture_model(base_cfg, tokenizer):
|
|
model, _ = ModelLoader(base_cfg, tokenizer).load()
|
|
return model
|
|
|
|
|
|
class TestHFRLTrainerBuilder:
|
|
"""
|
|
TestCase class for RLHF trainer builders
|
|
"""
|
|
|
|
def _test_common_training_arguments(self, training_arguments, rl: str):
|
|
"""Helper to test common arguments across all variants"""
|
|
# Basic training settings
|
|
if rl == "grpo":
|
|
# grpo_cfg's micro_batch_size is diff from others
|
|
assert training_arguments.per_device_train_batch_size == 4
|
|
else:
|
|
assert training_arguments.per_device_train_batch_size == 2
|
|
assert training_arguments.gradient_accumulation_steps == 1
|
|
assert training_arguments.max_steps == 100
|
|
|
|
# Optimizer settings
|
|
assert training_arguments.learning_rate == 0.00005
|
|
assert training_arguments.weight_decay == 0.01
|
|
assert training_arguments.adam_beta1 == 0.998
|
|
assert training_arguments.adam_beta2 == 0.9
|
|
assert training_arguments.adam_epsilon == 0.00001
|
|
assert training_arguments.max_grad_norm == 1.0
|
|
|
|
# LR scheduler settings
|
|
assert training_arguments.lr_scheduler_type == "cosine"
|
|
assert training_arguments.warmup_steps == 10
|
|
assert training_arguments.cosine_min_lr_ratio == 0.1
|
|
assert training_arguments.cosine_constant_lr_ratio == 0.2
|
|
|
|
# Other settings
|
|
assert training_arguments.dataloader_num_workers == 1
|
|
assert training_arguments.dataloader_pin_memory is True
|
|
|
|
# TODO(wing): restore once trl releases 0.22.0
|
|
# assert training_arguments.gradient_checkpointing is True
|
|
|
|
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
|
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
|
training_arguments, _ = builder._build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=dpo_cfg.rl)
|
|
# DPO specific
|
|
assert training_arguments.beta == 0.1
|
|
assert hasattr(training_arguments, "use_weighting")
|
|
assert training_arguments.use_weighting is True
|
|
assert training_arguments.label_smoothing == 0.1
|
|
|
|
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
|
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
|
training_arguments, _ = builder._build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
|
|
# ORPO specific
|
|
assert training_arguments.beta == 0.1 # maps from orpo_alpha
|
|
assert training_arguments.max_prompt_length == 512
|
|
|
|
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
|
|
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
|
|
training_arguments, _ = builder._build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=kto_cfg.rl)
|
|
# KTO specific
|
|
assert training_arguments.desirable_weight == 1.0
|
|
assert training_arguments.undesirable_weight == 1.0
|
|
assert training_arguments.max_prompt_length == 512
|
|
|
|
def _write_rewards_file(self, rewards_dir: Path):
|
|
"""
|
|
Writes reward function to local tmp path to be loaded on trainer building
|
|
"""
|
|
# Create rewards.py in a directory we can import from
|
|
rewards_dir.mkdir()
|
|
rewards_file = rewards_dir / "rewards.py"
|
|
rewards_file.write_text(
|
|
"""import random
|
|
def rand_reward_func(prompts, completions) -> list[float]:
|
|
return [random.uniform(0, 1) for _ in completions]
|
|
"""
|
|
)
|
|
|
|
def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
|
|
rewards_dir = tmp_path / "rewards_test"
|
|
self._write_rewards_file(rewards_dir)
|
|
|
|
# Add the directory to Python path so we can import the module
|
|
sys.path.insert(0, str(rewards_dir))
|
|
|
|
try:
|
|
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
|
|
training_arguments, _ = builder._build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
|
|
# GRPO specific
|
|
assert training_arguments.beta == 0.001
|
|
assert training_arguments.max_completion_length == 256
|
|
assert training_arguments.use_vllm is False
|
|
# assert training_arguments.vllm_device == "auto"
|
|
# assert training_arguments.vllm_gpu_memory_utilization == 0.15
|
|
assert training_arguments.num_generations == 4
|
|
|
|
# Test trainer creation to verify reward_funcs
|
|
trainer = builder.build(100)
|
|
|
|
# Verify reward functions are properly loaded
|
|
assert len(trainer.reward_funcs) == 1
|
|
assert trainer.reward_funcs[0].__module__ == "rewards"
|
|
assert trainer.reward_funcs[0].__name__ == "rand_reward_func"
|
|
finally:
|
|
# remove imported module from path
|
|
if str(rewards_dir) in sys.path:
|
|
sys.path.remove(str(rewards_dir))
|
|
|
|
def test_ipo_training_arguments(self, ipo_cfg, model, tokenizer):
|
|
builder = HFRLTrainerBuilder(ipo_cfg, model, tokenizer)
|
|
training_arguments, _ = builder._build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
|
|
# IPO specific
|
|
assert training_arguments.beta == 0.1
|
|
assert training_arguments.loss_type == "ipo"
|
|
assert training_arguments.label_smoothing == 0
|
|
|
|
def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):
|
|
builder = HFRLTrainerBuilder(simpo_cfg, model, tokenizer)
|
|
training_arguments, _ = builder._build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=simpo_cfg.rl)
|
|
# SIMPO specific
|
|
assert training_arguments.beta == 0.2
|
|
assert training_arguments.cpo_alpha == 0.9
|
|
assert training_arguments.simpo_gamma == 0.4
|
|
|
|
@pytest.mark.parametrize(
|
|
("cfg_string", "dataset_name"),
|
|
[
|
|
(
|
|
"dpo_cfg",
|
|
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
|
|
),
|
|
(
|
|
"ipo_cfg",
|
|
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
|
|
),
|
|
(
|
|
"grpo_cfg",
|
|
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
|
|
),
|
|
("orpo_cfg", None), # don't use fixture for orpo to use smaller split
|
|
("kto_cfg", None), # no fixture for kto
|
|
# (
|
|
# "simpo_cfg",
|
|
# "dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
|
|
# ),
|
|
],
|
|
)
|
|
def test_custom_optimizer_cls_and_kwargs(
|
|
self,
|
|
request,
|
|
cfg_string,
|
|
dataset_name,
|
|
tmp_path,
|
|
model,
|
|
tokenizer,
|
|
):
|
|
cfg = request.getfixturevalue(cfg_string)
|
|
|
|
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
|
cfg["optimizer"] = "muon"
|
|
|
|
if cfg_string in ["dpo_cfg", "ipo_cfg", "grpo_cfg", "simpo_cfg"]:
|
|
cfg["datasets"] = [DictDefault(ALPACA_MESSAGES_CONFIG_REVISION)]
|
|
elif cfg_string == "kto_cfg":
|
|
cfg["datasets"] = [
|
|
DictDefault(
|
|
{
|
|
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
|
|
"type": "llama3.ultra",
|
|
"split": "train[:1%]",
|
|
}
|
|
)
|
|
]
|
|
elif cfg_string == "orpo_cfg":
|
|
cfg["datasets"] = [
|
|
DictDefault(
|
|
{
|
|
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
|
|
"type": "chat_template.argilla",
|
|
"split": "train[:1%]",
|
|
}
|
|
)
|
|
]
|
|
else:
|
|
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
|
|
cfg["dataset_num_proc"] = 4
|
|
|
|
if cfg_string == "grpo_cfg":
|
|
rewards_dir = tmp_path / "rewards_test"
|
|
self._write_rewards_file(rewards_dir)
|
|
|
|
# Add the directory to Python path so we can import the module
|
|
sys.path.insert(0, str(rewards_dir))
|
|
|
|
try:
|
|
# Only use mock for the commented out configs
|
|
if dataset_name is not None:
|
|
with patch(
|
|
"axolotl.utils.data.rl.load_dataset_with_config"
|
|
) as mock_load_dataset:
|
|
mock_load_dataset.return_value = request.getfixturevalue(
|
|
dataset_name
|
|
)
|
|
train_dataset, eval_dataset = prepare_preference_datasets(
|
|
cfg, tokenizer
|
|
)
|
|
else:
|
|
# Load actual datasets for orpo_cfg and kto_cfg
|
|
train_dataset, eval_dataset = prepare_preference_datasets(
|
|
cfg, tokenizer
|
|
)
|
|
|
|
builder.train_dataset = train_dataset
|
|
builder.eval_dataset = eval_dataset
|
|
|
|
trainer = builder.build(100)
|
|
|
|
assert trainer.optimizer_cls_and_kwargs is not None
|
|
|
|
from axolotl.contribs.mit.muon import (
|
|
Muon,
|
|
MuonOptimizerFactory,
|
|
)
|
|
|
|
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
|
assert optimizer_cls is MuonOptimizerFactory
|
|
assert optimizer_kwargs["lr"] == 0.00005
|
|
assert optimizer_kwargs["weight_decay"] == 0.01
|
|
assert optimizer_kwargs["betas"] == (0.998, 0.9)
|
|
assert optimizer_kwargs["eps"] == 0.00001
|
|
|
|
# Ensure optimizer is created with correct class
|
|
optim = trainer.create_optimizer()
|
|
assert isinstance(optim, Muon)
|
|
|
|
finally:
|
|
# remove imported module from path
|
|
if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path:
|
|
sys.path.remove(str(rewards_dir))
|
|
|
|
|
|
class TestHFCausalTrainerBuilder:
|
|
"""
|
|
TestCase class for SFT trainer builder
|
|
"""
|
|
|
|
def test_training_arguments(self, sft_cfg, model, tokenizer):
|
|
builder = HFCausalTrainerBuilder(sft_cfg, model, tokenizer)
|
|
trainer = builder.build(100)
|
|
training_arguments = trainer.args
|
|
|
|
# Test common arguments
|
|
assert training_arguments.per_device_train_batch_size == 2
|
|
assert training_arguments.gradient_accumulation_steps == 1
|
|
assert training_arguments.max_steps == 100
|
|
|
|
assert training_arguments.learning_rate == 0.00005
|
|
assert training_arguments.weight_decay == 0.01
|
|
assert training_arguments.adam_beta1 == 0.998
|
|
assert training_arguments.adam_beta2 == 0.9
|
|
assert training_arguments.adam_epsilon == 0.00001
|
|
assert training_arguments.max_grad_norm == 1.0
|
|
|
|
assert training_arguments.lr_scheduler_type == "cosine"
|
|
assert training_arguments.warmup_steps == 10
|
|
assert training_arguments.cosine_min_lr_ratio == 0.1
|
|
|
|
assert training_arguments.dataloader_num_workers == 1
|
|
assert training_arguments.dataloader_pin_memory is True
|
|
assert training_arguments.gradient_checkpointing is False
|
|
|
|
# SFT specific
|
|
assert training_arguments.sample_packing is False
|
|
assert training_arguments.eval_sample_packing is False
|
|
|
|
@pytest.mark.parametrize(
|
|
"cfg_string",
|
|
[
|
|
"sft_cfg",
|
|
"rm_cfg",
|
|
"prm_cfg",
|
|
],
|
|
)
|
|
def test_custom_optimizer_cls_and_kwargs(
|
|
self, request, cfg_string, model, tokenizer
|
|
):
|
|
cfg = request.getfixturevalue(cfg_string)
|
|
builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
|
cfg["optimizer"] = "muon"
|
|
|
|
# need to load datasets for reward model and process reward model trainer
|
|
if cfg_string in ["rm_cfg", "prm_cfg"]:
|
|
dataset_meta = load_datasets(cfg=cfg)
|
|
|
|
builder.train_dataset = dataset_meta.train_dataset
|
|
builder.eval_dataset = dataset_meta.eval_dataset
|
|
|
|
trainer = builder.build(100)
|
|
|
|
assert trainer.optimizer_cls_and_kwargs is not None
|
|
|
|
from axolotl.contribs.mit.muon import (
|
|
Muon,
|
|
MuonOptimizerFactory,
|
|
)
|
|
|
|
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
|
assert optimizer_cls is MuonOptimizerFactory
|
|
assert optimizer_kwargs["lr"] == 0.00005
|
|
assert optimizer_kwargs["weight_decay"] == 0.01
|
|
assert optimizer_kwargs["betas"] == (0.998, 0.9)
|
|
assert optimizer_kwargs["eps"] == 0.00001
|
|
|
|
# Ensure optimizer is created with correct class
|
|
optim = trainer.create_optimizer()
|
|
assert isinstance(optim, Muon)
|
|
|
|
|
|
class TestTrainerClsPlugin:
|
|
"""
|
|
TestCase class for trainer builder with plugin
|
|
"""
|
|
|
|
def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tokenizer):
|
|
"""
|
|
Test that the trainer cls is not none with plugin
|
|
|
|
Fixes #2693
|
|
"""
|
|
cfg = kto_cfg.copy()
|
|
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
|
|
|
|
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
|
|
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
|
|
try:
|
|
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
|
|
|
builder.build(100)
|
|
except TypeError as e:
|
|
# Error raised if trainer_cls is None
|
|
assert "'tuple' object has no attribute 'config'" not in str(e)
|
|
except Exception:
|
|
# Another error happens, so we passed trainer_cls to builder
|
|
pass
|