* fix: transformers deprecate load_in_Xbit in model_kwargs * fix: test to read from quantization_config kwarg * fix: test * fix: access * fix: test weirdly entering incorrect config
219 lines
7.3 KiB
Python
219 lines
7.3 KiB
Python
"""Module for `axolotl.loaders`."""
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
|
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
from transformers.utils.import_utils import is_torch_mps_available
|
|
|
|
from axolotl.loaders import ModelLoader
|
|
from axolotl.utils.dict import DictDefault
|
|
from axolotl.utils.distributed import _get_parallel_config_kwargs
|
|
|
|
|
|
class TestModelsUtils:
|
|
"""Testing module for `axolotl.loaders`."""
|
|
|
|
def setup_method(self) -> None:
|
|
# load config
|
|
self.cfg = DictDefault(
|
|
{
|
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
"model_type": "AutoModelForCausalLM",
|
|
"tokenizer_type": "AutoTokenizer",
|
|
"load_in_8bit": True,
|
|
"load_in_4bit": False,
|
|
"adapter": "lora",
|
|
"flash_attention": False,
|
|
"sample_packing": True,
|
|
"device_map": "auto",
|
|
}
|
|
)
|
|
self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
|
self.inference = False
|
|
self.reference_model = True
|
|
|
|
# init ModelLoader
|
|
self.model_loader = ModelLoader(
|
|
cfg=self.cfg,
|
|
tokenizer=self.tokenizer,
|
|
inference=self.inference,
|
|
reference_model=self.reference_model,
|
|
)
|
|
|
|
def test_set_device_map_config(self):
|
|
# check device_map
|
|
device_map = self.cfg.device_map
|
|
if is_torch_mps_available():
|
|
device_map = "mps"
|
|
|
|
self.model_loader._set_device_map_config()
|
|
if is_deepspeed_zero3_enabled():
|
|
assert "device_map" not in self.model_loader.model_kwargs
|
|
else:
|
|
assert device_map in self.model_loader.model_kwargs["device_map"]
|
|
|
|
# check torch_dtype
|
|
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
|
|
|
|
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
|
|
@pytest.mark.parametrize("load_in_8bit", [True, False])
|
|
@pytest.mark.parametrize("load_in_4bit", [True, False])
|
|
@pytest.mark.parametrize("gptq", [True, False])
|
|
def test_set_quantization_config(
|
|
self,
|
|
adapter,
|
|
load_in_8bit,
|
|
load_in_4bit,
|
|
gptq,
|
|
):
|
|
# init cfg as args
|
|
self.cfg.load_in_8bit = load_in_8bit
|
|
self.cfg.load_in_4bit = load_in_4bit
|
|
self.cfg.gptq = gptq
|
|
self.cfg.adapter = adapter
|
|
|
|
self.model_loader._set_quantization_config()
|
|
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
|
|
assert not (
|
|
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
|
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
|
)
|
|
|
|
if self.cfg.adapter == "qlora" and load_in_4bit:
|
|
assert isinstance(
|
|
self.model_loader.model_kwargs.get("quantization_config"),
|
|
BitsAndBytesConfig,
|
|
)
|
|
|
|
assert (
|
|
self.model_loader.model_kwargs["quantization_config"]._load_in_4bit
|
|
is True
|
|
)
|
|
if self.cfg.adapter == "lora" and load_in_8bit:
|
|
assert isinstance(
|
|
self.model_loader.model_kwargs.get("quantization_config"),
|
|
BitsAndBytesConfig,
|
|
)
|
|
|
|
assert (
|
|
self.model_loader.model_kwargs["quantization_config"]._load_in_8bit
|
|
is True
|
|
)
|
|
|
|
def test_message_property_mapping(self):
|
|
"""Test message property mapping configuration validation"""
|
|
from axolotl.utils.schemas.datasets import SFTDataset
|
|
|
|
# Test legacy fields are mapped orrectly
|
|
dataset = SFTDataset(
|
|
path="test_path",
|
|
message_field_role="role_field",
|
|
message_field_content="content_field",
|
|
)
|
|
assert dataset.message_property_mappings == {
|
|
"role": "role_field",
|
|
"content": "content_field",
|
|
}
|
|
|
|
# Test direct message_property_mapping works
|
|
dataset = SFTDataset(
|
|
path="test_path",
|
|
message_property_mappings={
|
|
"role": "custom_role",
|
|
"content": "custom_content",
|
|
},
|
|
)
|
|
assert dataset.message_property_mappings == {
|
|
"role": "custom_role",
|
|
"content": "custom_content",
|
|
}
|
|
|
|
# Test both legacy and new fields work when they match
|
|
dataset = SFTDataset(
|
|
path="test_path",
|
|
message_field_role="same_role",
|
|
message_property_mappings={"role": "same_role"},
|
|
)
|
|
assert dataset.message_property_mappings == {
|
|
"role": "same_role",
|
|
"content": "content",
|
|
}
|
|
|
|
# Test both legacy and new fields work when they don't overlap
|
|
dataset = SFTDataset(
|
|
path="test_path",
|
|
message_field_role="role_field",
|
|
message_property_mappings={"content": "content_field"},
|
|
)
|
|
assert dataset.message_property_mappings == {
|
|
"role": "role_field",
|
|
"content": "content_field",
|
|
}
|
|
|
|
# Test no role or content provided
|
|
dataset = SFTDataset(
|
|
path="test_path",
|
|
)
|
|
assert dataset.message_property_mappings == {
|
|
"role": "role",
|
|
"content": "content",
|
|
}
|
|
|
|
# Test error when legacy and new fields conflict
|
|
with pytest.raises(ValueError) as exc_info:
|
|
SFTDataset(
|
|
path="test_path",
|
|
message_field_role="legacy_role",
|
|
message_property_mappings={"role": "different_role"},
|
|
)
|
|
assert "Conflicting message role fields" in str(exc_info.value)
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
SFTDataset(
|
|
path="test_path",
|
|
message_field_content="legacy_content",
|
|
message_property_mappings={"content": "different_content"},
|
|
)
|
|
assert "Conflicting message content fields" in str(exc_info.value)
|
|
|
|
@pytest.mark.parametrize(
|
|
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
|
|
[
|
|
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
|
|
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
|
|
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
|
|
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
|
|
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
|
|
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
|
|
],
|
|
)
|
|
def test_get_parallel_config_kwargs(
|
|
self,
|
|
world_size,
|
|
tensor_parallel_size,
|
|
context_parallel_size,
|
|
dp_shard_size,
|
|
dp_replicate_size,
|
|
is_fsdp,
|
|
expected,
|
|
):
|
|
res = _get_parallel_config_kwargs(
|
|
world_size,
|
|
tensor_parallel_size,
|
|
context_parallel_size,
|
|
dp_shard_size,
|
|
dp_replicate_size,
|
|
is_fsdp,
|
|
)
|
|
|
|
if expected[0] > 1:
|
|
assert res["tp_size"] == expected[0]
|
|
if expected[1] > 1:
|
|
assert res["cp_size"] == expected[1]
|
|
if expected[2] > 1:
|
|
assert res["dp_shard_size"] == expected[2]
|
|
if expected[3] > 1:
|
|
assert res["dp_replicate_size"] == expected[3]
|