fix: proprerly return causal model
This commit is contained in:
@@ -79,6 +79,9 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
Linear LLaMA model for causal language modeling.
|
Linear LLaMA model for causal language modeling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
config_class = LinearLlamaConfig
|
||||||
|
base_model_prefix = "linear_llama"
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model = LinearLlamaModel(config)
|
self.model = LinearLlamaModel(config)
|
||||||
@@ -102,20 +105,29 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
|
|
||||||
# Handle LlamaForCausalLM
|
# Handle LlamaForCausalLM
|
||||||
if isinstance(model, LlamaForCausalLM):
|
if isinstance(model, LlamaForCausalLM):
|
||||||
model = model.model
|
llama_model = model.model
|
||||||
|
else:
|
||||||
|
llama_model = model
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError("Missing config")
|
raise ValueError("Missing config")
|
||||||
|
|
||||||
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
||||||
|
|
||||||
new_model = convert_attention(
|
llama_model = convert_attention(
|
||||||
model,
|
llama_model,
|
||||||
DictDefault(**config.attention_config),
|
DictDefault(**config.attention_config),
|
||||||
train_attention=train_attention,
|
train_attention=train_attention,
|
||||||
remove_base_attn=remove_base_attn,
|
remove_base_attn=remove_base_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# initialize the model with prior weights
|
||||||
|
new_model = cls(config=config)
|
||||||
|
del new_model.model # remove the default model
|
||||||
|
del new_model.lm_head # remove the default lm_head
|
||||||
|
new_model.model = llama_model
|
||||||
|
new_model.lm_head = model.lm_head
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
def toggle_attention(self, train: bool = True):
|
def toggle_attention(self, train: bool = True):
|
||||||
|
|||||||
Reference in New Issue
Block a user