bump flash attention 2.5.8 -> 2.6.1 (#1738)

* bump flash attention 2.5.8 -> 2.6.1

* use triton implementation of cross entropy from flash attn

* add smoke test for flash attn cross entropy patch

* fix args to xentropy.apply

* handle tuple from triton loss fn

* ensure the patch tests run independently

* use the wrapper already built into flash attn for cross entropy

* mark pytest as forked for patches

* use pytest xdist instead of forked, since cuda doesn't like forking

* limit to 1 process and use dist loadfile for pytest

* change up pytest for fixture to reload transformers w monkeypathc
This commit is contained in:
Wing Lian
2024-07-14 19:11:31 -04:00
committed by GitHub
parent 219cd0d3c5
commit 98af5388ba
8 changed files with 103 additions and 14 deletions

View File

@@ -12,7 +12,7 @@ fire
PyYAML>=6.0
requests
datasets==2.19.1
flash-attn==2.5.8
flash-attn==2.6.1
sentencepiece
wandb
einops