diff --git a/scripts/finetune.py b/scripts/finetune.py index 597108527..498d03de8 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -85,14 +85,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a raise e try: - if cfg.load_4bit and "llama" in base_model: + if cfg.load_4bit and "llama" in base_model or "llama" in cfg.model_type.lower(): from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram from huggingface_hub import snapshot_download cache_model_path = Path(snapshot_download(base_model)) - # TODO search .glob for a .pt, .safetensor, or .bin - cache_model_path.glob("*.pt") - files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin')) + files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensors')) + list(cache_model_path.glob('*.bin')) if len(files) > 0: model_path = str(files[0]) else: