simplify load_tokenizer

This commit is contained in:
Aman Karmani
2023-08-13 01:33:38 +00:00
committed by Aman Gupta Karmani
parent 7b55fe6419
commit efb3b2c95e
3 changed files with 24 additions and 25 deletions

View File

@@ -32,32 +32,27 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer(
tokenizer_config,
tokenizer_type,
cfg,
):
def load_tokenizer(cfg):
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if tokenizer.__class__.__name__ in [
"LlamaTokenizer",