fix: proprerly return causal model

This commit is contained in:
NanoCode012
2025-02-05 15:56:57 +07:00
parent 4cc60df876
commit 253dcdd0cf

View File

@@ -79,6 +79,9 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
Linear LLaMA model for causal language modeling.
"""
config_class = LinearLlamaConfig
base_model_prefix = "linear_llama"
def __init__(self, config):
super().__init__(config)
self.model = LinearLlamaModel(config)
@@ -102,20 +105,29 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
model = model.model
llama_model = model.model
else:
llama_model = model
if config is None:
raise ValueError("Missing config")
from axolotl.integrations.lolcats.linearize_attention import convert_attention
new_model = convert_attention(
model,
llama_model = convert_attention(
llama_model,
DictDefault(**config.attention_config),
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
# initialize the model with prior weights
new_model = cls(config=config)
del new_model.model # remove the default model
del new_model.lm_head # remove the default lm_head
new_model.model = llama_model
new_model.lm_head = model.lm_head
return new_model
def toggle_attention(self, train: bool = True):