fix: assign linear head instead of loading state dict
This commit is contained in:
@@ -113,15 +113,17 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
# initialize the model with prior weights
|
# initialize the model with prior weights
|
||||||
new_model = cls(config=config)
|
new_model = cls(config=config)
|
||||||
|
|
||||||
del new_model.model # remove the default model
|
# remove the default model and lm_head
|
||||||
|
del new_model.model
|
||||||
|
del new_model.lm_head
|
||||||
|
|
||||||
new_model.model = convert_attention(
|
new_model.model = convert_attention(
|
||||||
model.model,
|
model.model,
|
||||||
attention_config=config.attention_config,
|
attention_config=config.attention_config,
|
||||||
train_attention=train_attention,
|
train_attention=train_attention,
|
||||||
remove_base_attn=remove_base_attn,
|
remove_base_attn=remove_base_attn,
|
||||||
)
|
)
|
||||||
|
new_model.lm_head = model.lm_head
|
||||||
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user