diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c51d0cd53..649578bd5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -43,6 +43,10 @@ def load_model( logging.info("patching with flash attention") replace_llama_attn_with_flash_attn() + elif is_llama_derived_model and cfg.xformers_attention: + from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention + logging.info("patching with xformers attention") + hijack_llama_attention() torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,) try: