fix for protected model_ namespace w pydantic (#1345)

This commit is contained in:
Wing Lian
2024-02-28 15:07:49 -05:00
committed by GitHub
parent 3a5a2d2f34
commit 6b3b271925
5 changed files with 76 additions and 22 deletions

View File

@@ -3,6 +3,7 @@
import logging
import os
import warnings
from typing import Optional
import pytest
@@ -14,6 +15,8 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars
warnings.filterwarnings("error")
@pytest.fixture(name="minimal_cfg")
def fixture_cfg():
@@ -190,6 +193,45 @@ class TestValidation(BaseValidation):
assert new_cfg.learning_rate == 0.00005
def test_model_config_remap(self, minimal_cfg):
cfg = (
DictDefault(
{
"model_config": {"model_type": "mistral"},
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.overrides_of_model_config["model_type"] == "mistral"
def test_model_type_remap(self, minimal_cfg):
cfg = (
DictDefault(
{
"model_type": "AutoModelForCausalLM",
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.type_of_model == "AutoModelForCausalLM"
def test_model_revision_remap(self, minimal_cfg):
cfg = (
DictDefault(
{
"model_revision": "main",
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.revision_of_model == "main"
def test_qlora(self, minimal_cfg):
base_cfg = (
DictDefault(