models.py -> loaders/ module refactor (#2680)
* models.py -> loaders/ module refactor * refactor ModelLoader class * plugin manager changes * circular import fix * pytest * pytest * minor improvements * fix * minor changes * fix test * remove dead code * coderabbit comments * lint * fix * coderabbit suggestion I liked * more coderabbit * review comments, yak shaving * lint * updating in light of SP ctx manager changes * review comment * review comment 2
This commit is contained in:
@@ -1,13 +1,11 @@
|
||||
"""
|
||||
unit tests for axolotl.core.trainer_builder
|
||||
"""
|
||||
"""Unit tests for axolotl.core.trainer_builder"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
|
||||
@@ -50,7 +48,7 @@ def fixture_tokenizer(cfg):
|
||||
|
||||
@pytest.fixture(name="model")
|
||||
def fixture_model(cfg, tokenizer):
|
||||
return load_model(cfg, tokenizer)
|
||||
return ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
|
||||
class TestHFRLTrainerBuilder:
|
||||
|
||||
@@ -6,9 +6,9 @@ import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
ModelLoader(cfg, tokenizer, inference=False).load()
|
||||
|
||||
@with_temp_dir
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
@@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
ModelLoader(cfg, tokenizer, inference=False).load()
|
||||
|
||||
assert (
|
||||
"torch.jit"
|
||||
|
||||
@@ -6,8 +6,8 @@ import tempfile
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="temp_dir")
|
||||
@@ -58,6 +58,8 @@ class TestLoadModelUtils:
|
||||
ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer="",
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -71,13 +73,8 @@ class TestLoadModelUtils:
|
||||
):
|
||||
self.cfg.output_dir = temp_dir
|
||||
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
|
||||
self.model_loader.model, _ = load_model(
|
||||
self.cfg,
|
||||
self.model_loader.tokenizer,
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
self.model_loader.convert_embedding_modules_dtype(
|
||||
self.model_loader.load()
|
||||
self.model_loader._convert_embedding_modules_dtype(
|
||||
embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||
)
|
||||
for name, module in self.model_loader.model.named_modules():
|
||||
|
||||
@@ -9,11 +9,11 @@ from typing import Optional
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from axolotl.loaders.utils import check_model_config
|
||||
from axolotl.utils import is_comet_available
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.models import check_model_config
|
||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
@@ -1215,6 +1215,20 @@ class TestValidation(BaseValidation):
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"s2_attention": True,
|
||||
"sample_packing": True,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match=r".*shifted-sparse attention does not currently support sample packing*",
|
||||
):
|
||||
validate_config(test_cfg)
|
||||
|
||||
|
||||
class TestTorchCompileValidation(BaseValidation):
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function.
|
||||
"""Test suite for functions in the `axolotl.utils.data.utils` module, focusing on the
|
||||
`deduplicate_and_log_datasets` function.
|
||||
|
||||
Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command.
|
||||
Additionally, this test suite includes tests for functions that indirectly call
|
||||
`deduplicate_and_log_datasets` during the execution of the preprocess command.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
@@ -11,20 +12,19 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_processor, load_tokenizer
|
||||
|
||||
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
|
||||
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
||||
"""
|
||||
Validates deduplication results and size consistency.
|
||||
"""Validates deduplication results and size consistency.
|
||||
|
||||
Parameters:
|
||||
- actual_dataset: Deduplicated dataset.
|
||||
@@ -49,9 +49,7 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
||||
|
||||
|
||||
class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
"""
|
||||
test class for deduplication function in data utils
|
||||
"""
|
||||
"""Test class for deduplication function in data utils"""
|
||||
|
||||
def setUp(self):
|
||||
# Sample data with duplicates
|
||||
@@ -248,7 +246,7 @@ class TestDeduplicateRLDataset:
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
@@ -272,7 +270,7 @@ class TestDeduplicateRLDataset:
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
@@ -411,7 +409,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
|
||||
|
||||
class TestWrongCollisions(unittest.TestCase):
|
||||
"""Creating mock datasets for testing wrong collisions"""
|
||||
"""Creating mock datasets for testing wrong collisions."""
|
||||
|
||||
def setUp(self):
|
||||
self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
"""Module for testing models utils file."""
|
||||
"""Module for `axolotl.loaders`."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.import_utils import is_torch_mps_available
|
||||
|
||||
from axolotl.loaders import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import ModelLoader, load_model
|
||||
|
||||
|
||||
class TestModelsUtils:
|
||||
"""Testing module for models utils."""
|
||||
"""Testing module for `axolotl.loaders`."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
# load config
|
||||
@@ -50,7 +50,8 @@ class TestModelsUtils:
|
||||
device_map = self.cfg.device_map
|
||||
if is_torch_mps_available():
|
||||
device_map = "mps"
|
||||
self.model_loader.set_device_map_config()
|
||||
# pylint: disable=protected-access
|
||||
self.model_loader._set_device_map_config()
|
||||
if is_deepspeed_zero3_enabled():
|
||||
assert "device_map" not in self.model_loader.model_kwargs
|
||||
else:
|
||||
@@ -59,29 +60,6 @@ class TestModelsUtils:
|
||||
# check torch_dtype
|
||||
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
|
||||
|
||||
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"s2_attention": True,
|
||||
"sample_packing": True,
|
||||
"base_model": "",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
}
|
||||
)
|
||||
|
||||
# Mock out call to HF hub
|
||||
with patch(
|
||||
"axolotl.utils.models.load_model_config"
|
||||
) as mocked_load_model_config:
|
||||
mocked_load_model_config.return_value = {}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
# Should error before hitting tokenizer, so we pass in an empty str
|
||||
load_model(cfg, tokenizer="") # type: ignore
|
||||
assert (
|
||||
"shifted-sparse attention does not currently support sample packing"
|
||||
in str(exc.value)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
|
||||
@pytest.mark.parametrize("load_in_8bit", [True, False])
|
||||
@pytest.mark.parametrize("load_in_4bit", [True, False])
|
||||
@@ -99,7 +77,8 @@ class TestModelsUtils:
|
||||
self.cfg.gptq = gptq
|
||||
self.cfg.adapter = adapter
|
||||
|
||||
self.model_loader.set_quantization_config()
|
||||
# pylint: disable=protected-access
|
||||
self.model_loader._set_quantization_config()
|
||||
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
|
||||
assert not (
|
||||
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
||||
@@ -2,9 +2,9 @@
|
||||
tests for loading loras
|
||||
"""
|
||||
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
minimal_config = DictDefault(
|
||||
@@ -46,7 +46,7 @@ class TestLoRALoad:
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer)
|
||||
ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
def test_load_lora_weights_empty_dropout(self):
|
||||
cfg = DictDefault(
|
||||
@@ -67,4 +67,4 @@ class TestLoRALoad:
|
||||
normalize_config(cfg)
|
||||
assert cfg.lora_dropout == 0.0
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer)
|
||||
ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
@@ -6,8 +6,8 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
|
||||
Reference in New Issue
Block a user