models.py -> loaders/ module refactor (#2680)
* models.py -> loaders/ module refactor * refactor ModelLoader class * plugin manager changes * circular import fix * pytest * pytest * minor improvements * fix * minor changes * fix test * remove dead code * coderabbit comments * lint * fix * coderabbit suggestion I liked * more coderabbit * review comments, yak shaving * lint * updating in light of SP ctx manager changes * review comment * review comment 2
This commit is contained in:
@@ -6,9 +6,9 @@ import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
ModelLoader(cfg, tokenizer, inference=False).load()
|
||||
|
||||
@with_temp_dir
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
@@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
ModelLoader(cfg, tokenizer, inference=False).load()
|
||||
|
||||
assert (
|
||||
"torch.jit"
|
||||
|
||||
Reference in New Issue
Block a user