diff --git a/requirements.txt b/requirements.txt index c8d168734..e24845a44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ flash-attn==2.5.8 sentencepiece wandb einops -xformers==0.0.26.post1 +xformers==0.0.27 optimum==1.16.2 hf_transfer colorama diff --git a/setup.py b/setup.py index c7b4e15de..58d279475 100644 --- a/setup.py +++ b/setup.py @@ -29,9 +29,10 @@ def parse_requirements(): _install_requires.append(line) try: + xformers_version = [req for req in _install_requires if "xformers" in req][0] if "Darwin" in platform.system(): # don't install xformers on MacOS - _install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) + _install_requires.pop(_install_requires.index(xformers_version)) else: # detect the version of torch already installed # and set it so dependencies don't clobber the torch version @@ -49,12 +50,14 @@ def parse_requirements(): raise ValueError("Invalid version format") if (major, minor) >= (2, 3): - pass + if patch == 0: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.26.post1") elif (major, minor) >= (2, 2): - _install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) + _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1") else: - _install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) + _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.23.post1") except PackageNotFoundError: