From 2520ecd6df3e0eb1d1813f3ad6dcb429d84e61fc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 May 2023 22:32:44 -0400 Subject: [PATCH] split up llama model loading so config can be loaded from base config and models can be loaded from a path --- src/axolotl/utils/models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0737d0f12..952aaaa97 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers -from transformers import AutoModelForCausalLM # noqa: F401 +from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401 from transformers import PreTrainedModel # noqa: F401 from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig @@ -172,8 +172,10 @@ def load_model( ) load_in_8bit = False elif is_llama_derived_model and "LlamaForCausalLM" in globals(): + config = LlamaConfig.from_pretrained(base_model_config) model = LlamaForCausalLM.from_pretrained( base_model, + config=config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype,