fix for protected model_ namespace w pydantic (#1345)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@@ -14,6 +15,8 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import check_model_config
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_cfg")
|
||||
def fixture_cfg():
|
||||
@@ -190,6 +193,45 @@ class TestValidation(BaseValidation):
|
||||
|
||||
assert new_cfg.learning_rate == 0.00005
|
||||
|
||||
def test_model_config_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"model_config": {"model_type": "mistral"},
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.overrides_of_model_config["model_type"] == "mistral"
|
||||
|
||||
def test_model_type_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.type_of_model == "AutoModelForCausalLM"
|
||||
|
||||
def test_model_revision_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"model_revision": "main",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.revision_of_model == "main"
|
||||
|
||||
def test_qlora(self, minimal_cfg):
|
||||
base_cfg = (
|
||||
DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user