Files
axolotl/tests/e2e/test_load_model.py
Dan Saunders 79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00

91 lines
2.9 KiB
Python

"""Module for testing ModelLoader."""
import shutil
import tempfile
import pytest
import torch
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="temp_dir")
def fixture_temp_dir():
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
class TestLoadModelUtils:
"""
Testing module testing ModelLoader.
"""
def setup_method(self):
# load config
self.cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"load_in_8bit": False,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"tensor_parallel_size": 1,
"context_parallel_size": 1,
}
)
self.model_loader = ModelLoader(
cfg=self.cfg,
tokenizer="",
inference=False,
reference_model=True,
)
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
@pytest.mark.parametrize(
"dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
)
@pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
def test_convert_embedding_modules_dtype(
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg)
self.model_loader.load()
self.model_loader._convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():
if (
"norm" in name
or (before_kbit_train_or_finetune and name.endswith(".gate"))
or (
any(m in name for m in embedding_modules)
and hasattr(module, "weight")
)
):
for _, param in module.named_parameters():
assert param.dtype == dist_dtype