From 4512738a73c68e2f148cf079d0c6a934ef16b3f9 Mon Sep 17 00:00:00 2001 From: Akshaya Shanbhogue Date: Sat, 13 Jul 2024 11:04:31 -0700 Subject: [PATCH] bump xformers to 0.0.27 (#1740) * Update requirements.txt Preserve compatibility with torch 2.3.1. [Reference](https://github.com/facebookresearch/xformers/issues/1052) * fix setup.py to extract the current xformers dep from requirements for replacement * xformers 0.0.27 wheels not built for torch 2.3.0 --------- Co-authored-by: Wing Lian --- requirements.txt | 2 +- setup.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index c8d168734..e24845a44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ flash-attn==2.5.8 sentencepiece wandb einops -xformers==0.0.26.post1 +xformers==0.0.27 optimum==1.16.2 hf_transfer colorama diff --git a/setup.py b/setup.py index c7b4e15de..58d279475 100644 --- a/setup.py +++ b/setup.py @@ -29,9 +29,10 @@ def parse_requirements(): _install_requires.append(line) try: + xformers_version = [req for req in _install_requires if "xformers" in req][0] if "Darwin" in platform.system(): # don't install xformers on MacOS - _install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) + _install_requires.pop(_install_requires.index(xformers_version)) else: # detect the version of torch already installed # and set it so dependencies don't clobber the torch version @@ -49,12 +50,14 @@ def parse_requirements(): raise ValueError("Invalid version format") if (major, minor) >= (2, 3): - pass + if patch == 0: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.26.post1") elif (major, minor) >= (2, 2): - _install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) + _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1") else: - _install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) + _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.23.post1") except PackageNotFoundError: