don't pass rope_scaling kwarg if it's None (#383)
This commit is contained in:
@@ -229,8 +229,12 @@ def load_model(
|
|||||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
|
config_kwargs = {}
|
||||||
|
if cfg.rope_scaling:
|
||||||
|
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
||||||
config = LlamaConfig.from_pretrained(
|
config = LlamaConfig.from_pretrained(
|
||||||
base_model_config, rope_scaling=cfg.rope_scaling
|
base_model_config,
|
||||||
|
**config_kwargs,
|
||||||
)
|
)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
|
|||||||
Reference in New Issue
Block a user