add uv axolotl builds (#3431)
This commit is contained in:
32
.github/workflows/base.yml
vendored
32
.github/workflows/base.yml
vendored
@@ -59,14 +59,14 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "129"
|
# - cuda: "129"
|
||||||
cuda_version: 12.9.1
|
# cuda_version: 12.9.1
|
||||||
cudnn_version: ""
|
# cudnn_version: ""
|
||||||
python_version: "3.12"
|
# python_version: "3.12"
|
||||||
pytorch: 2.9.1
|
# pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
# dockerfile: "Dockerfile-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
# platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "130"
|
- cuda: "130"
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -181,14 +181,14 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "129"
|
# - cuda: "129"
|
||||||
cuda_version: 12.9.1
|
# cuda_version: 12.9.1
|
||||||
cudnn_version: ""
|
# cudnn_version: ""
|
||||||
python_version: "3.12"
|
# python_version: "3.12"
|
||||||
pytorch: 2.9.1
|
# pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
# dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
# platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "130"
|
- cuda: "130"
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
|
|||||||
162
.github/workflows/main.yml
vendored
162
.github/workflows/main.yml
vendored
@@ -40,12 +40,12 @@ jobs:
|
|||||||
pytorch: 2.10.0
|
pytorch: 2.10.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: 129
|
# - cuda: 129
|
||||||
cuda_version: 12.9.1
|
# cuda_version: 12.9.1
|
||||||
python_version: "3.12"
|
# python_version: "3.12"
|
||||||
pytorch: 2.9.1
|
# pytorch: 2.9.1
|
||||||
axolotl_extras:
|
# axolotl_extras:
|
||||||
platforms: "linux/amd64,linux/arm64"
|
# platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: 130
|
- cuda: 130
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -98,6 +98,77 @@ jobs:
|
|||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
|
build-axolotl-uv:
|
||||||
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
is_latest: true
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.12"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.12"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
runs-on: axolotl-gpu-runner
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Docker metadata
|
||||||
|
id: metadata
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: |
|
||||||
|
axolotlai/axolotl-uv
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=pep440,pattern={{version}}
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
||||||
|
- name: Build and export to Docker
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: ${{ matrix.platforms }}
|
||||||
|
build-args: |
|
||||||
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
|
CUDA=${{ matrix.cuda }}
|
||||||
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
|
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||||
|
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
|
||||||
|
file: ./docker/Dockerfile-uv
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: |
|
||||||
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
build-axolotl-cloud:
|
build-axolotl-cloud:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
@@ -130,12 +201,12 @@ jobs:
|
|||||||
pytorch: 2.10.0
|
pytorch: 2.10.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: 129
|
# - cuda: 129
|
||||||
cuda_version: 12.9.1
|
# cuda_version: 12.9.1
|
||||||
python_version: "3.12"
|
# python_version: "3.12"
|
||||||
pytorch: 2.9.1
|
# pytorch: 2.9.1
|
||||||
axolotl_extras:
|
# axolotl_extras:
|
||||||
platforms: "linux/amd64,linux/arm64"
|
# platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: 130
|
- cuda: 130
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -183,6 +254,73 @@ jobs:
|
|||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
|
build-axolotl-cloud-uv:
|
||||||
|
needs: build-axolotl-uv
|
||||||
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.12"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.12"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
runs-on: axolotl-gpu-runner
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Docker metadata
|
||||||
|
id: metadata
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: |
|
||||||
|
axolotlai/axolotl-cloud-uv
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=pep440,pattern={{version}}
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
- name: Build
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: ${{ matrix.platforms }}
|
||||||
|
build-args: |
|
||||||
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
CUDA=${{ matrix.cuda }}
|
||||||
|
file: ./docker/Dockerfile-cloud-uv
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: |
|
||||||
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
build-axolotl-cloud-no-tmux:
|
build-axolotl-cloud-no-tmux:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
|
|||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -264,8 +264,8 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 129
|
- cuda: 130
|
||||||
cuda_version: 12.9.1
|
cuda_version: 13.0.0
|
||||||
python_version: "3.12"
|
python_version: "3.12"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
|||||||
@@ -59,34 +59,18 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
RUN case "$PYTORCH_VERSION" in \
|
# Map Python version (e.g., 3.12 -> cp312)
|
||||||
2.9.[0-9]*) \
|
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||||
if [ "$CUDA" = "128" ]; then \
|
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||||
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
|
# Map architecture
|
||||||
WHL_VERSION="v0.5.4"; \
|
case "$TARGETARCH" in \
|
||||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
amd64) ARCH_TAG="x86_64" ;; \
|
||||||
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
|
arm64) ARCH_TAG="aarch64" ;; \
|
||||||
WHL_VERSION="v0.6.4"; \
|
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||||
else \
|
esac && \
|
||||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
WHL_VERSION="v0.7.16" && \
|
||||||
fi; \
|
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
|
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||||
pip3 install --no-cache-dir ${WHL_FILE}; \
|
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
||||||
rm ${WHL_FILE}; \
|
rm "${WHL_FILE}"
|
||||||
elif [ "$CUDA" = "130" ]; then \
|
|
||||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
|
|
||||||
WHL_VERSION="v0.5.4"; \
|
|
||||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
|
|
||||||
WHL_VERSION="v0.6.4"; \
|
|
||||||
else \
|
|
||||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
|
||||||
fi; \
|
|
||||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
|
|
||||||
pip3 install --no-cache-dir ${WHL_FILE}; \
|
|
||||||
rm ${WHL_FILE}; \
|
|
||||||
fi \
|
|
||||||
;; \
|
|
||||||
esac
|
|
||||||
|
|||||||
30
docker/Dockerfile-cloud-uv
Normal file
30
docker/Dockerfile-cloud-uv
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
ARG BASE_TAG=main
|
||||||
|
FROM axolotlai/axolotl-uv:$BASE_TAG
|
||||||
|
|
||||||
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
|
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
EXPOSE 8888
|
||||||
|
EXPOSE 22
|
||||||
|
|
||||||
|
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||||
|
COPY scripts/motd /etc/motd
|
||||||
|
|
||||||
|
RUN pip install jupyterlab notebook ipywidgets && \
|
||||||
|
jupyter lab clean
|
||||||
|
RUN apt update && \
|
||||||
|
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||||
|
rm -rf /var/cache/apt/archives && \
|
||||||
|
rm -rf /var/lib/apt/lists/* && \
|
||||||
|
mkdir -p ~/.ssh && \
|
||||||
|
chmod 700 ~/.ssh && \
|
||||||
|
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||||
|
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||||
|
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||||
|
chmod +x /root/cloud-entrypoint.sh && \
|
||||||
|
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
||||||
|
|
||||||
|
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
||||||
|
CMD ["sleep", "infinity"]
|
||||||
47
docker/Dockerfile-uv
Normal file
47
docker/Dockerfile-uv
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
ARG BASE_TAG=main-base
|
||||||
|
FROM axolotlai/axolotl-base-uv:$BASE_TAG
|
||||||
|
|
||||||
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
ARG AXOLOTL_ARGS=""
|
||||||
|
ARG CUDA="118"
|
||||||
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
|
ARG TARGETARCH
|
||||||
|
|
||||||
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
|
||||||
|
rm -rf /var/cache/apt/archives && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
|
||||||
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
|
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
|
else \
|
||||||
|
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
|
fi && \
|
||||||
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
|
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
|
else \
|
||||||
|
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
|
fi && \
|
||||||
|
python scripts/unsloth_install.py --uv | sh && \
|
||||||
|
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||||
|
uv pip install pytest && \
|
||||||
|
uv pip cache purge
|
||||||
|
|
||||||
|
# fix so that git fetch/pull from remote works with shallow clone
|
||||||
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
|
git config --get remote.origin.fetch && \
|
||||||
|
git config --global credential.helper store
|
||||||
|
|
||||||
|
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
|
||||||
|
RUN chmod +x /root/.axolotl-complete.bash && \
|
||||||
|
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc
|
||||||
@@ -6,6 +6,7 @@ ARG TARGETARCH
|
|||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
|
ARG TARGETARCH
|
||||||
ARG PYTHON_VERSION="3.11"
|
ARG PYTHON_VERSION="3.11"
|
||||||
ARG PYTORCH_VERSION="2.6.0"
|
ARG PYTORCH_VERSION="2.6.0"
|
||||||
ARG CUDA="126"
|
ARG CUDA="126"
|
||||||
@@ -39,28 +40,18 @@ RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
|||||||
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN case "$PYTORCH_VERSION" in \
|
# Map Python version (e.g., 3.12 -> cp312)
|
||||||
2.9.[0-9]*) \
|
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||||
if [ "$CUDA" = "128" ]; then \
|
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
# Map architecture
|
||||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
case "$TARGETARCH" in \
|
||||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
amd64) ARCH_TAG="x86_64" ;; \
|
||||||
elif [ "$CUDA" = "130" ]; then \
|
arm64) ARCH_TAG="aarch64" ;; \
|
||||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
esac && \
|
||||||
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
WHL_VERSION="v0.7.16" && \
|
||||||
fi \
|
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||||
if [ "$CUDA" = "128" ]; then \
|
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
||||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
rm "${WHL_FILE}"
|
||||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
|
||||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
|
||||||
elif [ "$CUDA" = "130" ]; then \
|
|
||||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
|
||||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
|
||||||
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
|
||||||
fi \
|
|
||||||
fi \
|
|
||||||
;; \
|
|
||||||
esac
|
|
||||||
|
|||||||
Reference in New Issue
Block a user