From 9e1c4de13c912584e5b579b3d8240d7a19d2687d Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 5 Feb 2025 18:24:31 +0700 Subject: [PATCH] fix: assign linear head instead of loading state dict --- .../lolcats/linear_llama/modeling_linear_llama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py index aff906b19..a3497b55e 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py @@ -113,15 +113,17 @@ class LinearLlamaForCausalLM(LlamaForCausalLM): # initialize the model with prior weights new_model = cls(config=config) - del new_model.model # remove the default model + # remove the default model and lm_head + del new_model.model + del new_model.lm_head + new_model.model = convert_attention( model.model, attention_config=config.attention_config, train_attention=train_attention, remove_base_attn=remove_base_attn, ) - - new_model.lm_head.load_state_dict(model.lm_head.state_dict()) + new_model.lm_head = model.lm_head return new_model