imrpove llama check and fix safetensors file check

This commit is contained in:
Wing Lian
2023-04-17 23:49:21 -04:00
parent e1076430ff
commit 69164da079

View File

@@ -85,14 +85,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
raise e
try:
if cfg.load_4bit and "llama" in base_model:
if cfg.load_4bit and "llama" in base_model or "llama" in cfg.model_type.lower():
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
cache_model_path = Path(snapshot_download(base_model))
# TODO search .glob for a .pt, .safetensor, or .bin
cache_model_path.glob("*.pt")
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin'))
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensors')) + list(cache_model_path.glob('*.bin'))
if len(files) > 0:
model_path = str(files[0])
else: