attempt xformers hijack attention

This commit is contained in:
Wing Lian
2023-04-18 10:44:56 -04:00
parent 6045345d6b
commit 8746b701fe

View File

@@ -43,6 +43,10 @@ def load_model(
logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
elif is_llama_derived_model and cfg.xformers_attention:
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
logging.info("patching with xformers attention")
hijack_llama_attention()
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
try: