models.py -> loaders/ module refactor (#2680)

* models.py -> loaders/ module refactor

* refactor ModelLoader class

* plugin manager changes

* circular import fix

* pytest

* pytest

* minor improvements

* fix

* minor changes

* fix test

* remove dead code

* coderabbit comments

* lint

* fix

* coderabbit suggestion I liked

* more coderabbit

* review comments, yak shaving

* lint

* updating in light of SP ctx manager changes

* review comment

* review comment 2
This commit is contained in:
Dan Saunders
2025-05-23 15:51:11 -04:00
committed by GitHub
parent 8cde256db2
commit b5f1e53a0f
33 changed files with 2249 additions and 2039 deletions

View File

@@ -6,8 +6,8 @@ import tempfile
import pytest
import torch
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
@pytest.fixture(name="temp_dir")
@@ -58,6 +58,8 @@ class TestLoadModelUtils:
ModelLoader(
cfg=self.cfg,
tokenizer="",
inference=False,
reference_model=True,
)
)
@@ -71,13 +73,8 @@ class TestLoadModelUtils:
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
self.model_loader.model, _ = load_model(
self.cfg,
self.model_loader.tokenizer,
inference=False,
reference_model=True,
)
self.model_loader.convert_embedding_modules_dtype(
self.model_loader.load()
self.model_loader._convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():