From 790df757cb40bf796ad143e010064b1fddff6f06 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 16 Jan 2026 09:02:37 -0500 Subject: [PATCH] don't install xformers in for arm64 (#3359) * install xformers in the base docker image * install numba and numpy first * set CUDA_HOME for xformers install * Set cuda home env * don't install xformers by default on aarch64/arm64 --- setup.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 10c9a8453..5f51dbee0 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ def parse_requirements(extras_require_map): _install_requires.append(line) try: xformers_version = [req for req in _install_requires if "xformers" in req][0] + install_xformers = platform.machine() != "aarch64" if "Darwin" in platform.system(): # skip packages not compatible with OSX skip_packages = [ @@ -66,40 +67,49 @@ def parse_requirements(extras_require_map): extras_require_map.pop("fbgemm-gpu") extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"] extras_require_map["vllm"] = ["vllm==0.11.1"] + if not install_xformers: + _install_requires.pop(_install_requires.index(xformers_version)) elif (major, minor) >= (2, 8): extras_require_map.pop("fbgemm-gpu") extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"] extras_require_map["vllm"] = ["vllm==0.11.0"] + if not install_xformers: + _install_requires.pop(_install_requires.index(xformers_version)) elif (major, minor) >= (2, 7): _install_requires.pop(_install_requires.index(xformers_version)) if patch == 0: - _install_requires.append("xformers==0.0.30") + if install_xformers: + _install_requires.append("xformers==0.0.30") # vllm 0.9.x is incompatible with latest transformers extras_require_map.pop("vllm") else: - _install_requires.append("xformers==0.0.31") + if install_xformers: + _install_requires.append("xformers==0.0.31") extras_require_map["vllm"] = ["vllm==0.10.1"] elif (major, minor) >= (2, 6): _install_requires.pop(_install_requires.index(xformers_version)) - _install_requires.append("xformers==0.0.29.post3") + if install_xformers: + _install_requires.append("xformers==0.0.29.post3") # since we only support 2.6.0+cu126 _dependency_links.append("https://download.pytorch.org/whl/cu126") extras_require_map.pop("vllm") elif (major, minor) >= (2, 5): _install_requires.pop(_install_requires.index(xformers_version)) - if patch == 0: - _install_requires.append("xformers==0.0.28.post2") - else: - _install_requires.append("xformers>=0.0.28.post3") + if install_xformers: + if patch == 0: + _install_requires.append("xformers==0.0.28.post2") + else: + _install_requires.append("xformers>=0.0.28.post3") extras_require_map.pop("vllm") elif (major, minor) >= (2, 4): extras_require_map.pop("vllm") - if patch == 0: - _install_requires.pop(_install_requires.index(xformers_version)) - _install_requires.append("xformers>=0.0.27") - else: - _install_requires.pop(_install_requires.index(xformers_version)) - _install_requires.append("xformers==0.0.28.post1") + if install_xformers: + if patch == 0: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.27") + else: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers==0.0.28.post1") else: raise ValueError("axolotl requires torch>=2.4")