models.py -> loaders/ module refactor (#2680)

* models.py -> loaders/ module refactor

* refactor ModelLoader class

* plugin manager changes

* circular import fix

* pytest

* pytest

* minor improvements

* fix

* minor changes

* fix test

* remove dead code

* coderabbit comments

* lint

* fix

* coderabbit suggestion I liked

* more coderabbit

* review comments, yak shaving

* lint

* updating in light of SP ctx manager changes

* review comment

* review comment 2
This commit is contained in:
Dan Saunders
2025-05-23 15:51:11 -04:00
committed by GitHub
parent 8cde256db2
commit b5f1e53a0f
33 changed files with 2249 additions and 2039 deletions

View File

@@ -6,9 +6,9 @@ import unittest
import transformers
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from ..utils import with_temp_dir
@@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
ModelLoader(cfg, tokenizer, inference=False).load()
@with_temp_dir
def test_mistral_multipack(self, temp_dir):
@@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
ModelLoader(cfg, tokenizer, inference=False).load()
assert (
"torch.jit"