simplify load_tokenizer
This commit is contained in:
committed by
Aman Gupta Karmani
parent
7b55fe6419
commit
efb3b2c95e
@@ -177,9 +177,8 @@ def train(
|
|||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
LOG.info(f"loading tokenizer... {tokenizer_config}")
|
tokenizer = load_tokenizer(cfg)
|
||||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||||
|
|||||||
@@ -32,32 +32,27 @@ if TYPE_CHECKING:
|
|||||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(
|
def load_tokenizer(cfg):
|
||||||
tokenizer_config,
|
|
||||||
tokenizer_type,
|
|
||||||
cfg,
|
|
||||||
):
|
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
|
|
||||||
if cfg.tokenizer_use_fast is not None:
|
if cfg.tokenizer_use_fast is not None:
|
||||||
use_fast = cfg.tokenizer_use_fast
|
use_fast = cfg.tokenizer_use_fast
|
||||||
if cfg.tokenizer_legacy is not None:
|
if cfg.tokenizer_legacy is not None:
|
||||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||||
if tokenizer_type:
|
|
||||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
tokenizer_cls = AutoTokenizer
|
||||||
tokenizer_config,
|
if cfg.tokenizer_type:
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||||
use_fast=use_fast,
|
|
||||||
**tokenizer_kwargs,
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
)
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
else:
|
tokenizer_config,
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
tokenizer_config,
|
use_fast=use_fast,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
**tokenizer_kwargs,
|
||||||
use_fast=use_fast,
|
)
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ in [
|
if tokenizer.__class__.__name__ in [
|
||||||
"LlamaTokenizer",
|
"LlamaTokenizer",
|
||||||
|
|||||||
@@ -13,17 +13,22 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def test_default_use_fast(self):
|
def test_default_use_fast(self):
|
||||||
cfg = DictDefault({})
|
cfg = DictDefault(
|
||||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert "Fast" in tokenizer.__class__.__name__
|
assert "Fast" in tokenizer.__class__.__name__
|
||||||
|
|
||||||
def test_dont_use_fast(self):
|
def test_dont_use_fast(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"tokenizer_use_fast": False,
|
"tokenizer_use_fast": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert "Fast" not in tokenizer.__class__.__name__
|
assert "Fast" not in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user