fix up tokenizer config, isort fix

This commit is contained in:
Wing Lian
2023-05-30 23:00:02 -04:00
parent 2520ecd6df
commit 39a208c2bc
2 changed files with 13 additions and 7 deletions

View File

@@ -10,9 +10,14 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401
from transformers import PreTrainedModel # noqa: F401
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
LlamaConfig,
)
try:
from transformers import LlamaForCausalLM
@@ -31,18 +36,18 @@ if TYPE_CHECKING:
def load_tokenizer(
base_model_config,
tokenizer_config,
tokenizer_type,
cfg,
):
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
)