From 1687be6a35d886bfcff25b646cd96831bdfd274f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 14 Aug 2023 20:52:54 -0400 Subject: [PATCH] don't use mask expansion for inference (#392) --- examples/llama-2/lora.yml | 1 + examples/llama-2/qlora.yml | 1 + src/axolotl/utils/models.py | 6 ++++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index b4a121c98..2a0af130b 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf base_model_config: meta-llama/Llama-2-7b-hf model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer +is_llama_derived_model: true load_in_8bit: true load_in_4bit: false diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index 9a1c4c8c3..3ad2a7e4f 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf base_model_config: meta-llama/Llama-2-7b-hf model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer +is_llama_derived_model: true load_in_8bit: false load_in_4bit: true diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ce2d14f47..0a438ed21 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -138,8 +138,10 @@ def load_model( LOG.info("patching with xpos rope") replace_llama_rope_with_xpos_rope() - if cfg.is_llama_derived_model and ( - cfg.max_packed_sequence_len or cfg.sample_packing + if ( + cfg.is_llama_derived_model + and (cfg.max_packed_sequence_len or cfg.sample_packing) + and not cfg.inference ): from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask