don't use mask expansion for inference (#392)

This commit is contained in:
Wing Lian
2023-08-14 20:52:54 -04:00
committed by GitHub
parent 41ecb451c2
commit 1687be6a35
3 changed files with 6 additions and 2 deletions

View File

@@ -138,8 +138,10 @@ def load_model(
LOG.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope()
if cfg.is_llama_derived_model and (
cfg.max_packed_sequence_len or cfg.sample_packing
if (
cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not cfg.inference
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask