Dockerfile torch fix (#987)
* add torch to requirements.txt at build time to force version to stick * fix xformers check * better handling of xformers based on installed torch version * fix for ci w/o torch
This commit is contained in:
2
.github/workflows/base.yml
vendored
2
.github/workflows/base.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.0
|
pytorch: 2.1.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
4
.github/workflows/main.yml
vendored
4
.github/workflows/main.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.0
|
pytorch: 2.1.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
@@ -80,7 +80,7 @@ jobs:
|
|||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.0
|
pytorch: 2.1.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
15
setup.py
15
setup.py
@@ -1,5 +1,7 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
@@ -22,12 +24,13 @@ def parse_requirements():
|
|||||||
# Handle standard packages
|
# Handle standard packages
|
||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
try:
|
||||||
if "torch==2.1.0" in _install_requires:
|
torch_version = version("torch")
|
||||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
if torch_version.startswith("2.1.1"):
|
||||||
_install_requires.append(
|
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
_install_requires.append("xformers==0.0.23")
|
||||||
)
|
except PackageNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user