bump flash attention 2.5.8 -> 2.6.1 (#1738)

* bump flash attention 2.5.8 -> 2.6.1

* use triton implementation of cross entropy from flash attn

* add smoke test for flash attn cross entropy patch

* fix args to xentropy.apply

* handle tuple from triton loss fn

* ensure the patch tests run independently

* use the wrapper already built into flash attn for cross entropy

* mark pytest as forked for patches

* use pytest xdist instead of forked, since cuda doesn't like forking

* limit to 1 process and use dist loadfile for pytest

* change up pytest for fixture to reload transformers w monkeypathc
This commit is contained in:
Wing Lian
2024-07-14 19:11:31 -04:00
committed by GitHub
parent 219cd0d3c5
commit 98af5388ba
8 changed files with 103 additions and 14 deletions

View File

View File

@@ -104,17 +104,12 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled
if cross_entropy:
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.warning(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
# skip only if explicitly disabled
if rms_norm:

View File

@@ -371,6 +371,12 @@ def load_model(
rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,