don't use mask expansion for inference (#392)

This commit is contained in:
Wing Lian
2023-08-14 20:52:54 -04:00
committed by GitHub
parent 41ecb451c2
commit 1687be6a35
3 changed files with 6 additions and 2 deletions

View File

@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -138,8 +138,10 @@ def load_model(
LOG.info("patching with xpos rope") LOG.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope() replace_llama_rope_with_xpos_rope()
if cfg.is_llama_derived_model and ( if (
cfg.max_packed_sequence_len or cfg.sample_packing cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not cfg.inference
): ):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask