Feat: Add support for upstream FA2 (#626)
* Feat: Add support for upstream FA2 * chore: add is_falcon_derived_model: true to examples * chore: add config to readme for documentation * feat: add extra model types * fix: remove old falcon flash patch * chore: pin transformers and accelerate
This commit is contained in:
@@ -114,25 +114,13 @@ def load_model(
|
||||
|
||||
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
||||
|
||||
if hasattr(model_config, "model_type") and model_config.model_type in [
|
||||
"falcon",
|
||||
"RefinedWebModel",
|
||||
"RefinedWeb",
|
||||
]:
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
|
||||
replace_falcon_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
replace_falcon_attn_with_flash_attn()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching with flash attention")
|
||||
LOG.info("patching with flash attention for sample packing")
|
||||
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
@@ -213,6 +201,10 @@ def load_model(
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
# sample packing uses custom FA2 patch
|
||||
if cfg.flash_attention and not cfg.sample_packing:
|
||||
if cfg.is_llama_derived_model or cfg.is_falcon_derived_model:
|
||||
model_kwargs["use_flash_attention_2"] = True
|
||||
try:
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user