diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index b4a121c98..2a0af130b 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf base_model_config: meta-llama/Llama-2-7b-hf model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer +is_llama_derived_model: true load_in_8bit: true load_in_4bit: false diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index 9a1c4c8c3..3ad2a7e4f 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf base_model_config: meta-llama/Llama-2-7b-hf model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer +is_llama_derived_model: true load_in_8bit: false load_in_4bit: true diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ce2d14f47..0a438ed21 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -138,8 +138,10 @@ def load_model( LOG.info("patching with xpos rope") replace_llama_rope_with_xpos_rope() - if cfg.is_llama_derived_model and ( - cfg.max_packed_sequence_len or cfg.sample_packing + if ( + cfg.is_llama_derived_model + and (cfg.max_packed_sequence_len or cfg.sample_packing) + and not cfg.inference ): from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask