fix: assign linear head instead of loading state dict
This commit is contained in:
@@ -113,15 +113,17 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
# initialize the model with prior weights
|
||||
new_model = cls(config=config)
|
||||
|
||||
del new_model.model # remove the default model
|
||||
# remove the default model and lm_head
|
||||
del new_model.model
|
||||
del new_model.lm_head
|
||||
|
||||
new_model.model = convert_attention(
|
||||
model.model,
|
||||
attention_config=config.attention_config,
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
|
||||
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
||||
new_model.lm_head = model.lm_head
|
||||
|
||||
return new_model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user