diff --git a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py index aff906b19..a3497b55e 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py @@ -113,15 +113,17 @@ class LinearLlamaForCausalLM(LlamaForCausalLM): # initialize the model with prior weights new_model = cls(config=config) - del new_model.model # remove the default model + # remove the default model and lm_head + del new_model.model + del new_model.lm_head + new_model.model = convert_attention( model.model, attention_config=config.attention_config, train_attention=train_attention, remove_base_attn=remove_base_attn, ) - - new_model.lm_head.load_state_dict(model.lm_head.state_dict()) + new_model.lm_head = model.lm_head return new_model