bump flash attention 2.5.8 -> 2.6.1 (#1738)
* bump flash attention 2.5.8 -> 2.6.1 * use triton implementation of cross entropy from flash attn * add smoke test for flash attn cross entropy patch * fix args to xentropy.apply * handle tuple from triton loss fn * ensure the patch tests run independently * use the wrapper already built into flash attn for cross entropy * mark pytest as forked for patches * use pytest xdist instead of forked, since cuda doesn't like forking * limit to 1 process and use dist loadfile for pytest * change up pytest for fixture to reload transformers w monkeypathc
This commit is contained in:
4
setup.py
4
setup.py
@@ -80,10 +80,10 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.5.8",
|
||||
"flash-attn==2.6.1",
|
||||
],
|
||||
"fused-dense-lib": [
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
||||
|
||||
Reference in New Issue
Block a user