fix: load vocab_size
This commit is contained in:
@@ -108,13 +108,14 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError("Missing config")
|
raise ValueError("Missing config")
|
||||||
|
|
||||||
# initialize the model with prior weights
|
# initialize a new model with config
|
||||||
new_model = cls(config=config)
|
new_model = cls(config=config)
|
||||||
|
|
||||||
# remove the default model and lm_head
|
# remove the default model and lm_head
|
||||||
del new_model.model
|
del new_model.model
|
||||||
del new_model.lm_head
|
del new_model.lm_head
|
||||||
|
|
||||||
|
# load converted model, lm_head, and vocab_size from llama model
|
||||||
new_model.model = convert_attention(
|
new_model.model = convert_attention(
|
||||||
model.model,
|
model.model,
|
||||||
attention_config=config.attention_config,
|
attention_config=config.attention_config,
|
||||||
@@ -122,6 +123,7 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
remove_base_attn=remove_base_attn,
|
remove_base_attn=remove_base_attn,
|
||||||
)
|
)
|
||||||
new_model.lm_head = model.lm_head
|
new_model.lm_head = model.lm_head
|
||||||
|
new_model.vocab_size = model.vocab_size
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user