split up llama model loading so config can be loaded from base config and models can be loaded from a path

This commit is contained in:
Wing Lian
2023-05-30 22:32:44 -04:00
parent c5b0af1a7e
commit 2520ecd6df

View File

@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from transformers import AutoModelForCausalLM # noqa: F401
from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401
from transformers import PreTrainedModel # noqa: F401
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
@@ -172,8 +172,10 @@ def load_model(
)
load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,