Add MPS support (#1264)

* add mps support

* linter stuff

* CI fixes

* install packaging for various tests

* Update setup.py

* Revert "install packaging for various tests"

This reverts commit 980e7aa44d.

* Revert "CI fixes"

This reverts commit 4609e3b166.

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Maxime
2024-02-12 14:30:32 +01:00
committed by GitHub
parent ea00dd0852
commit fac2d98c26
5 changed files with 102 additions and 8 deletions

View File

@@ -1,5 +1,7 @@
"""setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup
@@ -26,11 +28,25 @@ def parse_requirements():
_install_requires.append(line)
try:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
if torch_version.startswith("2.1."):
if "Darwin" in platform.system():
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
else:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 1):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
except PackageNotFoundError:
pass