simplify load_tokenizer

This commit is contained in:
Aman Karmani
2023-08-13 01:33:38 +00:00
committed by Aman Gupta Karmani
parent 7b55fe6419
commit efb3b2c95e
3 changed files with 24 additions and 25 deletions

View File

@@ -177,9 +177,8 @@ def train(
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
# load the tokenizer first # load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
LOG.info(f"loading tokenizer... {tokenizer_config}") tokenizer = load_tokenizer(cfg)
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if ( if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference

View File

@@ -32,32 +32,27 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401 from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer( def load_tokenizer(cfg):
tokenizer_config,
tokenizer_type,
cfg,
):
tokenizer_kwargs = {} tokenizer_kwargs = {}
use_fast = True # this is the default use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None: if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None: if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224 # True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained( tokenizer_cls = AutoTokenizer
tokenizer_config, if cfg.tokenizer_type:
trust_remote_code=cfg.trust_remote_code or False, tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
use_fast=use_fast,
**tokenizer_kwargs, tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
) tokenizer = tokenizer_cls.from_pretrained(
else: tokenizer_config,
tokenizer = AutoTokenizer.from_pretrained( trust_remote_code=cfg.trust_remote_code or False,
tokenizer_config, use_fast=use_fast,
trust_remote_code=cfg.trust_remote_code or False, **tokenizer_kwargs,
use_fast=use_fast, )
**tokenizer_kwargs,
)
if tokenizer.__class__.__name__ in [ if tokenizer.__class__.__name__ in [
"LlamaTokenizer", "LlamaTokenizer",

View File

@@ -13,17 +13,22 @@ class TestTokenizers(unittest.TestCase):
""" """
def test_default_use_fast(self): def test_default_use_fast(self):
cfg = DictDefault({}) cfg = DictDefault(
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg) {
"tokenizer_config": "huggyllama/llama-7b",
}
)
tokenizer = load_tokenizer(cfg)
assert "Fast" in tokenizer.__class__.__name__ assert "Fast" in tokenizer.__class__.__name__
def test_dont_use_fast(self): def test_dont_use_fast(self):
cfg = DictDefault( cfg = DictDefault(
{ {
"tokenizer_config": "huggyllama/llama-7b",
"tokenizer_use_fast": False, "tokenizer_use_fast": False,
} }
) )
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg) tokenizer = load_tokenizer(cfg)
assert "Fast" not in tokenizer.__class__.__name__ assert "Fast" not in tokenizer.__class__.__name__