From eb808903e5dca6a1a516097b41a58fd7e2dca453 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 18 Apr 2023 01:19:53 -0400 Subject: [PATCH] fix llama check --- scripts/finetune.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index b6573ebe9..fa09f401a 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -60,12 +60,14 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit tokenizer = None + is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower() if adapter != "lora": raise NotImplementedError(f"{adapter} peft adapter not available") - if "llama" in base_model and cfg.flash_attention: + if is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and inference is False: from axolotl.flash_attn import replace_llama_attn_with_flash_attn + logging.info("patching with flash attention") replace_llama_attn_with_flash_attn() torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32, @@ -85,7 +87,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a raise e try: - if cfg.load_4bit and ("llama" in base_model or "llama" in cfg.model_type.lower()): + if cfg.load_4bit and is_llama_derived_model: from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram from huggingface_hub import snapshot_download @@ -104,7 +106,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a is_v1_model=cfg.gptq_model_v1 if cfg.gptq_model_v1 is not None else True, ) load_in_8bit = False - elif "llama" in base_model: + elif is_llama_derived_model: model = LlamaForCausalLM.from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit, @@ -128,13 +130,18 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a if not tokenizer: try: - if "llama" in base_model: + if is_llama_derived_model: tokenizer = LlamaTokenizer.from_pretrained(model) else: tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model) except: tokenizer = AutoTokenizer.from_pretrained(base_model) + logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") + logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") + logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") + logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN @@ -144,6 +151,7 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a os.environ["TOKENIZERS_PARALLELISM"] = "false" if load_in_8bit and not cfg.load_4bit: + logging.info("converting model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) lora_config = LoraConfig(