simplify load_tokenizer
This commit is contained in:
committed by
Aman Gupta Karmani
parent
7b55fe6419
commit
efb3b2c95e
@@ -177,9 +177,8 @@ def train(
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
# load the tokenizer first
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
LOG.info(f"loading tokenizer... {tokenizer_config}")
|
||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
if (
|
||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||
|
||||
@@ -32,32 +32,27 @@ if TYPE_CHECKING:
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
tokenizer_config,
|
||||
tokenizer_type,
|
||||
cfg,
|
||||
):
|
||||
def load_tokenizer(cfg):
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
|
||||
if cfg.tokenizer_use_fast is not None:
|
||||
use_fast = cfg.tokenizer_use_fast
|
||||
if cfg.tokenizer_legacy is not None:
|
||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||
if tokenizer_type:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
tokenizer_cls = AutoTokenizer
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
if tokenizer.__class__.__name__ in [
|
||||
"LlamaTokenizer",
|
||||
|
||||
@@ -13,17 +13,22 @@ class TestTokenizers(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_default_use_fast(self):
|
||||
cfg = DictDefault({})
|
||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert "Fast" in tokenizer.__class__.__name__
|
||||
|
||||
def test_dont_use_fast(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"tokenizer_use_fast": False,
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert "Fast" not in tokenizer.__class__.__name__
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user