set torch version to what is installed during axolotl install (#1234)

This commit is contained in:
Wing Lian
2024-01-31 08:47:34 -05:00
committed by GitHub
parent 5787e1a23f
commit 8f2b591baf

View File

@@ -27,6 +27,7 @@ def parse_requirements():
try:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
if torch_version.startswith("2.1."):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
@@ -50,7 +51,7 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.3.3",
"flash-attn==2.5.0",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",