From 45f77dd51e3bb72bf7d22103c98874fa57caccea Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 14 Apr 2023 19:30:41 -0400 Subject: [PATCH] bettter handling of llama model import --- scripts/finetune.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 3ed8093ba..c5d467c6f 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -19,7 +19,7 @@ from peft import ( get_peft_model_state_dict, PeftModel, ) from torch import nn -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer # add src to the pythonpath so we don't need to pip install this from transformers.trainer_pt_utils import get_parameter_names @@ -53,16 +53,23 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): raise NotImplementedError(f"{adapter} peft adapter not available") if "llama" in base_model: from axolotl.flash_attn import replace_llama_attn_with_flash_attn - replace_llama_attn_with_flash_attn() try: - model = getattr(transformers, model_type).from_pretrained( - base_model, - load_in_8bit=cfg.load_in_8bit, - torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32, - device_map=cfg.device_map, - ) + if "llama" in base_model: + model = LlamaForCausalLM.from_pretrained( + base_model, + load_in_8bit=cfg.load_in_8bit, + torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32, + device_map=cfg.device_map, + ) + else: + model = getattr(transformers, model_type).from_pretrained( + base_model, + load_in_8bit=cfg.load_in_8bit, + torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32, + device_map=cfg.device_map, + ) except: model = AutoModelForCausalLM.from_pretrained( base_model, @@ -72,7 +79,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): ) try: - tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model) + if "llama" in base_model: + tokenizer = LlamaTokenizer.from_pretrained(model) + else: + tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model) except: tokenizer = AutoTokenizer.from_pretrained(base_model)