use even if not using sample packing
This commit is contained in:
@@ -136,7 +136,11 @@ def load_model(
|
|||||||
|
|
||||||
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
if (
|
||||||
|
cfg.is_llama_derived_model
|
||||||
|
and cfg.flash_attention
|
||||||
|
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
|
||||||
|
):
|
||||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
|
|||||||
Reference in New Issue
Block a user