diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 260215ef6..dd294ecd9 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -21,6 +21,8 @@ jobs: timeout-minutes: 480 # this job needs to be run on self-hosted GPU runners... runs-on: ubuntu-latest-m + env: + HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }} strategy: fail-fast: false matrix: @@ -32,6 +34,7 @@ jobs: pytorch: 2.8.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-base" + platforms: "linux/amd64" - cuda: "128" cuda_version: 12.8.1 cudnn_version: "" @@ -39,6 +42,7 @@ jobs: pytorch: 2.9.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-base" + platforms: "linux/amd64,linux/arm64" - cuda: "128" cuda_version: 12.8.1 cudnn_version: "" @@ -46,6 +50,7 @@ jobs: pytorch: 2.9.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-base" + platforms: "linux/amd64,linux/arm64" - cuda: "130" cuda_version: 13.0.0 cudnn_version: "" @@ -53,6 +58,7 @@ jobs: pytorch: 2.9.1 torch_cuda_arch_list: "9.0+PTX" dockerfile: "Dockerfile-base" + platforms: "linux/amd64,linux/arm64" # - cuda: "128" # cuda_version: 12.8.1 # cudnn_version: "" @@ -79,7 +85,7 @@ jobs: axolotlai/axolotl-base - name: Login to Docker Hub uses: docker/login-action@v2 - if: ${{ github.event_name != 'pull_request' && secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }} + if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }} with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} @@ -90,7 +96,7 @@ jobs: with: context: . file: ./docker/${{ matrix.dockerfile }} - platforms: linux/amd64,linux/arm64 + platforms: ${{ matrix.platforms }} push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} labels: ${{ steps.metadata.outputs.labels }} @@ -105,6 +111,8 @@ jobs: if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }} timeout-minutes: 480 runs-on: ubuntu-latest-m + env: + HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }} strategy: fail-fast: false matrix: @@ -116,6 +124,7 @@ jobs: pytorch: 2.8.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-uv-base" + platforms: "linux/amd64" - cuda: "128" cuda_version: 12.8.1 cudnn_version: "" @@ -123,6 +132,7 @@ jobs: pytorch: 2.9.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-uv-base" + platforms: "linux/amd64,linux/arm64" - cuda: "128" cuda_version: 12.8.1 cudnn_version: "" @@ -130,6 +140,7 @@ jobs: pytorch: 2.9.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-uv-base" + platforms: "linux/amd64,linux/arm64" - cuda: "130" cuda_version: 13.0.0 cudnn_version: "" @@ -137,6 +148,7 @@ jobs: pytorch: 2.9.1 torch_cuda_arch_list: "9.0+PTX" dockerfile: "Dockerfile-uv-base" + platforms: "linux/amd64,linux/arm64" steps: - name: Checkout uses: actions/checkout@v4 @@ -148,6 +160,7 @@ jobs: axolotlai/axolotl-base-uv - name: Login to Docker Hub uses: docker/login-action@v2 + if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }} with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} @@ -158,6 +171,7 @@ jobs: with: context: . file: ./docker/${{ matrix.dockerfile }} + platforms: ${{ matrix.platforms }} push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} labels: ${{ steps.metadata.outputs.labels }} diff --git a/docker/Dockerfile-uv-base b/docker/Dockerfile-uv-base index 1b54c05e6..d28b27ad2 100644 --- a/docker/Dockerfile-uv-base +++ b/docker/Dockerfile-uv-base @@ -2,6 +2,7 @@ ARG CUDA_VERSION="12.6.3" ARG CUDNN_VERSION="" ARG UBUNTU_VERSION="22.04" ARG MAX_JOBS=4 +ARG TARGETARCH FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder @@ -31,20 +32,35 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}" RUN uv pip install packaging setuptools wheel psutil \ && uv pip install torch==${PYTORCH_VERSION} torchvision \ - && uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \ - && uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \ && uv pip install awscli pydantic +RUN if [ "$TARGETARCH" = "amd64" ]; then \ + uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \ + uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \ + fi + RUN case "$PYTORCH_VERSION" in \ 2.9.[0-9]*) \ - if [ "$CUDA" = "128" ]; then \ - wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - elif [ "$CUDA" = "130" ]; then \ - wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ - uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ - rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + if [ "$TARGETARCH" = "amd64" ]; then \ + if [ "$CUDA" = "128" ]; then \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + elif [ "$CUDA" = "130" ]; then \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + fi \ + elif [ "$TARGETARCH" = "arm64" ]; then \ + if [ "$CUDA" = "128" ]; then \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \ + uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \ + rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \ + elif [ "$CUDA" = "130" ]; then \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \ + uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \ + rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \ + fi \ fi \ ;; \ esac