Feat: Warns to add to modules_to_save when adding tokens or switching special_tokens (#787)
* Feat: Auto add to modules_to_save when adding tokens * fix: swap to error instead of warning * feat: add check when special_tokens differ and add test
This commit is contained in:
@@ -448,6 +448,20 @@ def validate_config(cfg):
|
|||||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.adapter
|
||||||
|
and cfg.tokens
|
||||||
|
and (
|
||||||
|
not cfg.lora_modules_to_save
|
||||||
|
or not all(
|
||||||
|
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -136,6 +136,23 @@ def load_tokenizer(cfg):
|
|||||||
|
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
for k, val in cfg.special_tokens.items():
|
for k, val in cfg.special_tokens.items():
|
||||||
|
# check if new special token is not already in tokenizer and
|
||||||
|
# is adapter training to make sure lora_modules_to_save is set
|
||||||
|
if (
|
||||||
|
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||||
|
and cfg.adapter
|
||||||
|
and (
|
||||||
|
not cfg.lora_modules_to_save
|
||||||
|
or not all(
|
||||||
|
x in cfg.lora_modules_to_save
|
||||||
|
for x in ["embed_tokens", "lm_head"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ Test cases for the tokenizer loading
|
|||||||
"""
|
"""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
@@ -31,6 +33,40 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert "Fast" not in tokenizer.__class__.__name__
|
assert "Fast" not in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
def test_special_tokens_modules_to_save(self):
|
||||||
|
# setting special_tokens to new token
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"adapter": "lora",
|
||||||
|
"special_tokens": {"bos_token": "[INST]"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*Please set lora_modules_to_save*",
|
||||||
|
):
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
# setting special_tokens but not changing from default
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"adapter": "lora",
|
||||||
|
"special_tokens": {"bos_token": "<s>"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
# non-adapter setting special_tokens
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"special_tokens": {"bos_token": "[INST]"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -682,6 +682,43 @@ class ValidationTest(unittest.TestCase):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_add_tokens_adapter(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"adapter": "qlora",
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"tokens": ["<|imstart|>"],
|
||||||
|
"lora_modules_to_save": ["embed_tokens"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"adapter": "qlora",
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"tokens": ["<|imstart|>"],
|
||||||
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
|
||||||
class ValidationWandbTest(ValidationTest):
|
class ValidationWandbTest(ValidationTest):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user