diff --git a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py index 57ea6cacb..9aeba81c5 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py @@ -79,6 +79,9 @@ class LinearLlamaForCausalLM(LlamaForCausalLM): Linear LLaMA model for causal language modeling. """ + config_class = LinearLlamaConfig + base_model_prefix = "linear_llama" + def __init__(self, config): super().__init__(config) self.model = LinearLlamaModel(config) @@ -102,20 +105,29 @@ class LinearLlamaForCausalLM(LlamaForCausalLM): # Handle LlamaForCausalLM if isinstance(model, LlamaForCausalLM): - model = model.model + llama_model = model.model + else: + llama_model = model if config is None: raise ValueError("Missing config") from axolotl.integrations.lolcats.linearize_attention import convert_attention - new_model = convert_attention( - model, + llama_model = convert_attention( + llama_model, DictDefault(**config.attention_config), train_attention=train_attention, remove_base_attn=remove_base_attn, ) + # initialize the model with prior weights + new_model = cls(config=config) + del new_model.model # remove the default model + del new_model.lm_head # remove the default lm_head + new_model.model = llama_model + new_model.lm_head = model.lm_head + return new_model def toggle_attention(self, train: bool = True):