don't pass rope_scaling kwarg if it's None (#383)

This commit is contained in:
Wing Lian
2023-08-13 18:57:38 -04:00
committed by GitHub
parent ffac902c1b
commit 919246fbc1

View File

@@ -229,8 +229,12 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config, rope_scaling=cfg.rope_scaling
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained(
base_model,