bump flash attention 2.5.8 -> 2.6.1 (#1738)
* bump flash attention 2.5.8 -> 2.6.1 * use triton implementation of cross entropy from flash attn * add smoke test for flash attn cross entropy patch * fix args to xentropy.apply * handle tuple from triton loss fn * ensure the patch tests run independently * use the wrapper already built into flash attn for cross entropy * mark pytest as forked for patches * use pytest xdist instead of forked, since cuda doesn't like forking * limit to 1 process and use dist loadfile for pytest * change up pytest for fixture to reload transformers w monkeypathc
This commit is contained in:
0
src/axolotl/integrations/__init__.py
Normal file
0
src/axolotl/integrations/__init__.py
Normal file
@@ -104,17 +104,12 @@ def replace_llama_attn_with_flash_attn(
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
|
||||
@@ -371,6 +371,12 @@ def load_model(
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
use_shifted_sparse_attn=True,
|
||||
)
|
||||
elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=False,
|
||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
)
|
||||
elif cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
|
||||
Reference in New Issue
Block a user