Merge pull request #240 from OpenAccess-AI-Collective/tokenizer-fast
optionally define whether to use_fast tokenizer
This commit is contained in:
@@ -302,6 +302,8 @@ model_type: AutoModelForCausalLM
|
|||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
# Trust remote code for untrusted source
|
# Trust remote code for untrusted source
|
||||||
trust_remote_code:
|
trust_remote_code:
|
||||||
|
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||||
|
tokenizer_use_fast:
|
||||||
|
|
||||||
# whether you are training a 4-bit GPTQ quantized model
|
# whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
|
|||||||
@@ -34,15 +34,20 @@ def load_tokenizer(
|
|||||||
tokenizer_type,
|
tokenizer_type,
|
||||||
cfg,
|
cfg,
|
||||||
):
|
):
|
||||||
|
use_fast = True # this is the default
|
||||||
|
if cfg.tokenizer_use_fast is not None:
|
||||||
|
use_fast = cfg.tokenizer_use_fast
|
||||||
if tokenizer_type:
|
if tokenizer_type:
|
||||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
use_fast=use_fast,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
use_fast=use_fast,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
|
|||||||
31
tests/test_tokenizers.py
Normal file
31
tests/test_tokenizers.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""
|
||||||
|
Test cases for the tokenizer loading
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenizers(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
test class for the load_tokenizer fn
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_default_use_fast(self):
|
||||||
|
cfg = DictDefault({})
|
||||||
|
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||||
|
assert "Fast" in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
def test_dont_use_fast(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_use_fast": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||||
|
assert "Fast" not in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user