fix: proprerly return causal model
This commit is contained in:
@@ -79,6 +79,9 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
Linear LLaMA model for causal language modeling.
|
||||
"""
|
||||
|
||||
config_class = LinearLlamaConfig
|
||||
base_model_prefix = "linear_llama"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LinearLlamaModel(config)
|
||||
@@ -102,20 +105,29 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
# Handle LlamaForCausalLM
|
||||
if isinstance(model, LlamaForCausalLM):
|
||||
model = model.model
|
||||
llama_model = model.model
|
||||
else:
|
||||
llama_model = model
|
||||
|
||||
if config is None:
|
||||
raise ValueError("Missing config")
|
||||
|
||||
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
||||
|
||||
new_model = convert_attention(
|
||||
model,
|
||||
llama_model = convert_attention(
|
||||
llama_model,
|
||||
DictDefault(**config.attention_config),
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
|
||||
# initialize the model with prior weights
|
||||
new_model = cls(config=config)
|
||||
del new_model.model # remove the default model
|
||||
del new_model.lm_head # remove the default lm_head
|
||||
new_model.model = llama_model
|
||||
new_model.lm_head = model.lm_head
|
||||
|
||||
return new_model
|
||||
|
||||
def toggle_attention(self, train: bool = True):
|
||||
|
||||
Reference in New Issue
Block a user