use even if not using sample packing

This commit is contained in:
Wing Lian
2023-10-13 17:54:35 -04:00
parent f95858d369
commit 080612219b

View File

@@ -136,7 +136,11 @@ def load_model(
replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if (
cfg.is_llama_derived_model
and cfg.flash_attention
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
):
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,