From 3e0bbd33ece5cb03abf18b8258ec1b6af8468e02 Mon Sep 17 00:00:00 2001 From: "@TT" <1sand0s@users.noreply.github.com> Date: Mon, 12 Jan 2026 11:00:02 -0600 Subject: [PATCH] feat: add ARM64/AArch64 build support to Dockerfile-base (#3346) * Add support for capability to build arm64 image * Fixing wrong variable TARGETPLATFORM bug * Adding missing semicolons * skip docker hub login if PR (no push) or no credentials * Enabling arm64 builds for Dockerfile-base in Github actions * TARGETARCH automatically default to platform arch under build * Enabling arm64 builds for axolotl docker builds * Enabling arm64 builds for axolotl-cloud docker build Github actions --------- Co-authored-by: Wing Lian --- .github/workflows/base.yml | 2 ++ .github/workflows/main.yml | 3 +++ docker/Dockerfile-base | 46 +++++++++++++++++++++++++++++--------- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index ea721bff4..260215ef6 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -79,6 +79,7 @@ jobs: axolotlai/axolotl-base - name: Login to Docker Hub uses: docker/login-action@v2 + if: ${{ github.event_name != 'pull_request' && secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }} with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} @@ -89,6 +90,7 @@ jobs: with: context: . file: ./docker/${{ matrix.dockerfile }} + platforms: linux/amd64,linux/arm64 push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} labels: ${{ steps.metadata.outputs.labels }} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 052f9aa72..19cef5de4 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -61,6 +61,7 @@ jobs: uses: docker/build-push-action@v5 with: context: . + platforms: linux/amd64,linux/arm64 build-args: | BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} CUDA=${{ matrix.cuda }} @@ -127,6 +128,7 @@ jobs: uses: docker/build-push-action@v5 with: context: . + platforms: linux/amd64,linux/arm64 build-args: | BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} CUDA=${{ matrix.cuda }} @@ -180,6 +182,7 @@ jobs: uses: docker/build-push-action@v5 with: context: . + platforms: linux/amd64,linux/arm64 build-args: | BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} CUDA=${{ matrix.cuda }} diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index e080758c2..96367207f 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -2,11 +2,13 @@ ARG CUDA_VERSION="11.8.0" ARG CUDNN_VERSION="8" ARG UBUNTU_VERSION="22.04" ARG MAX_JOBS=4 +ARG TARGETARCH FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder ENV PATH="/root/miniconda3/bin:${PATH}" +ARG TARGETARCH ARG PYTHON_VERSION="3.11" ARG PYTORCH_VERSION="2.1.2" ARG CUDA="128" @@ -22,11 +24,17 @@ RUN apt-get update \ librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \ && rm -rf /var/cache/apt/archives \ && rm -rf /var/lib/apt/lists/* \ - && wget \ - https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ + && if [ "$TARGETARCH" = "amd64" ]; then \ + MINICONDA_ARCH="x86_64"; \ + elif [ "$TARGETARCH" = "arm64" ]; then \ + MINICONDA_ARCH="aarch64"; \ + else \ + echo "Unsupported architecture: $TARGETARCH"; exit 1; \ + fi \ + && wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \ && mkdir /root/.conda \ - && bash Miniconda3-latest-Linux-x86_64.sh -b \ - && rm -f Miniconda3-latest-Linux-x86_64.sh \ + && bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \ + && rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \ && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" @@ -54,13 +62,31 @@ RUN git lfs install --skip-repo && \ RUN case "$PYTORCH_VERSION" in \ 2.9.[0-9]*) \ if [ "$CUDA" = "128" ]; then \ - wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + if [ "$TARGETARCH" = "amd64" ]; then \ + WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \ + WHL_VERSION="v0.5.4"; \ + elif [ "$TARGETARCH" = "arm64" ]; then \ + WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \ + WHL_VERSION="v0.6.4"; \ + else \ + echo "Unsupported architecture: $TARGETARCH"; exit 1; \ + fi; \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \ + pip3 install --no-cache-dir ${WHL_FILE}; \ + rm ${WHL_FILE}; \ elif [ "$CUDA" = "130" ]; then \ - wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ - pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ - rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + if [ "$TARGETARCH" = "amd64" ]; then \ + WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \ + WHL_VERSION="v0.5.4"; \ + elif [ "$TARGETARCH" = "arm64" ]; then \ + WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \ + WHL_VERSION="v0.6.4"; \ + else \ + echo "Unsupported architecture: $TARGETARCH"; exit 1; \ + fi; \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \ + pip3 install --no-cache-dir ${WHL_FILE}; \ + rm ${WHL_FILE}; \ fi \ ;; \ esac