Feat: Add rope scaling (#343)

* Feat: Add rope scaling

* fix: move rope config
This commit is contained in:
NanoCode012
2023-08-13 00:50:15 +09:00
committed by GitHub
parent 289d5c403d
commit b5212068ac
2 changed files with 7 additions and 1 deletions

View File

@@ -474,6 +474,10 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only
xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# resume from a specific checkpoint dir
resume_from_checkpoint:

View File

@@ -219,7 +219,9 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM
config = LlamaConfig.from_pretrained(base_model_config)
config = LlamaConfig.from_pretrained(
base_model_config, rope_scaling=cfg.rope_scaling
)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,