Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
208f8b253f add validation for DFT 2026-01-13 09:33:04 -05:00
Wing Lian
75ad1a9932 use dynamic finetuning with chunked cross entropy 2026-01-13 09:33:04 -05:00
49 changed files with 186 additions and 1007 deletions

View File

@@ -15,11 +15,6 @@
<!--- Include details of your testing environment, tests ran to see how --> <!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. --> <!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate) ## Screenshots (if appropriate)
## Types of changes ## Types of changes

View File

@@ -21,8 +21,6 @@ jobs:
timeout-minutes: 480 timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@@ -34,7 +32,6 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -42,7 +39,6 @@ jobs:
pytorch: 2.9.0 pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -50,15 +46,6 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "129"
cuda_version: 12.9.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -66,15 +53,6 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128" # - cuda: "128"
# cuda_version: 12.8.1 # cuda_version: 12.8.1
# cudnn_version: "" # cudnn_version: ""
@@ -101,7 +79,7 @@ jobs:
axolotlai/axolotl-base axolotlai/axolotl-base
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }} if: ${{ github.event_name != 'pull_request' && secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -112,7 +90,7 @@ jobs:
with: with:
context: . context: .
file: ./docker/${{ matrix.dockerfile }} file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }} platforms: linux/amd64,linux/arm64
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
@@ -127,8 +105,6 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }} if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480 timeout-minutes: 480
runs-on: ubuntu-latest-m runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@@ -140,7 +116,6 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -148,7 +123,6 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -156,15 +130,6 @@ jobs:
pytorch: 2.9.0 pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "129"
cuda_version: 12.9.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -172,15 +137,6 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -192,7 +148,6 @@ jobs:
axolotlai/axolotl-base-uv axolotlai/axolotl-base-uv
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -203,7 +158,6 @@ jobs:
with: with:
context: . context: .
file: ./docker/${{ matrix.dockerfile }} file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,32 +20,22 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64" is_latest: true
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: vllm
platforms: "linux/amd64,linux/arm64"
- cuda: 130 - cuda: 130
cuda_version: 13.0.0 cuda_version: 13.0.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -71,7 +61,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: ${{ matrix.platforms }} platforms: linux/amd64,linux/arm64
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
@@ -98,32 +88,22 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64" is_latest: true
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130 - cuda: 130
cuda_version: 13.0.0 cuda_version: 13.0.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -148,7 +128,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: ${{ matrix.platforms }} platforms: linux/amd64,linux/arm64
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
@@ -169,11 +149,11 @@ jobs:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
is_latest: true is_latest:
- cuda: 130 - cuda: 128
cuda_version: 13.0.0 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:

View File

@@ -35,26 +35,21 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: fbgemm-gpu axolotl_extras: fbgemm-gpu
num_gpus: 2 num_gpus: 2
nightly_build: "true"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu" axolotl_extras: fbgemm-gpu
num_gpus: 2 num_gpus: 2
- cuda: 129 nightly_build: "true"
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu,vllm"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
- cuda: 130 - cuda: 130
cuda_version: 13.0.0 cuda_version: 13.0.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras: fbgemm-gpu
# axolotl_extras: fbgemm-gpu
num_gpus: 2 num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
@@ -76,8 +71,8 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run -m cicd.multigpu modal run -m cicd.multigpu

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install wheel packaging==26.0 pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e . pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
@@ -48,9 +48,9 @@ jobs:
id: tag id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3) run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in VERSION file - name: Update version in setup.py
run: | run: |
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a source dist - name: Build a source dist
run: | run: |

View File

@@ -48,7 +48,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |

View File

@@ -54,13 +54,8 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11", "3.12"] python_version: ["3.11"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"] pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -87,7 +82,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -149,13 +144,8 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11", "3.12"] python_version: ["3.11"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"] pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -182,7 +172,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -264,12 +254,12 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 129 - cuda: 128
cuda_version: 12.9.1 cuda_version: 12.8.1
python_version: "3.12" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.8.0
num_gpus: 1 num_gpus: 1
axolotl_extras: vllm axolotl_extras:
dockerfile: "Dockerfile-uv.jinja" dockerfile: "Dockerfile-uv.jinja"
steps: steps:
- name: Checkout - name: Checkout
@@ -369,9 +359,9 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 129 - cuda: 128
cuda_version: 12.9.1 cuda_version: 12.8.1
python_version: "3.12" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:

View File

@@ -88,7 +88,7 @@ Features:
#### Using pip #### Using pip
```bash ```bash
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs # Download example axolotl configs, deepspeed configs

View File

@@ -1 +0,0 @@
0.13.2

View File

@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \ sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi fi
RUN uv pip install packaging==26.0 setuptools==75.8.0 RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN uv pip install torchvision RUN uv pip install torchvision
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \ sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi fi
RUN pip install packaging==26.0 setuptools==75.8.0 psutil RUN pip install packaging==23.2 setuptools==75.8.0 psutil
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \

View File

@@ -17,8 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment( template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape() loader=template_loader, autoescape=select_autoescape()
) )
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja") df_template = template_env.get_template("Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_args = { df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""), "AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
@@ -28,11 +27,8 @@ df_args = {
"CUDA": os.environ.get("CUDA", "126"), "CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""), "CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub", "HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
} }
dockerfile_contents = df_template.render(**df_args) dockerfile_contents = df_template.render(**df_args)

View File

@@ -6,7 +6,6 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS="" ARG AXOLOTL_ARGS=""
ARG CUDA="118" ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -21,17 +20,13 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64 # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$TARGETARCH" = "arm64" ]; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi && \ fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \ python scripts/unsloth_install.py | sh && \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \ python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \ pip install pytest && \
pip cache purge pip cache purge

View File

@@ -43,7 +43,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \ RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip cache purge python3 -m pip cache purge

View File

@@ -30,7 +30,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \ RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \ python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \ python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \

View File

@@ -2,7 +2,6 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION="" ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
@@ -32,35 +31,20 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \ RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \ && uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic && uv pip install awscli pydantic
RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi
RUN case "$PYTORCH_VERSION" in \ RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \ 2.9.[0-9]*) \
if [ "$TARGETARCH" = "amd64" ]; then \ if [ "$CUDA" = "128" ]; then \
if [ "$CUDA" = "128" ]; then \ wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ elif [ "$CUDA" = "130" ]; then \
elif [ "$CUDA" = "130" ]; then \ wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
elif [ "$TARGETARCH" = "arm64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
fi \ fi \
;; \ ;; \
esac esac

View File

@@ -17,7 +17,6 @@ feedback. Various methods include, but not limited to:
- [Kahneman-Tversky Optimization (KTO)](#kto) - [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo) - [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo) - [Group Relative Policy Optimization (GRPO)](#grpo)
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
## RLHF using Axolotl ## RLHF using Axolotl
@@ -721,102 +720,6 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types). For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
::: {.callout-tip}
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
:::
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
rl: gdpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: true
num_generations: 4
reward_funcs:
- rewards.format_reward
- rewards.correctness_reward
reward_weights: [1.0, 2.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform
```
You can also use GRPO with explicit aggregation control:
```yaml
rl: grpo
trl:
multi_objective_aggregation: normalize_then_sum # GDPO behavior
# or: sum_then_normalize # Default GRPO behavior
```
#### GDPO vs GRPO
| Aspect | GRPO | GDPO |
|--------|------|------|
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
| **Multi-reward** | May collapse advantages | Preserves reward signals |
| **Single reward** | Standard behavior | Equivalent to GRPO |
#### Why GDPO?
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
```
# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3
```
GDPO normalizes each reward independently, preserving their relative differences.
#### Reward Functions
GDPO uses the same reward function format as GRPO:
```python
# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(completions, answers, **kwargs) -> list[float]:
rewards = []
for completion, answer in zip(completions, answers):
# Your scoring logic here
rewards.append(score)
return rewards
```
#### Sequence Parallelism
GDPO supports sequence parallelism for long-context training:
```yaml
rl: gdpo
context_parallel_size: 2
```
### SimPO ### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function. SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]' pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -17,7 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]' pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -16,7 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-1b-it base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0.05
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-270m-it base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0.05
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -2,7 +2,6 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too # Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true load_in_4bit: true
@@ -33,8 +32,8 @@ sample_packing: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0.05
lora_target_linear: true lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:

View File

@@ -10,7 +10,7 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -14,7 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]' pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -13,7 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]' pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -1,68 +0,0 @@
base_model: meta-llama/Llama-3.2-1B-Instruct
chat_template: llama3
rl: gdpo
trl:
beta: 0.001
max_completion_length: 128
num_generations: 2
temperature: 0.7
top_p: 0.95
use_vllm: false
multi_objective_aggregation: normalize_then_sum
reward_funcs:
- rwd.format_reward
- rwd.correctness_reward
reward_weights: [1.0, 2.0]
log_completions: true
num_completions_to_print: 3
scale_rewards: true
datasets:
- path: openai/gsm8k
name: main
split: train[:1000]
type: rwd.gsm8k_transform
val_set_size: 0.0
output_dir: ./outputs/llama3-gdpo-out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
max_steps: 100
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
weight_decay: 0.01
warmup_steps: 10
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
flash_attention: true
logging_steps: 1
save_steps: 50
save_safetensors: true
special_tokens:
pad_token: "<|end_of_text|>"
seed: 42

View File

@@ -14,7 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.7.0 min) # Ensure you have Pytorch installed (Pytorch 2.7.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]' pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -12,7 +12,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -1,5 +1,5 @@
[build-system] [build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"] requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
@@ -24,9 +24,6 @@ Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
py-modules = ["setuptools_axolotl_dynamic_dependencies"] py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true include-package-data = true
[tool.setuptools.dynamic]
version = { file = "VERSION" }
[tool.setuptools.cmdclass] [tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand" build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"

View File

@@ -8,16 +8,16 @@ xformers>=0.0.23.post1
liger-kernel==0.6.4 liger-kernel==0.6.4
# END section # END section
packaging==26.0 packaging==23.2
huggingface_hub>=0.36.0 huggingface_hub>=0.36.0
peft>=0.18.1 peft>=0.18.0
tokenizers>=0.22.1 tokenizers>=0.22.1
transformers==4.57.6 transformers==4.57.1
accelerate==1.12.0 accelerate==1.12.0
datasets==4.5.0 datasets==4.4.2
deepspeed>=0.18.3 deepspeed>=0.18.3
trl==0.27.0 trl==0.25.1
hf_xet==1.2.0 hf_xet==1.2.0
kernels==0.11.5 kernels==0.11.5
trackio>=0.13.0 trackio>=0.13.0
@@ -72,4 +72,4 @@ axolotl-contribs-mit==0.0.6
# telemetry # telemetry
posthog==6.7.11 posthog==6.7.11
mistral-common==1.8.8 mistral-common==1.8.6

View File

@@ -1,5 +1,6 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
import ast
import os import os
import platform import platform
import re import re
@@ -25,7 +26,6 @@ def parse_requirements(extras_require_map):
_install_requires.append(line) _install_requires.append(line)
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# skip packages not compatible with OSX # skip packages not compatible with OSX
skip_packages = [ skip_packages = [
@@ -62,68 +62,44 @@ def parse_requirements(extras_require_map):
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
torch_parts = torch_version.split("+")
if len(torch_parts) == 2:
torch_cuda_version = torch_parts[1]
_dependency_links.append(
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 9): if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu") extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [ extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"] extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:
extras_require_map["vllm"] = ["vllm==0.14.0"]
elif (major, minor) >= (2, 8): elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu") extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"] extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"] extras_require_map["vllm"] = ["vllm==0.11.0"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 7): elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:
if install_xformers: _install_requires.append("xformers==0.0.30")
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers # vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
else: else:
if install_xformers: _install_requires.append("xformers==0.0.31")
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"] extras_require_map["vllm"] = ["vllm==0.10.1"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if install_xformers: _install_requires.append("xformers==0.0.29.post3")
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126 # since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126") _dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if install_xformers: if patch == 0:
if patch == 0: _install_requires.append("xformers==0.0.28.post2")
_install_requires.append("xformers==0.0.28.post2") else:
else: _install_requires.append("xformers>=0.0.28.post3")
_install_requires.append("xformers>=0.0.28.post3")
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4): elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
if install_xformers: if patch == 0:
if patch == 0: _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27")
_install_requires.append("xformers>=0.0.27") else:
else: _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers==0.0.28.post1")
_install_requires.append("xformers==0.0.28.post1")
else: else:
raise ValueError("axolotl requires torch>=2.4") raise ValueError("axolotl requires torch>=2.4")
@@ -134,11 +110,15 @@ def parse_requirements(extras_require_map):
def get_package_version(): def get_package_version():
with open( with open(
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION", Path(os.path.dirname(os.path.abspath(__file__)))
/ "src"
/ "axolotl"
/ "__init__.py",
"r", "r",
encoding="utf-8", encoding="utf-8",
) as fin: ) as fin:
version_ = fin.read().strip() version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
version_ = ast.literal_eval(version_match.group(1))
return version_ return version_

View File

@@ -1,11 +1,7 @@
"""Axolotl - Train and fine-tune large language models""" """Axolotl - Train and fine-tune large language models"""
import pkgutil import pkgutil
from importlib.metadata import PackageNotFoundError, version
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package __path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
try: __version__ = "0.13.0.dev"
__version__ = version("axolotl")
except PackageNotFoundError:
__version__ = "unknown"

View File

@@ -52,11 +52,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls = None trainer_cls = None
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}: if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1 sequence_parallel=self.cfg.context_parallel_size > 1
) )
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]: elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
@@ -146,8 +147,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl is RLType.KTO: elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs = ["max_prompt_length"]
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0 self.cfg.kto_desirable_weight or 1.0
@@ -156,14 +155,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0 self.cfg.kto_undesirable_weight or 1.0
) )
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}: elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class() training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
training_args_kwargs.setdefault(
"multi_objective_aggregation", "normalize_then_sum"
)
elif self.cfg.rl in [RLType.DPO, RLType.IPO]: elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig training_args_cls = AxolotlDPOConfig

View File

@@ -129,11 +129,6 @@ class GRPOStrategy:
if trl.rollout_func: if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func) grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod

View File

@@ -153,9 +153,12 @@ class PatchManager:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks: if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks) patch_chunked_ce_loss_fn(
self.cfg.chunked_cross_entropy_num_chunks,
use_dft=self.cfg.use_dynamic_finetuning,
)
else: else:
patch_chunked_ce_loss_fn() patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning)
def _apply_fsdp_patches(self): def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations.""" """Apply patches for FSDP configurations."""

View File

@@ -5,7 +5,6 @@ from typing import Type
import addict import addict
import torch import torch
import transformers
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -154,9 +153,6 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
This function determines the appropriate model config source, loads it, applies any This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config. necessary overrides, and validates it for compatibility with the `axolotl` config.
If `cfg.cls_model_config` is set, a custom config class from transformers will be
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -178,13 +174,8 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
if cfg.num_labels: if cfg.num_labels:
# num_labels is used to initialize classifier models # num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels config_kwargs["num_labels"] = cfg.num_labels
config_cls = AutoConfig
if cfg.cls_model_config:
config_cls = getattr(transformers, cfg.cls_model_config)
try: try:
model_config = config_cls.from_pretrained( model_config = AutoConfig.from_pretrained(
model_config_name, model_config_name,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**config_kwargs, **config_kwargs,

View File

@@ -16,10 +16,16 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390 For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
""" """
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): def __init__(
self,
num_output_chunks: int = 8,
ignore_index: int = -100,
use_dft: bool = False,
):
super().__init__() super().__init__()
self.num_output_chunks = num_output_chunks self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.use_dft = use_dft
def compute_cross_entropy( def compute_cross_entropy(
self, self,
@@ -30,10 +36,30 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
""" """
Upcast logits to fp32 and compute cross entropy loss. Upcast logits to fp32 and compute cross entropy loss.
""" """
return F.cross_entropy( ce_loss = F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" logits.float(), labels, ignore_index=self.ignore_index, reduction="none"
) )
if self.use_dft:
# Compute probabilities and gather the ones corresponding to labels
with torch.no_grad(): # Stop gradient
probs = torch.softmax(logits.float(), dim=-1)
# Create mask for valid tokens (not ignore_index)
valid_mask = labels != self.ignore_index
# Gather probabilities for the correct tokens
label_probs = probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
# Apply mask to only scale valid tokens
label_probs = label_probs * valid_mask
# Avoid multiplication by 0 for ignored tokens
label_probs = torch.where(
valid_mask, label_probs, torch.ones_like(label_probs)
)
# Scale the loss by the probability (DFT)
ce_loss = ce_loss * label_probs
return ce_loss.sum()
def forward( def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum" self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor: ) -> torch.Tensor:
@@ -71,16 +97,20 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
return total_loss / total_elements return total_loss / total_elements
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): def _build_chunked_ce_loss_fn(
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index, use_dft)
loss_fn_ce.compute_cross_entropy = torch.compile( loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor" loss_fn_ce.compute_cross_entropy, backend="inductor"
) )
return loss_fn_ce return loss_fn_ce
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100): def get_causal_lm_loss(
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index) num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index, use_dft)
def chunked_fix_cross_entropy( def chunked_fix_cross_entropy(
source, source,
@@ -124,10 +154,14 @@ def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
return for_causal_lm_chunked_loss return for_causal_lm_chunked_loss
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): def patch_chunked_ce_loss_fn(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
import transformers.loss.loss_utils import transformers.loss.loss_utils
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index) for_causal_lm_chunked_loss = get_causal_lm_loss(
num_output_chunks, ignore_index, use_dft
)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = ( transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss for_causal_lm_chunked_loss

View File

@@ -173,7 +173,7 @@ def _drop_long_sequences(
return (len_prompt + len_completion) <= sequence_len return (len_prompt + len_completion) <= sequence_len
if rl in {RLType.GRPO, RLType.GDPO}: if rl is RLType.GRPO:
return True return True
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")

View File

@@ -664,6 +664,13 @@ class AxolotlInputConfig(
}, },
) )
use_dynamic_finetuning: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use dynamic fine-tuning for scaled SFT gradients."
},
)
chunked_cross_entropy: bool | None = Field( chunked_cross_entropy: bool | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -26,7 +26,6 @@ class RLType(str, Enum):
"""RL trainer type configuration subset""" """RL trainer type configuration subset"""
DPO = "dpo" DPO = "dpo"
GDPO = "gdpo"
GRPO = "grpo" GRPO = "grpo"
IPO = "ipo" IPO = "ipo"
ORPO = "orpo" ORPO = "orpo"

View File

@@ -25,12 +25,7 @@ class ModelInputConfig(BaseModel):
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model" "description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
}, },
) )
cls_model_config: str | None = Field( cls_model_config: str | None = None
default=None,
json_schema_extra={
"description": "transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig."
},
)
tokenizer_config: str | None = Field( tokenizer_config: str | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -179,13 +179,3 @@ class TRLConfig(BaseModel):
"description": "Path to custom rollout function. Must be importable from current dir." "description": "Path to custom rollout function. Must be importable from current dir."
}, },
) )
multi_objective_aggregation: (
Literal["sum_then_normalize", "normalize_then_sum"] | None
) = Field(
default=None,
json_schema_extra={
"description": "Multi-objective reward aggregation strategy. "
"'sum_then_normalize' (GRPO default): weights and sums rewards first, then normalizes. "
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
},
)

View File

@@ -434,6 +434,18 @@ class TrainingValidationMixin:
return data return data
@model_validator(mode="before")
@classmethod
def check_ao_optim_fsdp2_offload(cls, data):
if data.get("fsdp_config") and data.get("fsdp_config", {}).get(
"offload_params"
):
if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]:
raise ValueError(
"low bit ao optimizers is not supported with FSDP2 w/ offload_params."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_use_reentrant_mismatch(cls, data): def check_use_reentrant_mismatch(cls, data):
@@ -557,6 +569,20 @@ class TrainingValidationMixin:
return data return data
class CELossValidationMixin:
"""Validation methods related to CE loss configuration."""
@model_validator(mode="before")
@classmethod
def check_dft_loss_fn(cls, data):
if data.get("use_dynamic_finetuning"):
if not data.get("chunked_cross_entropy"):
raise ValueError(
"`use_dynamic_finetuning` requires `chunked_cross_entropy`"
)
return data
class LoRAValidationMixin: class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration.""" """Validation methods related to LoRA/QLoRA configuration."""
@@ -746,19 +772,6 @@ class RLValidationMixin:
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_gdpo(cls, data):
if (
data.get("rl") == "gdpo"
and data.get("trl", {}).get("multi_objective_aggregation")
== "sum_then_normalize"
):
raise ValueError(
"`multi_objective_aggregation` value set as `sum_then_normalize` => GRPO, but GDPO was selected"
)
return data
class OptimizationValidationMixin: class OptimizationValidationMixin:
"""Validation methods related to optimization and performance.""" """Validation methods related to optimization and performance."""
@@ -1477,6 +1490,7 @@ class ValidationMixin(
DatasetValidationMixin, DatasetValidationMixin,
AttentionValidationMixin, AttentionValidationMixin,
TrainingValidationMixin, TrainingValidationMixin,
CELossValidationMixin,
LoRAValidationMixin, LoRAValidationMixin,
RLValidationMixin, RLValidationMixin,
OptimizationValidationMixin, OptimizationValidationMixin,

View File

@@ -311,6 +311,7 @@ class TestHFRLTrainerBuilder:
# KTO specific # KTO specific
assert training_arguments.desirable_weight == 1.0 assert training_arguments.desirable_weight == 1.0
assert training_arguments.undesirable_weight == 1.0 assert training_arguments.undesirable_weight == 1.0
assert training_arguments.max_prompt_length == 512
def _write_rewards_file(self, rewards_dir: Path): def _write_rewards_file(self, rewards_dir: Path):
""" """

View File

@@ -1,538 +0,0 @@
"""
GDPO test suite
GDPO uses TRL's multi_objective_aggregation="normalize_then_sum" for
per-reward normalization in multi-reward RL training.
"""
import os
import random
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.multigpu.solo.test_grpo import recursive_kill, start_vllm
from tests.e2e.utils import require_vllm
@pytest.mark.skip(reason="flaky vllm tests in modal")
class TestGDPO:
"""Test case for GDPO training using TRL's native multi-objective aggregation."""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_gdpo_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def format_reward(prompts, completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(prompts, completions, **kwargs) -> list[float]:
return [random.uniform(-1, 3) for _ in completions]
def safety_reward(prompts, completions, **kwargs) -> list[float]:
return [1.0 if 'error' not in c.lower() else 0.0 for c in completions]
def single_reward(prompts, completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]}],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize("num_gpus", [1, 2])
@require_vllm
def test_gdpo_multi_reward_lora(self, temp_dir, num_gpus):
"""Test GDPO with multiple reward functions using LoRA."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.format_reward",
f"rewards_gdpo_{rnd_suffix}.correctness_reward",
],
"reward_weights": [1.0, 2.0],
"scale_rewards": True,
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)
@require_vllm
def test_gdpo_three_rewards(self, temp_dir):
"""Test GDPO with three reward functions (format, correctness, safety)."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.format_reward",
f"rewards_gdpo_{rnd_suffix}.correctness_reward",
f"rewards_gdpo_{rnd_suffix}.safety_reward",
],
"reward_weights": [1.0, 2.0, 1.5],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)
@require_vllm
def test_gdpo_single_reward_fallback(self, temp_dir):
"""Test GDPO with single reward."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.single_reward",
],
"reward_weights": [1.0],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)
@require_vllm
def test_gdpo_fft(self, temp_dir):
"""Test GDPO with full fine-tuning (no adapter)."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.format_reward",
f"rewards_gdpo_{rnd_suffix}.correctness_reward",
],
"reward_weights": [1.0, 2.0],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
# No adapter - full fine-tuning
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)
@require_vllm
def test_gdpo_sequence_parallel(self, temp_dir):
"""Test GDPO with sequence parallelism."""
rnd_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "gdpo",
"context_parallel_size": 2,
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
f"rewards_gdpo_{rnd_suffix}.format_reward",
f"rewards_gdpo_{rnd_suffix}.correctness_reward",
],
"reward_weights": [1.0, 2.0],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_gdpo_{rnd_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)