fix up tokenizer config, isort fix

This commit is contained in:
Wing Lian
2023-05-30 23:00:02 -04:00
parent 2520ecd6df
commit 39a208c2bc
2 changed files with 13 additions and 7 deletions

View File

@@ -171,8 +171,9 @@ def train(
validate_config(cfg)
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if check_not_in(
["inference", "shard", "merge_lora"], kwargs

View File

@@ -10,9 +10,14 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401
from transformers import PreTrainedModel # noqa: F401
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
LlamaConfig,
)
try:
from transformers import LlamaForCausalLM
@@ -31,18 +36,18 @@ if TYPE_CHECKING:
def load_tokenizer(
base_model_config,
tokenizer_config,
tokenizer_type,
cfg,
):
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
)