be more robust about checking embedding modules for lora finetunes (#1074) [skip ci]

* be more robust about checking embedding modules for lora finetunes

* update dynamic error message
This commit is contained in:
Wing Lian
2024-01-09 22:58:54 -05:00
committed by GitHub
parent ead34c516a
commit 0f100800e3
4 changed files with 104 additions and 30 deletions

View File

@@ -10,12 +10,13 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars
class ValidationTest(unittest.TestCase):
class BaseValidation(unittest.TestCase):
"""
Test the validation module
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@@ -24,6 +25,12 @@ class ValidationTest(unittest.TestCase):
def inject_fixtures(self, caplog):
self._caplog = caplog
class ValidationTest(BaseValidation):
"""
Test the validation module
"""
def test_load_4bit_deprecate(self):
cfg = DictDefault(
{
@@ -687,16 +694,23 @@ class ValidationTest(unittest.TestCase):
validate_config(cfg)
def test_add_tokens_adapter(self):
class ValidationCheckModelConfig(BaseValidation):
"""
Test the validation for the config when the model config is available
"""
def test_llama_add_tokens_adapter(self):
cfg = DictDefault(
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
)
model_config = DictDefault({"model_type": "llama"})
with pytest.raises(
ValueError,
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
validate_config(cfg)
check_model_config(cfg, model_config)
cfg = DictDefault(
{
@@ -709,9 +723,9 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(
ValueError,
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
validate_config(cfg)
check_model_config(cfg, model_config)
cfg = DictDefault(
{
@@ -722,10 +736,48 @@ class ValidationTest(unittest.TestCase):
}
)
validate_config(cfg)
check_model_config(cfg, model_config)
def test_phi2_add_tokens_adapter(self):
cfg = DictDefault(
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
)
model_config = DictDefault({"model_type": "phi-msft"})
with pytest.raises(
ValueError,
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
check_model_config(cfg, model_config)
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embed_tokens", "lm_head"],
}
)
with pytest.raises(
ValueError,
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
check_model_config(cfg, model_config)
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embd", "lm_head.linear"],
}
)
check_model_config(cfg, model_config)
class ValidationWandbTest(ValidationTest):
class ValidationWandbTest(BaseValidation):
"""
Validation test for wandb
"""