Dockerfile torch fix (#987)

* add torch to requirements.txt at build time to force version to stick

* fix xformers check

* better handling of xformers based on installed torch version

* fix for ci w/o torch
This commit is contained in:
Wing Lian
2023-12-21 09:38:20 -05:00
committed by GitHub
parent d25c34caa6
commit 161bcb6517
4 changed files with 12 additions and 10 deletions

View File

@@ -1,5 +1,7 @@
"""setup.py for axolotl"""
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup
@@ -22,12 +24,13 @@ def parse_requirements():
# Handle standard packages
_install_requires.append(line)
# TODO(wing) remove once xformers release supports torch 2.1.0
if "torch==2.1.0" in _install_requires:
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
_install_requires.append(
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
)
try:
torch_version = version("torch")
if torch_version.startswith("2.1.1"):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers==0.0.23")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links