diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0737d0f12..952aaaa97 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers -from transformers import AutoModelForCausalLM # noqa: F401 +from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401 from transformers import PreTrainedModel # noqa: F401 from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig @@ -172,8 +172,10 @@ def load_model( ) load_in_8bit = False elif is_llama_derived_model and "LlamaForCausalLM" in globals(): + config = LlamaConfig.from_pretrained(base_model_config) model = LlamaForCausalLM.from_pretrained( base_model, + config=config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype,