367 lines
12 KiB
Python
367 lines
12 KiB
Python
"""
|
|
unit tests for axolotl.core.trainer_builder
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
|
|
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
|
|
from axolotl.utils.config import normalize_config
|
|
from axolotl.utils.dict import DictDefault
|
|
from axolotl.utils.models import load_model, load_tokenizer
|
|
from axolotl.utils.schemas.enums import RLType
|
|
|
|
|
|
@pytest.fixture(name="base_cfg")
|
|
def fixture_base_cfg():
|
|
"""
|
|
Base config with all common arguments between SFT and RLHF
|
|
"""
|
|
cfg = DictDefault(
|
|
{
|
|
# Model and tokenizer settings
|
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
"sequence_len": 2048,
|
|
"model_config_type": "llama", # example type
|
|
# Basic training settings
|
|
"micro_batch_size": 2,
|
|
"eval_batch_size": 2,
|
|
"num_epochs": 1,
|
|
"gradient_accumulation_steps": 1,
|
|
"max_steps": 100,
|
|
"val_set_size": 0,
|
|
# Optimizer settings
|
|
"optimizer": "adamw_torch_fused",
|
|
"learning_rate": 0.00005,
|
|
"weight_decay": 0.01,
|
|
"adam_beta1": 0.998,
|
|
"adam_beta2": 0.9,
|
|
"adam_epsilon": 0.00001,
|
|
"max_grad_norm": 1.0,
|
|
# LR scheduler settings
|
|
"lr_scheduler": "cosine",
|
|
"lr_scheduler_kwargs": {"foo": "bar"},
|
|
"warmup_steps": 10,
|
|
"warmup_ratio": None,
|
|
"cosine_min_lr_ratio": 0.1,
|
|
"cosine_constant_lr_ratio": 0.2,
|
|
# Checkpointing and saving
|
|
"save_steps": 100,
|
|
"output_dir": "./model-out",
|
|
"save_safetensors": True,
|
|
"save_total_limit": 4,
|
|
"save_only_model": False,
|
|
# Hardware/performance settings
|
|
"gradient_checkpointing": False,
|
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
|
"dataloader_num_workers": 1,
|
|
"dataloader_pin_memory": True,
|
|
"dataloader_prefetch_factor": 2,
|
|
"sequence_parallel_degree": 1,
|
|
# Dtype
|
|
"fp16": False,
|
|
"bf16": False,
|
|
"tf32": False,
|
|
# Logging and evaluation
|
|
"logging_steps": 10,
|
|
"eval_steps": 50,
|
|
"eval_strategy": "steps",
|
|
"save_strategy": "steps",
|
|
"include_tokens_per_second": True,
|
|
# Other common settings
|
|
"seed": 42,
|
|
"remove_unused_columns": True,
|
|
"ddp_timeout": 1800,
|
|
"ddp_bucket_cap_mb": 25,
|
|
"ddp_broadcast_buffers": False,
|
|
}
|
|
)
|
|
|
|
normalize_config(cfg)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="dpo_cfg")
|
|
def fixture_dpo_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.DPO,
|
|
"dpo_use_weighting": True,
|
|
"dpo_use_logits_to_keep": True,
|
|
"dpo_label_smoothing": 0.1,
|
|
"beta": 0.1, # DPO beta
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="orpo_cfg")
|
|
def fixture_orpo_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.ORPO,
|
|
"orpo_alpha": 0.1,
|
|
"max_prompt_len": 512,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="kto_cfg")
|
|
def fixture_kto_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.KTO,
|
|
"kto_desirable_weight": 1.0,
|
|
"kto_undesirable_weight": 1.0,
|
|
"max_prompt_len": 512,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="grpo_cfg")
|
|
def fixture_grpo_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.GRPO,
|
|
"trl": DictDefault(
|
|
{
|
|
"beta": 0.001,
|
|
"max_completion_length": 256,
|
|
"use_vllm": False, # run on CPU
|
|
# "vllm_device": "auto",
|
|
# "vllm_gpu_memory_utilization": 0.15,
|
|
"num_generations": 4,
|
|
"reward_funcs": ["rewards.rand_reward_func"],
|
|
}
|
|
),
|
|
# Must be evenly divisible by num_generations
|
|
"micro_batch_size": 4,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="ipo_cfg")
|
|
def fixture_ipo_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": RLType.IPO,
|
|
"dpo_label_smoothing": 0.1,
|
|
"beta": 0.1,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="sft_cfg")
|
|
def fixture_sft_cfg(base_cfg):
|
|
cfg = base_cfg.copy()
|
|
cfg.update(
|
|
{
|
|
"rl": None,
|
|
"sample_packing": False,
|
|
"eval_sample_packing": False,
|
|
"flash_attention": False,
|
|
}
|
|
)
|
|
return cfg
|
|
|
|
|
|
@pytest.fixture(name="tokenizer")
|
|
def fixture_tokenizer(base_cfg):
|
|
return load_tokenizer(base_cfg)
|
|
|
|
|
|
@pytest.fixture(name="model")
|
|
def fixture_model(base_cfg, tokenizer):
|
|
model, _ = load_model(base_cfg, tokenizer)
|
|
return model
|
|
|
|
|
|
class TestHFRLTrainerBuilder:
|
|
"""
|
|
TestCase class for RLHF trainer builders
|
|
"""
|
|
|
|
def _test_common_training_arguments(self, training_arguments, rl: str):
|
|
"""Helper to test common arguments across all variants"""
|
|
# Basic training settings
|
|
if rl == "grpo":
|
|
# grpo_cfg's micro_batch_size is diff from others
|
|
assert training_arguments.per_device_train_batch_size == 4
|
|
else:
|
|
assert training_arguments.per_device_train_batch_size == 2
|
|
assert training_arguments.gradient_accumulation_steps == 1
|
|
assert training_arguments.max_steps == 100
|
|
|
|
# Optimizer settings
|
|
assert training_arguments.learning_rate == 0.00005
|
|
assert training_arguments.weight_decay == 0.01
|
|
assert training_arguments.adam_beta1 == 0.998
|
|
assert training_arguments.adam_beta2 == 0.9
|
|
assert training_arguments.adam_epsilon == 0.00001
|
|
assert training_arguments.max_grad_norm == 1.0
|
|
|
|
# LR scheduler settings
|
|
assert training_arguments.lr_scheduler_type == "cosine"
|
|
assert training_arguments.warmup_steps == 10
|
|
assert training_arguments.cosine_min_lr_ratio == 0.1
|
|
assert training_arguments.cosine_constant_lr_ratio == 0.2
|
|
|
|
# Other settings
|
|
assert training_arguments.dataloader_num_workers == 1
|
|
assert training_arguments.dataloader_pin_memory is True
|
|
assert training_arguments.gradient_checkpointing is False
|
|
|
|
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
|
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
|
training_arguments = builder.build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=dpo_cfg.rl)
|
|
# DPO specific
|
|
assert training_arguments.beta == 0.1
|
|
assert hasattr(training_arguments, "use_weighting")
|
|
assert training_arguments.use_weighting is True
|
|
|
|
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
|
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
|
training_arguments = builder.build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
|
|
# ORPO specific
|
|
assert training_arguments.beta == 0.1 # maps from orpo_alpha
|
|
assert training_arguments.max_prompt_length == 512
|
|
|
|
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
|
|
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
|
|
training_arguments = builder.build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=kto_cfg.rl)
|
|
# KTO specific
|
|
assert training_arguments.desirable_weight == 1.0
|
|
assert training_arguments.undesirable_weight == 1.0
|
|
assert training_arguments.max_prompt_length == 512
|
|
|
|
def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
|
|
def _write_rewards_file(rewards_dir: Path):
|
|
"""
|
|
Writes reward function to local tmp path to be loaded on trainer building
|
|
"""
|
|
# Create rewards.py in a directory we can import from
|
|
rewards_dir.mkdir()
|
|
rewards_file = rewards_dir / "rewards.py"
|
|
rewards_file.write_text(
|
|
"""import random
|
|
def rand_reward_func(prompts, completions) -> list[float]:
|
|
return [random.uniform(0, 1) for _ in completions]
|
|
"""
|
|
)
|
|
|
|
rewards_dir = tmp_path / "rewards_test"
|
|
_write_rewards_file(rewards_dir)
|
|
|
|
# Add the directory to Python path so we can import the module
|
|
sys.path.insert(0, str(rewards_dir))
|
|
|
|
try:
|
|
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
|
|
training_arguments = builder.build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
|
|
# GRPO specific
|
|
assert training_arguments.beta == 0.001
|
|
assert training_arguments.max_completion_length == 256
|
|
assert training_arguments.use_vllm is False
|
|
# assert training_arguments.vllm_device == "auto"
|
|
# assert training_arguments.vllm_gpu_memory_utilization == 0.15
|
|
assert training_arguments.num_generations == 4
|
|
|
|
# Test trainer creation to verify reward_funcs
|
|
trainer = builder.build(100)
|
|
|
|
# Verify reward functions are properly loaded
|
|
assert len(trainer.reward_funcs) == 1
|
|
assert trainer.reward_funcs[0].__module__ == "rewards"
|
|
assert trainer.reward_funcs[0].__name__ == "rand_reward_func"
|
|
finally:
|
|
# remove imported module from path
|
|
if str(rewards_dir) in sys.path:
|
|
sys.path.remove(str(rewards_dir))
|
|
|
|
def test_ipo_training_arguments(self, ipo_cfg, model, tokenizer):
|
|
builder = HFRLTrainerBuilder(ipo_cfg, model, tokenizer)
|
|
training_arguments = builder.build_training_arguments(100)
|
|
|
|
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
|
|
# IPO specific
|
|
assert training_arguments.beta == 0.1
|
|
assert training_arguments.loss_type == "ipo"
|
|
|
|
|
|
class TestHFCausalTrainerBuilder:
|
|
"""
|
|
TestCase class for SFT trainer builder
|
|
"""
|
|
|
|
def test_training_arguments(self, sft_cfg, model, tokenizer):
|
|
builder = HFCausalTrainerBuilder(sft_cfg, model, tokenizer)
|
|
trainer = builder.build(100)
|
|
training_arguments = trainer.args
|
|
|
|
# Test common arguments
|
|
assert training_arguments.per_device_train_batch_size == 2
|
|
assert training_arguments.gradient_accumulation_steps == 1
|
|
assert training_arguments.max_steps == 100
|
|
|
|
assert training_arguments.learning_rate == 0.00005
|
|
assert training_arguments.weight_decay == 0.01
|
|
assert training_arguments.adam_beta1 == 0.998
|
|
assert training_arguments.adam_beta2 == 0.9
|
|
assert training_arguments.adam_epsilon == 0.00001
|
|
assert training_arguments.max_grad_norm == 1.0
|
|
|
|
assert training_arguments.lr_scheduler_type == "cosine"
|
|
assert training_arguments.warmup_steps == 10
|
|
assert training_arguments.cosine_min_lr_ratio == 0.1
|
|
|
|
assert training_arguments.dataloader_num_workers == 1
|
|
assert training_arguments.dataloader_pin_memory is True
|
|
assert training_arguments.gradient_checkpointing is False
|
|
|
|
# SFT specific
|
|
assert training_arguments.sample_packing is False
|
|
assert training_arguments.eval_sample_packing is False
|
|
|
|
|
|
class TestTrainerClsPlugin:
|
|
"""
|
|
TestCase class for trainer builder with plugin
|
|
"""
|
|
|
|
def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tokenizer):
|
|
"""
|
|
Test that the trainer cls is not none with plugin
|
|
|
|
Fixes #2693
|
|
"""
|
|
kto_cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
|
|
|
|
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
|
|
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
|
|
with pytest.raises(
|
|
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
|
|
):
|
|
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
|
|
|
|
builder.build(100)
|