From b5212068ac531838e96f0783cd4624df2d7188ef Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 13 Aug 2023 00:50:15 +0900 Subject: [PATCH] Feat: Add rope scaling (#343) * Feat: Add rope scaling * fix: move rope config --- README.md | 4 ++++ src/axolotl/utils/models.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8dbb535cd..011b13903 100644 --- a/README.md +++ b/README.md @@ -474,6 +474,10 @@ landmark_attention: # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # llama only xpos_rope: +# RoPE Scaling https://github.com/huggingface/transformers/pull/24653 +rope_scaling: + type: # linear | dynamic + factor: # float # resume from a specific checkpoint dir resume_from_checkpoint: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 9224d0f4d..6abbd7265 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -219,7 +219,9 @@ def load_model( elif cfg.is_llama_derived_model and not cfg.trust_remote_code: from transformers import LlamaForCausalLM - config = LlamaConfig.from_pretrained(base_model_config) + config = LlamaConfig.from_pretrained( + base_model_config, rope_scaling=cfg.rope_scaling + ) model = LlamaForCausalLM.from_pretrained( base_model, config=config,