Refactor func load_model to class ModelLoader (#1909)
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
95
tests/e2e/test_load_model.py
Normal file
95
tests/e2e/test_load_model.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""Module for testing ModelLoader."""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="temp_dir")
|
||||||
|
def fixture_temp_dir():
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
yield temp_dir
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadModelUtils:
|
||||||
|
"""
|
||||||
|
Testing module testing ModelLoader.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
# load config
|
||||||
|
self.cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"tokenizer_config": "JackFram/llama-68m",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": False,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
ModelLoader(
|
||||||
|
cfg=self.cfg,
|
||||||
|
tokenizer="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
|
||||||
|
def test_convert_embedding_modules_dtype(
|
||||||
|
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||||
|
):
|
||||||
|
self.cfg.output_dir = temp_dir
|
||||||
|
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
|
||||||
|
self.model_loader.model, _ = load_model(
|
||||||
|
self.cfg,
|
||||||
|
self.model_loader.tokenizer,
|
||||||
|
inference=False,
|
||||||
|
reference_model=True,
|
||||||
|
)
|
||||||
|
self.model_loader.convert_embedding_modules_dtype(
|
||||||
|
embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||||
|
)
|
||||||
|
for name, module in self.model_loader.model.named_modules():
|
||||||
|
if (
|
||||||
|
"norm" in name
|
||||||
|
or (before_kbit_train_or_finetune and name.endswith(".gate"))
|
||||||
|
or (
|
||||||
|
any(m in name for m in embedding_modules)
|
||||||
|
and hasattr(module, "weight")
|
||||||
|
)
|
||||||
|
):
|
||||||
|
for _, param in module.named_parameters():
|
||||||
|
assert param.dtype == dist_dtype
|
||||||
@@ -1,18 +1,64 @@
|
|||||||
"""Module for testing models utils file."""
|
"""Module for testing models utils file."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
import unittest
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
|
||||||
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.utils.import_utils import is_torch_mps_available
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model
|
from axolotl.utils.models import ModelLoader, load_model
|
||||||
|
|
||||||
|
|
||||||
class ModelsUtilsTest(unittest.TestCase):
|
class TestModelsUtils:
|
||||||
"""Testing module for models utils."""
|
"""Testing module for models utils."""
|
||||||
|
|
||||||
|
def setup_method(self) -> None:
|
||||||
|
# load config
|
||||||
|
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"model_type": "LlamaForCausalLM",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"load_in_4bit": False,
|
||||||
|
"adapter": "lora",
|
||||||
|
"flash_attention": False,
|
||||||
|
"sample_packing": True,
|
||||||
|
"device_map": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
|
||||||
|
spec=PreTrainedTokenizerBase
|
||||||
|
)
|
||||||
|
self.inference = False # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.reference_model = True # pylint: disable=attribute-defined-outside-init
|
||||||
|
|
||||||
|
# init ModelLoader
|
||||||
|
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
ModelLoader(
|
||||||
|
cfg=self.cfg,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
inference=self.inference,
|
||||||
|
reference_model=self.reference_model,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_set_device_map_config(self):
|
||||||
|
# check device_map
|
||||||
|
device_map = self.cfg.device_map
|
||||||
|
if is_torch_mps_available():
|
||||||
|
device_map = "mps"
|
||||||
|
self.model_loader.set_device_map_config()
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
assert "device_map" not in self.model_loader.model_kwargs
|
||||||
|
else:
|
||||||
|
assert device_map in self.model_loader.model_kwargs["device_map"]
|
||||||
|
|
||||||
|
# check torch_dtype
|
||||||
|
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
|
||||||
|
|
||||||
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -35,3 +81,38 @@ class ModelsUtilsTest(unittest.TestCase):
|
|||||||
"shifted-sparse attention does not currently support sample packing"
|
"shifted-sparse attention does not currently support sample packing"
|
||||||
in str(exc.value)
|
in str(exc.value)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
|
||||||
|
@pytest.mark.parametrize("load_in_8bit", [True, False])
|
||||||
|
@pytest.mark.parametrize("load_in_4bit", [True, False])
|
||||||
|
@pytest.mark.parametrize("gptq", [True, False])
|
||||||
|
def test_set_quantization_config(
|
||||||
|
self,
|
||||||
|
adapter,
|
||||||
|
load_in_8bit,
|
||||||
|
load_in_4bit,
|
||||||
|
gptq,
|
||||||
|
):
|
||||||
|
# init cfg as args
|
||||||
|
self.cfg.load_in_8bit = load_in_8bit
|
||||||
|
self.cfg.load_in_4bit = load_in_4bit
|
||||||
|
self.cfg.gptq = gptq
|
||||||
|
self.cfg.adapter = adapter
|
||||||
|
|
||||||
|
self.model_loader.set_quantization_config()
|
||||||
|
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
|
||||||
|
assert not (
|
||||||
|
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
||||||
|
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
||||||
|
)
|
||||||
|
elif load_in_8bit and self.cfg.adapter is not None:
|
||||||
|
assert self.model_loader.model_kwargs["load_in_8bit"]
|
||||||
|
elif load_in_4bit and self.cfg.adapter is not None:
|
||||||
|
assert self.model_loader.model_kwargs["load_in_4bit"]
|
||||||
|
|
||||||
|
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
|
||||||
|
self.cfg.adapter == "lora" and load_in_8bit
|
||||||
|
):
|
||||||
|
assert self.model_loader.model_kwargs.get(
|
||||||
|
"quantization_config", BitsAndBytesConfig
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user