split up llama model loading so config can be loaded from base config and models can be loaded from a path
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import AutoModelForCausalLM # noqa: F401
|
from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401
|
||||||
from transformers import PreTrainedModel # noqa: F401
|
from transformers import PreTrainedModel # noqa: F401
|
||||||
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
||||||
|
|
||||||
@@ -172,8 +172,10 @@ def load_model(
|
|||||||
)
|
)
|
||||||
load_in_8bit = False
|
load_in_8bit = False
|
||||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||||
|
config = LlamaConfig.from_pretrained(base_model_config)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
|
config=config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
|
|||||||
Reference in New Issue
Block a user