add mps support
This commit is contained in:
15
setup.py
15
setup.py
@@ -1,7 +1,8 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from packaging.version import Version, parse
|
||||
import platform
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
@@ -26,11 +27,15 @@ def parse_requirements():
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
if torch_version.startswith("2.1."):
|
||||
if "Darwin" in platform.system():
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
_install_requires.append("xformers>=0.0.23")
|
||||
else:
|
||||
torch_version = parse(version("torch"))
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
if torch_version >= Version("2.1"):
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
_install_requires.append("xformers>=0.0.23")
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user