don't use mask expansion for inference (#392)
This commit is contained in:
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
|
|||||||
base_model_config: meta-llama/Llama-2-7b-hf
|
base_model_config: meta-llama/Llama-2-7b-hf
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_llama_derived_model: true
|
||||||
|
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
load_in_4bit: false
|
load_in_4bit: false
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
|
|||||||
base_model_config: meta-llama/Llama-2-7b-hf
|
base_model_config: meta-llama/Llama-2-7b-hf
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_llama_derived_model: true
|
||||||
|
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
|
|||||||
@@ -138,8 +138,10 @@ def load_model(
|
|||||||
LOG.info("patching with xpos rope")
|
LOG.info("patching with xpos rope")
|
||||||
replace_llama_rope_with_xpos_rope()
|
replace_llama_rope_with_xpos_rope()
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and (
|
if (
|
||||||
cfg.max_packed_sequence_len or cfg.sample_packing
|
cfg.is_llama_derived_model
|
||||||
|
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
||||||
|
and not cfg.inference
|
||||||
):
|
):
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user