add mps support

This commit is contained in:
Maxime
2024-02-04 14:20:40 +01:00
parent 2d65f470d5
commit eb300b6c57
5 changed files with 91 additions and 9 deletions

View File

@@ -1,7 +1,8 @@
"""setup.py for axolotl"""
from importlib.metadata import PackageNotFoundError, version
from packaging.version import Version, parse
import platform
from setuptools import find_packages, setup
@@ -26,11 +27,15 @@ def parse_requirements():
_install_requires.append(line)
try:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
if torch_version.startswith("2.1."):
if "Darwin" in platform.system():
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
else:
torch_version = parse(version("torch"))
_install_requires.append(f"torch=={torch_version}")
if torch_version >= Version("2.1"):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
except PackageNotFoundError:
pass