diff --git a/scripts/finetune.py b/scripts/finetune.py index ba3a59a6a..4c24a3c4f 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -79,31 +79,31 @@ def do_inference(cfg, model, tokenizer): from axolotl.prompters import ReflectAlpacaPrompter - instruction = str(input("Give me an instruction: ")) - instruction = ( - instruction if not instruction else "Tell me a joke about dromedaries." - ) - prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction) - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + while True: + instruction = str(input("Give me an instruction: ")) + if not instruction: + return + prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction) + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - model.eval() - with torch.no_grad(): - # gc = GenerationConfig() # TODO swap out and use this - generated = model.generate( - inputs=batch["input_ids"].to("cuda"), - do_sample=True, - use_cache=True, - repetition_penalty=1.1, - max_new_tokens=100, - temperature=0.9, - top_p=0.95, - top_k=40, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) + model.eval() + with torch.no_grad(): + # gc = GenerationConfig() # TODO swap out and use this + generated = model.generate( + inputs=batch["input_ids"].to("cuda"), + do_sample=True, + use_cache=True, + repetition_penalty=1.1, + max_new_tokens=100, + temperature=0.9, + top_p=0.95, + top_k=40, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) def choose_config(path: Path): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d05e62d29..d66831861 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -66,22 +66,25 @@ def load_model( from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram from huggingface_hub import snapshot_download - snapshot_download_kwargs = {} - if cfg.base_model_ignore_patterns: - snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns - cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs)) - files = ( - list(cache_model_path.glob("*.pt")) - + list(cache_model_path.glob("*.safetensors")) - + list(cache_model_path.glob("*.bin")) - ) - if len(files) > 0: - model_path = str(files[0]) - else: - logging.warning( - "unable to find a cached model file, this will likely fail..." + try: + snapshot_download_kwargs = {} + if cfg.base_model_ignore_patterns: + snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns + cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs)) + files = ( + list(cache_model_path.glob("*.pt")) + + list(cache_model_path.glob("*.safetensors")) + + list(cache_model_path.glob("*.bin")) ) - model_path = str(cache_model_path) + if len(files) > 0: + model_path = str(files[0]) + else: + logging.warning( + "unable to find a cached model file, this will likely fail..." + ) + model_path = str(cache_model_path) + except: + model_path = cfg.base_model model, tokenizer = load_llama_model_4bit_low_ram( base_model_config if base_model_config else base_model, model_path,