fix: assign linear head instead of loading state dict

This commit is contained in:
NanoCode012
2025-02-05 18:24:31 +07:00
parent 2d5f692fc0
commit 9e1c4de13c

View File

@@ -113,15 +113,17 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
# initialize the model with prior weights
new_model = cls(config=config)
del new_model.model # remove the default model
# remove the default model and lm_head
del new_model.model
del new_model.lm_head
new_model.model = convert_attention(
model.model,
attention_config=config.attention_config,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
new_model.lm_head = model.lm_head
return new_model