From 69164da07920b7365bf2785027edf958e490ff12 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 17 Apr 2023 23:49:21 -0400 Subject: [PATCH] imrpove llama check and fix safetensors file check --- scripts/finetune.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 597108527..498d03de8 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -85,14 +85,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a raise e try: - if cfg.load_4bit and "llama" in base_model: + if cfg.load_4bit and "llama" in base_model or "llama" in cfg.model_type.lower(): from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram from huggingface_hub import snapshot_download cache_model_path = Path(snapshot_download(base_model)) - # TODO search .glob for a .pt, .safetensor, or .bin - cache_model_path.glob("*.pt") - files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin')) + files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensors')) + list(cache_model_path.glob('*.bin')) if len(files) > 0: model_path = str(files[0]) else: