use even if not using sample packing
This commit is contained in:
@@ -136,7 +136,11 @@ def load_model(
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
and cfg.flash_attention
|
||||
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
|
||||
):
|
||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
|
||||
Reference in New Issue
Block a user