support for replit lm

This commit is contained in:
Wing Lian
2023-05-17 08:49:03 -04:00
parent b46bc02f0a
commit 8c2f3cb0f8
2 changed files with 67 additions and 3 deletions

View File

@@ -163,11 +163,20 @@ def load_model(
if not tokenizer:
try:
if is_llama_derived_model and "LlamaTokenizer" in globals():
tokenizer = LlamaTokenizer.from_pretrained(model)
tokenizer = LlamaTokenizer.from_pretrained(
model,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
model,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
except:
tokenizer = AutoTokenizer.from_pretrained(base_model_config)
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")