split up llama model loading so config can be loaded from base config and models can be loaded from a path
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM # noqa: F401
|
||||
from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401
|
||||
from transformers import PreTrainedModel # noqa: F401
|
||||
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
@@ -172,8 +172,10 @@ def load_model(
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
config = LlamaConfig.from_pretrained(base_model_config)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
|
||||
Reference in New Issue
Block a user