models.py -> loaders/ module refactor (#2680)

* models.py -> loaders/ module refactor

* refactor ModelLoader class

* plugin manager changes

* circular import fix

* pytest

* pytest

* minor improvements

* fix

* minor changes

* fix test

* remove dead code

* coderabbit comments

* lint

* fix

* coderabbit suggestion I liked

* more coderabbit

* review comments, yak shaving

* lint

* updating in light of SP ctx manager changes

* review comment

* review comment 2
This commit is contained in:
Dan Saunders
2025-05-23 15:51:11 -04:00
committed by GitHub
parent 8cde256db2
commit b5f1e53a0f
33 changed files with 2249 additions and 2039 deletions

View File

@@ -1,13 +1,11 @@
"""
unit tests for axolotl.core.trainer_builder
"""
"""Unit tests for axolotl.core.trainer_builder"""
import pytest
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.schemas.enums import RLType
@@ -50,7 +48,7 @@ def fixture_tokenizer(cfg):
@pytest.fixture(name="model")
def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)
return ModelLoader(cfg, tokenizer).load()
class TestHFRLTrainerBuilder:

View File

@@ -6,9 +6,9 @@ import unittest
import transformers
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from ..utils import with_temp_dir
@@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
ModelLoader(cfg, tokenizer, inference=False).load()
@with_temp_dir
def test_mistral_multipack(self, temp_dir):
@@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
ModelLoader(cfg, tokenizer, inference=False).load()
assert (
"torch.jit"

View File

@@ -6,8 +6,8 @@ import tempfile
import pytest
import torch
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
@pytest.fixture(name="temp_dir")
@@ -58,6 +58,8 @@ class TestLoadModelUtils:
ModelLoader(
cfg=self.cfg,
tokenizer="",
inference=False,
reference_model=True,
)
)
@@ -71,13 +73,8 @@ class TestLoadModelUtils:
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
self.model_loader.model, _ = load_model(
self.cfg,
self.model_loader.tokenizer,
inference=False,
reference_model=True,
)
self.model_loader.convert_embedding_modules_dtype(
self.model_loader.load()
self.model_loader._convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():

View File

@@ -9,11 +9,11 @@ from typing import Optional
import pytest
from pydantic import ValidationError
from axolotl.loaders.utils import check_model_config
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import check_model_config
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -1215,6 +1215,20 @@ class TestValidation(BaseValidation):
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self, minimal_cfg):
test_cfg = DictDefault(
{
"s2_attention": True,
"sample_packing": True,
}
| minimal_cfg
)
with pytest.raises(
ValidationError,
match=r".*shifted-sparse attention does not currently support sample packing*",
):
validate_config(test_cfg)
class TestTorchCompileValidation(BaseValidation):
"""

View File

@@ -1,7 +1,8 @@
"""
Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function.
"""Test suite for functions in the `axolotl.utils.data.utils` module, focusing on the
`deduplicate_and_log_datasets` function.
Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command.
Additionally, this test suite includes tests for functions that indirectly call
`deduplicate_and_log_datasets` during the execution of the preprocess command.
"""
import hashlib
@@ -11,20 +12,19 @@ from unittest.mock import patch
import pytest
from datasets import Dataset
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
from tests.hf_offline_utils import enable_hf_offline
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
"""
Validates deduplication results and size consistency.
"""Validates deduplication results and size consistency.
Parameters:
- actual_dataset: Deduplicated dataset.
@@ -49,9 +49,7 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
class TestDeduplicateIndividualFunctions(unittest.TestCase):
"""
test class for deduplication function in data utils
"""
"""Test class for deduplication function in data utils"""
def setUp(self):
# Sample data with duplicates
@@ -248,7 +246,7 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
@@ -272,7 +270,7 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
@@ -411,7 +409,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
class TestWrongCollisions(unittest.TestCase):
"""Creating mock datasets for testing wrong collisions"""
"""Creating mock datasets for testing wrong collisions."""
def setUp(self):
self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]}

View File

@@ -1,18 +1,18 @@
"""Module for testing models utils file."""
"""Module for `axolotl.loaders`."""
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils.import_utils import is_torch_mps_available
from axolotl.loaders import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model
class TestModelsUtils:
"""Testing module for models utils."""
"""Testing module for `axolotl.loaders`."""
def setup_method(self) -> None:
# load config
@@ -50,7 +50,8 @@ class TestModelsUtils:
device_map = self.cfg.device_map
if is_torch_mps_available():
device_map = "mps"
self.model_loader.set_device_map_config()
# pylint: disable=protected-access
self.model_loader._set_device_map_config()
if is_deepspeed_zero3_enabled():
assert "device_map" not in self.model_loader.model_kwargs
else:
@@ -59,29 +60,6 @@ class TestModelsUtils:
# check torch_dtype
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
cfg = DictDefault(
{
"s2_attention": True,
"sample_packing": True,
"base_model": "",
"model_type": "AutoModelForCausalLM",
}
)
# Mock out call to HF hub
with patch(
"axolotl.utils.models.load_model_config"
) as mocked_load_model_config:
mocked_load_model_config.return_value = {}
with pytest.raises(ValueError) as exc:
# Should error before hitting tokenizer, so we pass in an empty str
load_model(cfg, tokenizer="") # type: ignore
assert (
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
)
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
@pytest.mark.parametrize("load_in_8bit", [True, False])
@pytest.mark.parametrize("load_in_4bit", [True, False])
@@ -99,7 +77,8 @@ class TestModelsUtils:
self.cfg.gptq = gptq
self.cfg.adapter = adapter
self.model_loader.set_quantization_config()
# pylint: disable=protected-access
self.model_loader._set_quantization_config()
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
assert not (
hasattr(self.model_loader.model_kwargs, "load_in_8bit")

View File

@@ -2,9 +2,9 @@
tests for loading loras
"""
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# pylint: disable=duplicate-code
minimal_config = DictDefault(
@@ -46,7 +46,7 @@ class TestLoRALoad:
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)
ModelLoader(cfg, tokenizer).load()
def test_load_lora_weights_empty_dropout(self):
cfg = DictDefault(
@@ -67,4 +67,4 @@ class TestLoRALoad:
normalize_config(cfg)
assert cfg.lora_dropout == 0.0
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)
ModelLoader(cfg, tokenizer).load()

View File

@@ -6,8 +6,8 @@ import unittest
import pytest
from axolotl.loaders import load_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
from tests.hf_offline_utils import enable_hf_offline