Merge branch 'main' into 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length

This commit is contained in:
mhenrichsen
2025-05-27 12:31:31 +02:00
committed by GitHub
75 changed files with 2850 additions and 2821 deletions

View File

@@ -1,13 +1,12 @@
"""
unit tests for axolotl.core.trainer_builder
"""
"""Unit tests for axolotl.core.trainer_builder"""
import pytest
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.schemas.enums import RLType
@pytest.fixture(name="cfg")
@@ -49,7 +48,7 @@ def fixture_tokenizer(cfg):
@pytest.fixture(name="model")
def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)
return ModelLoader(cfg, tokenizer).load()
class TestHFRLTrainerBuilder:
@@ -65,3 +64,27 @@ class TestHFRLTrainerBuilder:
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
class TestTrainerClsPlugin:
"""
TestCase class for trainer builder with plugin
"""
def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer):
"""
Test that the trainer cls is not none with plugin
Fixes #2693
"""
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
cfg.rl = RLType.KTO
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
with pytest.raises(
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
builder.build(100)

View File

@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"""
)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally:
recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,

View File

@@ -6,9 +6,9 @@ import unittest
import transformers
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from ..utils import with_temp_dir
@@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
ModelLoader(cfg, tokenizer, inference=False).load()
@with_temp_dir
def test_mistral_multipack(self, temp_dir):
@@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
ModelLoader(cfg, tokenizer, inference=False).load()
assert (
"torch.jit"

View File

@@ -10,7 +10,7 @@ import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.attention.ring_attn import (
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
@@ -84,16 +84,16 @@ class TestRingAttention:
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group returns None when no group has been registered."""
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Get the group without registration
group = get_ring_attn_group()
# Verify that None was returned
assert group is None
# Verify that RuntimeError is raised when no group is registered
with pytest.raises(
RuntimeError, match="register_ring_attn\\(\\) not yet called"
):
get_ring_attn_group()
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@@ -313,18 +313,21 @@ class TestApplySequenceParallelism:
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
MagicMock,
)
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
def test_world_size_one(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
@@ -336,8 +339,11 @@ class TestApplySequenceParallelism:
# Should return the original batch unchanged
assert result == sequence_parallel_batch
def test_batch_ring_rank0(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
@@ -359,8 +365,11 @@ class TestApplySequenceParallelism:
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
@@ -419,8 +428,13 @@ class TestApplySequenceParallelism:
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_partial_application(
self, mock_get_ring_attn_group, sequence_parallel_batch
):
"""Test that we can create a partially applied version of the function."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()

View File

@@ -6,8 +6,8 @@ import tempfile
import pytest
import torch
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
@pytest.fixture(name="temp_dir")
@@ -58,6 +58,8 @@ class TestLoadModelUtils:
ModelLoader(
cfg=self.cfg,
tokenizer="",
inference=False,
reference_model=True,
)
)
@@ -71,13 +73,8 @@ class TestLoadModelUtils:
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
self.model_loader.model, _ = load_model(
self.cfg,
self.model_loader.tokenizer,
inference=False,
reference_model=True,
)
self.model_loader.convert_embedding_modules_dtype(
self.model_loader.load()
self.model_loader._convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():

View File

@@ -9,11 +9,11 @@ from typing import Optional
import pytest
from pydantic import ValidationError
from axolotl.loaders.utils import check_model_config
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import check_model_config
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -1215,6 +1215,20 @@ class TestValidation(BaseValidation):
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self, minimal_cfg):
test_cfg = DictDefault(
{
"s2_attention": True,
"sample_packing": True,
}
| minimal_cfg
)
with pytest.raises(
ValidationError,
match=r".*shifted-sparse attention does not currently support sample packing*",
):
validate_config(test_cfg)
class TestTorchCompileValidation(BaseValidation):
"""

View File

@@ -1,7 +1,8 @@
"""
Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function.
"""Test suite for functions in the `axolotl.utils.data.utils` module, focusing on the
`deduplicate_and_log_datasets` function.
Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command.
Additionally, this test suite includes tests for functions that indirectly call
`deduplicate_and_log_datasets` during the execution of the preprocess command.
"""
import hashlib
@@ -11,20 +12,19 @@ from unittest.mock import patch
import pytest
from datasets import Dataset
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
from tests.hf_offline_utils import enable_hf_offline
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
"""
Validates deduplication results and size consistency.
"""Validates deduplication results and size consistency.
Parameters:
- actual_dataset: Deduplicated dataset.
@@ -49,9 +49,7 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
class TestDeduplicateIndividualFunctions(unittest.TestCase):
"""
test class for deduplication function in data utils
"""
"""Test class for deduplication function in data utils"""
def setUp(self):
# Sample data with duplicates
@@ -248,7 +246,7 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
@@ -272,7 +270,7 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
@@ -411,7 +409,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
class TestWrongCollisions(unittest.TestCase):
"""Creating mock datasets for testing wrong collisions"""
"""Creating mock datasets for testing wrong collisions."""
def setUp(self):
self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]}

View File

@@ -1,18 +1,18 @@
"""Module for testing models utils file."""
"""Module for `axolotl.loaders`."""
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils.import_utils import is_torch_mps_available
from axolotl.loaders import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model
class TestModelsUtils:
"""Testing module for models utils."""
"""Testing module for `axolotl.loaders`."""
def setup_method(self) -> None:
# load config
@@ -50,7 +50,8 @@ class TestModelsUtils:
device_map = self.cfg.device_map
if is_torch_mps_available():
device_map = "mps"
self.model_loader.set_device_map_config()
# pylint: disable=protected-access
self.model_loader._set_device_map_config()
if is_deepspeed_zero3_enabled():
assert "device_map" not in self.model_loader.model_kwargs
else:
@@ -59,29 +60,6 @@ class TestModelsUtils:
# check torch_dtype
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
cfg = DictDefault(
{
"s2_attention": True,
"sample_packing": True,
"base_model": "",
"model_type": "AutoModelForCausalLM",
}
)
# Mock out call to HF hub
with patch(
"axolotl.utils.models.load_model_config"
) as mocked_load_model_config:
mocked_load_model_config.return_value = {}
with pytest.raises(ValueError) as exc:
# Should error before hitting tokenizer, so we pass in an empty str
load_model(cfg, tokenizer="") # type: ignore
assert (
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
)
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
@pytest.mark.parametrize("load_in_8bit", [True, False])
@pytest.mark.parametrize("load_in_4bit", [True, False])
@@ -99,7 +77,8 @@ class TestModelsUtils:
self.cfg.gptq = gptq
self.cfg.adapter = adapter
self.model_loader.set_quantization_config()
# pylint: disable=protected-access
self.model_loader._set_quantization_config()
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
assert not (
hasattr(self.model_loader.model_kwargs, "load_in_8bit")

View File

@@ -2,9 +2,9 @@
tests for loading loras
"""
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# pylint: disable=duplicate-code
minimal_config = DictDefault(
@@ -46,7 +46,7 @@ class TestLoRALoad:
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)
ModelLoader(cfg, tokenizer).load()
def test_load_lora_weights_empty_dropout(self):
cfg = DictDefault(
@@ -67,4 +67,4 @@ class TestLoRALoad:
normalize_config(cfg)
assert cfg.lora_dropout == 0.0
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)
ModelLoader(cfg, tokenizer).load()

View File

@@ -6,8 +6,8 @@ import unittest
import pytest
from axolotl.loaders import load_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
from tests.hf_offline_utils import enable_hf_offline