models.py -> loaders/ module refactor (#2680)
* models.py -> loaders/ module refactor * refactor ModelLoader class * plugin manager changes * circular import fix * pytest * pytest * minor improvements * fix * minor changes * fix test * remove dead code * coderabbit comments * lint * fix * coderabbit suggestion I liked * more coderabbit * review comments, yak shaving * lint * updating in light of SP ctx manager changes * review comment * review comment 2
This commit is contained in:
@@ -1,194 +0,0 @@
|
||||
"""Module for testing models utils file."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.import_utils import is_torch_mps_available
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import ModelLoader, load_model
|
||||
|
||||
|
||||
class TestModelsUtils:
|
||||
"""Testing module for models utils."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
# load config
|
||||
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"load_in_8bit": True,
|
||||
"load_in_4bit": False,
|
||||
"adapter": "lora",
|
||||
"flash_attention": False,
|
||||
"sample_packing": True,
|
||||
"device_map": "auto",
|
||||
}
|
||||
)
|
||||
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
|
||||
spec=PreTrainedTokenizerBase
|
||||
)
|
||||
self.inference = False # pylint: disable=attribute-defined-outside-init
|
||||
self.reference_model = True # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
# init ModelLoader
|
||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||
ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
inference=self.inference,
|
||||
reference_model=self.reference_model,
|
||||
)
|
||||
)
|
||||
|
||||
def test_set_device_map_config(self):
|
||||
# check device_map
|
||||
device_map = self.cfg.device_map
|
||||
if is_torch_mps_available():
|
||||
device_map = "mps"
|
||||
self.model_loader.set_device_map_config()
|
||||
if is_deepspeed_zero3_enabled():
|
||||
assert "device_map" not in self.model_loader.model_kwargs
|
||||
else:
|
||||
assert device_map in self.model_loader.model_kwargs["device_map"]
|
||||
|
||||
# check torch_dtype
|
||||
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
|
||||
|
||||
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"s2_attention": True,
|
||||
"sample_packing": True,
|
||||
"base_model": "",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
}
|
||||
)
|
||||
|
||||
# Mock out call to HF hub
|
||||
with patch(
|
||||
"axolotl.utils.models.load_model_config"
|
||||
) as mocked_load_model_config:
|
||||
mocked_load_model_config.return_value = {}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
# Should error before hitting tokenizer, so we pass in an empty str
|
||||
load_model(cfg, tokenizer="") # type: ignore
|
||||
assert (
|
||||
"shifted-sparse attention does not currently support sample packing"
|
||||
in str(exc.value)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
|
||||
@pytest.mark.parametrize("load_in_8bit", [True, False])
|
||||
@pytest.mark.parametrize("load_in_4bit", [True, False])
|
||||
@pytest.mark.parametrize("gptq", [True, False])
|
||||
def test_set_quantization_config(
|
||||
self,
|
||||
adapter,
|
||||
load_in_8bit,
|
||||
load_in_4bit,
|
||||
gptq,
|
||||
):
|
||||
# init cfg as args
|
||||
self.cfg.load_in_8bit = load_in_8bit
|
||||
self.cfg.load_in_4bit = load_in_4bit
|
||||
self.cfg.gptq = gptq
|
||||
self.cfg.adapter = adapter
|
||||
|
||||
self.model_loader.set_quantization_config()
|
||||
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
|
||||
assert not (
|
||||
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
||||
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
||||
)
|
||||
elif load_in_8bit and self.cfg.adapter is not None:
|
||||
assert self.model_loader.model_kwargs["load_in_8bit"]
|
||||
elif load_in_4bit and self.cfg.adapter is not None:
|
||||
assert self.model_loader.model_kwargs["load_in_4bit"]
|
||||
|
||||
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
|
||||
self.cfg.adapter == "lora" and load_in_8bit
|
||||
):
|
||||
assert self.model_loader.model_kwargs.get(
|
||||
"quantization_config", BitsAndBytesConfig
|
||||
)
|
||||
|
||||
def test_message_property_mapping(self):
|
||||
"""Test message property mapping configuration validation"""
|
||||
from axolotl.utils.schemas.datasets import SFTDataset
|
||||
|
||||
# Test legacy fields are mapped orrectly
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="role_field",
|
||||
message_field_content="content_field",
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role_field",
|
||||
"content": "content_field",
|
||||
}
|
||||
|
||||
# Test direct message_property_mapping works
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_property_mappings={
|
||||
"role": "custom_role",
|
||||
"content": "custom_content",
|
||||
},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "custom_role",
|
||||
"content": "custom_content",
|
||||
}
|
||||
|
||||
# Test both legacy and new fields work when they match
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="same_role",
|
||||
message_property_mappings={"role": "same_role"},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "same_role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
# Test both legacy and new fields work when they don't overlap
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="role_field",
|
||||
message_property_mappings={"content": "content_field"},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role_field",
|
||||
"content": "content_field",
|
||||
}
|
||||
|
||||
# Test no role or content provided
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
# Test error when legacy and new fields conflict
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="legacy_role",
|
||||
message_property_mappings={"role": "different_role"},
|
||||
)
|
||||
assert "Conflicting message role fields" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SFTDataset(
|
||||
path="test_path",
|
||||
message_field_content="legacy_content",
|
||||
message_property_mappings={"content": "different_content"},
|
||||
)
|
||||
assert "Conflicting message content fields" in str(exc_info.value)
|
||||
Reference in New Issue
Block a user