attempt xformers hijack attention
This commit is contained in:
@@ -43,6 +43,10 @@ def load_model(
|
||||
|
||||
logging.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
elif is_llama_derived_model and cfg.xformers_attention:
|
||||
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
|
||||
logging.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
|
||||
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user