From 684b543aa1899d2115a30d9c9f728d348743bb7b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 6 Dec 2024 11:07:27 -0500 Subject: [PATCH] experiment with nvcr pytorch image for torch 2.5.1 --- .github/workflows/base.yml | 16 ++++++++++------ docker/Dockerfile | 3 ++- docker/Dockerfile-base | 4 +++- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 640d2cd7a..ea3db0e2d 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -28,30 +28,32 @@ jobs: python_version: "3.10" pytorch: 2.3.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + from_base_img: "" + from_base_tag: "" - cuda: "121" cuda_version: 12.1.1 cudnn_version: 8 python_version: "3.11" pytorch: 2.3.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - - cuda: "124" - cuda_version: 12.4.1 - cudnn_version: "" - python_version: "3.10" - pytorch: 2.4.1 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + from_base_img: "" + from_base_tag: "" - cuda: "124" cuda_version: 12.4.1 cudnn_version: "" python_version: "3.11" pytorch: 2.4.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + from_base_img: "" + from_base_tag: "" - cuda: "124" cuda_version: 12.4.1 cudnn_version: "" python_version: "3.11" pytorch: 2.5.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + from_base_img: nvcr.io/nvidia/pytorch + from_base_tag: 24.10-py3 steps: - name: Checkout uses: actions/checkout@v4 @@ -84,3 +86,5 @@ jobs: PYTHON_VERSION=${{ matrix.python_version }} PYTORCH_VERSION=${{ matrix.pytorch }} TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }} + BASE_IMAGE=${{ matrix.from_base_img || '' }} + BASE_TAG=${{ matrix.from_base_tag || '' }} diff --git a/docker/Dockerfile b/docker/Dockerfile index 261f0e12a..658a81943 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,5 +1,6 @@ +ARG BASE_IMAGE=axolotlai/axolotl-base ARG BASE_TAG=main-base -FROM axolotlai/axolotl-base:$BASE_TAG +FROM $BASE_IMAGE:$BASE_TAG ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG AXOLOTL_EXTRAS="" diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 7eab3b3e4..b2308b02a 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -3,7 +3,9 @@ ARG CUDNN_VERSION="8" ARG UBUNTU_VERSION="22.04" ARG MAX_JOBS=4 -FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder +ARG BASE_IMAGE=nvidia/cuda +ARG BASE_TAG="" +FROM $BASE_IMAGE:${BASE_TAG:-$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION} AS base-builder ENV PATH="/root/miniconda3/bin:${PATH}"