From 919246fbc1f6d925e792703d8fe989d05e200029 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 13 Aug 2023 18:57:38 -0400 Subject: [PATCH] don't pass rope_scaling kwarg if it's None (#383) --- src/axolotl/utils/models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1f36c50db..ce2d14f47 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -229,8 +229,12 @@ def load_model( elif cfg.is_llama_derived_model and not cfg.trust_remote_code: from transformers import LlamaForCausalLM + config_kwargs = {} + if cfg.rope_scaling: + config_kwargs["rope_scaling"] = cfg.rope_scaling config = LlamaConfig.from_pretrained( - base_model_config, rope_scaling=cfg.rope_scaling + base_model_config, + **config_kwargs, ) model = LlamaForCausalLM.from_pretrained( base_model,