Dockerfile torch fix (#987)

* add torch to requirements.txt at build time to force version to stick

* fix xformers check

* better handling of xformers based on installed torch version

* fix for ci w/o torch
This commit is contained in:
Wing Lian
2023-12-21 09:38:20 -05:00
committed by GitHub
parent d25c34caa6
commit 161bcb6517
4 changed files with 12 additions and 10 deletions

View File

@@ -28,7 +28,7 @@ jobs:
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.0 pytorch: 2.1.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -27,7 +27,7 @@ jobs:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.0 pytorch: 2.1.1
axolotl_extras: axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
@@ -80,7 +80,7 @@ jobs:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.0 pytorch: 2.1.1
axolotl_extras: axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:

View File

@@ -19,7 +19,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
else \ else \

View File

@@ -1,5 +1,7 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup from setuptools import find_packages, setup
@@ -22,12 +24,13 @@ def parse_requirements():
# Handle standard packages # Handle standard packages
_install_requires.append(line) _install_requires.append(line)
# TODO(wing) remove once xformers release supports torch 2.1.0 try:
if "torch==2.1.0" in _install_requires: torch_version = version("torch")
_install_requires.pop(_install_requires.index("xformers>=0.0.22")) if torch_version.startswith("2.1.1"):
_install_requires.append( _install_requires.pop(_install_requires.index("xformers==0.0.22"))
"xformers @ git+https://github.com/facebookresearch/xformers.git@main" _install_requires.append("xformers==0.0.23")
) except PackageNotFoundError:
pass
return _install_requires, _dependency_links return _install_requires, _dependency_links