diff --git a/scripts/finetune.py b/scripts/finetune.py index 6c42b3061..e1b0b2e59 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -171,8 +171,9 @@ def train( validate_config(cfg) # load the tokenizer first - logging.info("loading tokenizer...") - tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg) + tokenizer_config = cfg.tokenizer_config or cfg.base_model_config + logging.info(f"loading tokenizer... {tokenizer_config}") + tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) if check_not_in( ["inference", "shard", "merge_lora"], kwargs diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 952aaaa97..dc303bca6 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -10,9 +10,14 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers -from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401 from transformers import PreTrainedModel # noqa: F401 -from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig +from transformers import ( # noqa: F401 + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + LlamaConfig, +) try: from transformers import LlamaForCausalLM @@ -31,18 +36,18 @@ if TYPE_CHECKING: def load_tokenizer( - base_model_config, + tokenizer_config, tokenizer_type, cfg, ): if tokenizer_type: tokenizer = getattr(transformers, tokenizer_type).from_pretrained( - base_model_config, + tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, ) else: tokenizer = AutoTokenizer.from_pretrained( - base_model_config, + tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, )