be more robust about checking embedding modules for lora finetunes (#1074) [skip ci]
* be more robust about checking embedding modules for lora finetunes * update dynamic error message
This commit is contained in:
@@ -151,6 +151,10 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
|
"""
|
||||||
|
This is a "pre-validation" step that handles the yaml configuration before we have any
|
||||||
|
information about the model architecture
|
||||||
|
"""
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
if not cfg.bf16 and not cfg.bfloat16:
|
if not cfg.bf16 and not cfg.bfloat16:
|
||||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
||||||
@@ -443,20 +447,6 @@ def validate_config(cfg):
|
|||||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.adapter
|
|
||||||
and cfg.tokens
|
|
||||||
and (
|
|
||||||
not cfg.lora_modules_to_save
|
|
||||||
or not all(
|
|
||||||
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||||
|
|||||||
12
src/axolotl/utils/lora_embeddings.py
Normal file
12
src/axolotl/utils/lora_embeddings.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
helpers for lora embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_linear_embedding_layers(model_type):
|
||||||
|
"""
|
||||||
|
returns the linear embedding layers needed for loras, dependent on the model arch
|
||||||
|
"""
|
||||||
|
if model_type == "phi-msft":
|
||||||
|
return ["embd", "lm_head.linear"]
|
||||||
|
return ["lm_head", "embed_tokens"]
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional, Tuple # noqa: F401
|
from typing import Any, Optional, Tuple, Union # noqa: F401
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
@@ -28,12 +28,16 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||||
quant_config_exists = hasattr(model_config, "quantization_config")
|
quant_config_exists = (
|
||||||
|
hasattr(model_config, "quantization_config")
|
||||||
|
and model_config.quantization_config
|
||||||
|
)
|
||||||
quant_config_method_is_gptq = (
|
quant_config_method_is_gptq = (
|
||||||
quant_config_exists
|
quant_config_exists
|
||||||
and "quant_method" in model_config.quantization_config
|
and "quant_method" in model_config.quantization_config
|
||||||
@@ -52,6 +56,20 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|||||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||||
|
if (
|
||||||
|
cfg.adapter
|
||||||
|
and cfg.tokens
|
||||||
|
and (
|
||||||
|
not cfg.lora_modules_to_save
|
||||||
|
or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save))
|
||||||
|
raise ValueError(
|
||||||
|
f"`lora_modules_to_save` not properly set when adding new tokens. Please include {lora_modules_to_save} in `lora_modules_to_save`."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
@@ -139,6 +157,7 @@ def load_tokenizer(cfg):
|
|||||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||||
|
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
|
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||||
for k, val in cfg.special_tokens.items():
|
for k, val in cfg.special_tokens.items():
|
||||||
# check if new special token is not already in tokenizer and
|
# check if new special token is not already in tokenizer and
|
||||||
# is adapter training to make sure lora_modules_to_save is set
|
# is adapter training to make sure lora_modules_to_save is set
|
||||||
@@ -149,14 +168,15 @@ def load_tokenizer(cfg):
|
|||||||
and (
|
and (
|
||||||
not cfg.lora_modules_to_save
|
not cfg.lora_modules_to_save
|
||||||
or not all(
|
or not all(
|
||||||
x in cfg.lora_modules_to_save
|
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||||
for x in ["embed_tokens", "lm_head"]
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
and (model_config.model_type in ["llama", "mistral", "mixtral"])
|
|
||||||
):
|
):
|
||||||
|
lora_modules_to_save = ", ".join(
|
||||||
|
[f"`{x}`" for x in lora_modules_to_save]
|
||||||
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
f"Please set lora_modules_to_save to {lora_modules_to_save} when using an adapter and changing the special tokens."
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
|
|||||||
@@ -10,12 +10,13 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import check_model_config
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class BaseValidation(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test the validation module
|
Base validation module to setup the log capture
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||||
@@ -24,6 +25,12 @@ class ValidationTest(unittest.TestCase):
|
|||||||
def inject_fixtures(self, caplog):
|
def inject_fixtures(self, caplog):
|
||||||
self._caplog = caplog
|
self._caplog = caplog
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationTest(BaseValidation):
|
||||||
|
"""
|
||||||
|
Test the validation module
|
||||||
|
"""
|
||||||
|
|
||||||
def test_load_4bit_deprecate(self):
|
def test_load_4bit_deprecate(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -687,16 +694,23 @@ class ValidationTest(unittest.TestCase):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_add_tokens_adapter(self):
|
|
||||||
|
class ValidationCheckModelConfig(BaseValidation):
|
||||||
|
"""
|
||||||
|
Test the validation for the config when the model config is available
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_llama_add_tokens_adapter(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
||||||
)
|
)
|
||||||
|
model_config = DictDefault({"model_type": "llama"})
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
check_model_config(cfg, model_config)
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -709,9 +723,9 @@ class ValidationTest(unittest.TestCase):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
check_model_config(cfg, model_config)
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -722,10 +736,48 @@ class ValidationTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
check_model_config(cfg, model_config)
|
||||||
|
|
||||||
|
def test_phi2_add_tokens_adapter(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
||||||
|
)
|
||||||
|
model_config = DictDefault({"model_type": "phi-msft"})
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
||||||
|
):
|
||||||
|
check_model_config(cfg, model_config)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"adapter": "qlora",
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"tokens": ["<|imstart|>"],
|
||||||
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
||||||
|
):
|
||||||
|
check_model_config(cfg, model_config)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"adapter": "qlora",
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"tokens": ["<|imstart|>"],
|
||||||
|
"lora_modules_to_save": ["embd", "lm_head.linear"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
check_model_config(cfg, model_config)
|
||||||
|
|
||||||
|
|
||||||
class ValidationWandbTest(ValidationTest):
|
class ValidationWandbTest(BaseValidation):
|
||||||
"""
|
"""
|
||||||
Validation test for wandb
|
Validation test for wandb
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user