be more robust about checking embedding modules for lora finetunes (#1074) [skip ci]
* be more robust about checking embedding modules for lora finetunes * update dynamic error message
This commit is contained in:
@@ -10,12 +10,13 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import check_model_config
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
|
||||
class ValidationTest(unittest.TestCase):
|
||||
class BaseValidation(unittest.TestCase):
|
||||
"""
|
||||
Test the validation module
|
||||
Base validation module to setup the log capture
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
@@ -24,6 +25,12 @@ class ValidationTest(unittest.TestCase):
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
class ValidationTest(BaseValidation):
|
||||
"""
|
||||
Test the validation module
|
||||
"""
|
||||
|
||||
def test_load_4bit_deprecate(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -687,16 +694,23 @@ class ValidationTest(unittest.TestCase):
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_add_tokens_adapter(self):
|
||||
|
||||
class ValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
Test the validation for the config when the model config is available
|
||||
"""
|
||||
|
||||
def test_llama_add_tokens_adapter(self):
|
||||
cfg = DictDefault(
|
||||
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
||||
)
|
||||
model_config = DictDefault({"model_type": "llama"})
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
||||
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
check_model_config(cfg, model_config)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -709,9 +723,9 @@ class ValidationTest(unittest.TestCase):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
||||
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
check_model_config(cfg, model_config)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -722,10 +736,48 @@ class ValidationTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
check_model_config(cfg, model_config)
|
||||
|
||||
def test_phi2_add_tokens_adapter(self):
|
||||
cfg = DictDefault(
|
||||
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
||||
)
|
||||
model_config = DictDefault({"model_type": "phi-msft"})
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
||||
):
|
||||
check_model_config(cfg, model_config)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": True,
|
||||
"tokens": ["<|imstart|>"],
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
||||
):
|
||||
check_model_config(cfg, model_config)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": True,
|
||||
"tokens": ["<|imstart|>"],
|
||||
"lora_modules_to_save": ["embd", "lm_head.linear"],
|
||||
}
|
||||
)
|
||||
|
||||
check_model_config(cfg, model_config)
|
||||
|
||||
|
||||
class ValidationWandbTest(ValidationTest):
|
||||
class ValidationWandbTest(BaseValidation):
|
||||
"""
|
||||
Validation test for wandb
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user