move ring flash attn to extras with flash-attn (#2414)
This commit is contained in:
@@ -68,4 +68,3 @@ axolotl-contribs-mit==0.0.3
|
|||||||
|
|
||||||
# for sequence parallelism
|
# for sequence parallelism
|
||||||
yunchang==0.6.0
|
yunchang==0.6.0
|
||||||
ring-flash-attn>=0.1.4
|
|
||||||
|
|||||||
11
setup.py
11
setup.py
@@ -17,11 +17,7 @@ def parse_requirements():
|
|||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
is_extras = (
|
is_extras = (
|
||||||
"flash-attn" in line
|
"deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line
|
||||||
or "flash-attention" in line
|
|
||||||
or "deepspeed" in line
|
|
||||||
or "mamba-ssm" in line
|
|
||||||
or "lion-pytorch" in line
|
|
||||||
)
|
)
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
@@ -39,7 +35,6 @@ def parse_requirements():
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"triton",
|
"triton",
|
||||||
"mamba-ssm",
|
"mamba-ssm",
|
||||||
"flash-attn",
|
|
||||||
"xformers",
|
"xformers",
|
||||||
"autoawq",
|
"autoawq",
|
||||||
"liger-kernel",
|
"liger-kernel",
|
||||||
@@ -124,9 +119,7 @@ setup(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": ["flash-attn==2.7.4.post1", "ring-flash-attn>=0.1.4"],
|
||||||
"flash-attn==2.7.4.post1",
|
|
||||||
],
|
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.4",
|
"deepspeed==0.16.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
|
|||||||
Reference in New Issue
Block a user