From 47d601fa2389a7f7a0dac0bd767e669c3a326cbe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 25 Jun 2023 10:19:49 -0400 Subject: [PATCH] optionally define whether to use_fast tokenizer --- README.md | 2 ++ src/axolotl/utils/models.py | 5 +++++ tests/test_tokenizers.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 tests/test_tokenizers.py diff --git a/README.md b/README.md index 5fbac1a48..047d6aa34 100644 --- a/README.md +++ b/README.md @@ -302,6 +302,8 @@ model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer # Trust remote code for untrusted source trust_remote_code: +# use_fast option for tokenizer loading from_pretrained, default to True +tokenizer_use_fast: # whether you are training a 4-bit GPTQ quantized model gptq: true diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2ae9a26aa..6d94cd674 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -34,15 +34,20 @@ def load_tokenizer( tokenizer_type, cfg, ): + use_fast = True # this is the default + if cfg.tokenizer_use_fast is not None: + use_fast = cfg.tokenizer_use_fast if tokenizer_type: tokenizer = getattr(transformers, tokenizer_type).from_pretrained( tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, + use_fast=use_fast, ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, + use_fast=use_fast, ) logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py new file mode 100644 index 000000000..f2521e8e7 --- /dev/null +++ b/tests/test_tokenizers.py @@ -0,0 +1,31 @@ +""" +Test cases for the tokenizer loading +""" +import unittest + +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_tokenizer + + +class TestTokenizers(unittest.TestCase): + """ + test class for the load_tokenizer fn + """ + + def test_default_use_fast(self): + cfg = DictDefault({}) + tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg) + assert "Fast" in tokenizer.__class__.__name__ + + def test_dont_use_fast(self): + cfg = DictDefault( + { + "tokenizer_use_fast": False, + } + ) + tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg) + assert "Fast" not in tokenizer.__class__.__name__ + + +if __name__ == "__main__": + unittest.main()