models.py -> loaders/ module refactor (#2680)
* models.py -> loaders/ module refactor * refactor ModelLoader class * plugin manager changes * circular import fix * pytest * pytest * minor improvements * fix * minor changes * fix test * remove dead code * coderabbit comments * lint * fix * coderabbit suggestion I liked * more coderabbit * review comments, yak shaving * lint * updating in light of SP ctx manager changes * review comment * review comment 2
This commit is contained in:
@@ -6,9 +6,9 @@ import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
ModelLoader(cfg, tokenizer, inference=False).load()
|
||||
|
||||
@with_temp_dir
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
@@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
ModelLoader(cfg, tokenizer, inference=False).load()
|
||||
|
||||
assert (
|
||||
"torch.jit"
|
||||
|
||||
@@ -6,8 +6,8 @@ import tempfile
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="temp_dir")
|
||||
@@ -58,6 +58,8 @@ class TestLoadModelUtils:
|
||||
ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer="",
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -71,13 +73,8 @@ class TestLoadModelUtils:
|
||||
):
|
||||
self.cfg.output_dir = temp_dir
|
||||
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
|
||||
self.model_loader.model, _ = load_model(
|
||||
self.cfg,
|
||||
self.model_loader.tokenizer,
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
self.model_loader.convert_embedding_modules_dtype(
|
||||
self.model_loader.load()
|
||||
self.model_loader._convert_embedding_modules_dtype(
|
||||
embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||
)
|
||||
for name, module in self.model_loader.model.named_modules():
|
||||
|
||||
Reference in New Issue
Block a user