Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
d260eeb57d match protected method 2026-02-15 07:55:55 -05:00
Wing Lian
5a7f007d20 cleanup ao fp8 patching 2026-02-13 17:02:23 -05:00
72 changed files with 331 additions and 6758 deletions

View File

@@ -51,22 +51,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "129"
cuda_version: 12.8.1 cuda_version: 12.9.1
cudnn_version: "" cudnn_version: ""
python_version: "3.12" python_version: "3.12"
pytorch: 2.10.0 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -83,14 +75,6 @@ jobs:
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128" # - cuda: "128"
# cuda_version: 12.8.1 # cuda_version: 12.8.1
# cudnn_version: "" # cudnn_version: ""
@@ -173,22 +157,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "129"
cuda_version: 12.8.1 cuda_version: 12.9.1
cudnn_version: "" cudnn_version: ""
python_version: "3.12" python_version: "3.12"
pytorch: 2.10.0 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-uv-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -205,14 +181,6 @@ jobs:
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -34,28 +34,16 @@ jobs:
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
is_latest: true is_latest: true
- cuda: 128 - cuda: 129
cuda_version: 12.8.1 cuda_version: 12.9.1
python_version: "3.12" python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
- cuda: 130 - cuda: 130
cuda_version: 13.0.0 cuda_version: 13.0.0
python_version: "3.12" python_version: "3.11"
pytorch: 2.10.0 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
@@ -98,77 +86,6 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-uv:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
- name: Build and export to Docker
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud: build-axolotl-cloud:
needs: build-axolotl needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }} if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -195,28 +112,16 @@ jobs:
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 129
cuda_version: 12.8.1 cuda_version: 12.9.1
python_version: "3.12" python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
- cuda: 130 - cuda: 130
cuda_version: 13.0.0 cuda_version: 13.0.0
python_version: "3.12" python_version: "3.11"
pytorch: 2.10.0 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
@@ -254,73 +159,6 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-uv:
needs: build-axolotl-uv
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-cloud-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux: build-axolotl-cloud-no-tmux:
needs: build-axolotl needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }} if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}

View File

@@ -37,7 +37,7 @@ jobs:
id: hf-cache-restore-s3 id: hf-cache-restore-s3
run: | run: |
mkdir -p /home/runner/.cache/huggingface/hub mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5

View File

@@ -75,7 +75,7 @@ jobs:
id: hf-cache-restore-s3 id: hf-cache-restore-s3
run: | run: |
mkdir -p ~/.cache/huggingface/hub mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1 curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/ ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python - name: Setup Python
@@ -170,7 +170,7 @@ jobs:
id: hf-cache-restore-s3 id: hf-cache-restore-s3
run: | run: |
mkdir -p ~/.cache/huggingface/hub mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1 curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/ ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python - name: Setup Python
@@ -264,8 +264,8 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 130 - cuda: 129
cuda_version: 13.0.0 cuda_version: 12.9.1
python_version: "3.12" python_version: "3.12"
pytorch: 2.9.1 pytorch: 2.9.1
num_gpus: 1 num_gpus: 1

View File

@@ -59,18 +59,34 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \ pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312) RUN case "$PYTORCH_VERSION" in \
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \ 2.9.[0-9]*) \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10) if [ "$CUDA" = "128" ]; then \
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \ if [ "$TARGETARCH" = "amd64" ]; then \
# Map architecture WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
case "$TARGETARCH" in \ WHL_VERSION="v0.5.4"; \
amd64) ARCH_TAG="x86_64" ;; \ elif [ "$TARGETARCH" = "arm64" ]; then \
arm64) ARCH_TAG="aarch64" ;; \ WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \ WHL_VERSION="v0.6.4"; \
esac && \ else \
WHL_VERSION="v0.7.16" && \ echo "Unsupported architecture: $TARGETARCH"; exit 1; \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \ fi; \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \ wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir "${WHL_FILE}" && \ pip3 install --no-cache-dir ${WHL_FILE}; \
rm "${WHL_FILE}" rm ${WHL_FILE}; \
elif [ "$CUDA" = "130" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
fi \
;; \
esac

View File

@@ -1,30 +0,0 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl-uv:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -1,47 +0,0 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base-uv:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch && \
git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -6,7 +6,6 @@ ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG TARGETARCH
ARG PYTHON_VERSION="3.11" ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0" ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126" ARG CUDA="126"
@@ -40,18 +39,28 @@ RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \ uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi fi
# Map Python version (e.g., 3.12 -> cp312) RUN case "$PYTORCH_VERSION" in \
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \ 2.9.[0-9]*) \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10) if [ "$TARGETARCH" = "amd64" ]; then \
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \ if [ "$CUDA" = "128" ]; then \
# Map architecture wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
case "$TARGETARCH" in \ uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
amd64) ARCH_TAG="x86_64" ;; \ rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
arm64) ARCH_TAG="aarch64" ;; \ elif [ "$CUDA" = "130" ]; then \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \ wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
esac && \ uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
WHL_VERSION="v0.7.16" && \ rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \ fi \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \ elif [ "$TARGETARCH" = "arm64" ]; then \
uv pip install --no-cache-dir "${WHL_FILE}" && \ if [ "$CUDA" = "128" ]; then \
rm "${WHL_FILE}" wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
fi \
;; \
esac

View File

@@ -210,8 +210,6 @@ axolotl lm-eval config.yml
Configuration options: Configuration options:
```yaml ```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate # List of tasks to evaluate
lm_eval_tasks: lm_eval_tasks:
- arc_challenge - arc_challenge
@@ -220,7 +218,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results output_dir: # Directory to save evaluation results
``` ```
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details. See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4 ### delinearize-llama4

View File

@@ -40,7 +40,7 @@
"%%capture\n", "%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n", "# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572\"" "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b\""
] ]
}, },
{ {

View File

@@ -2,25 +2,25 @@
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1 bitsandbytes==0.49.1
triton>=3.4.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
liger-kernel==0.7.0 liger-kernel==0.6.4
# END section # END section
packaging==26.0 packaging==26.0
huggingface_hub>=1.1.7 huggingface_hub>=1.1.7
peft>=0.18.1 peft>=0.18.1
tokenizers>=0.22.1 tokenizers>=0.22.1
transformers==5.2.0 transformers==5.0.0
accelerate==1.12.0 accelerate==1.12.0
datasets==4.5.0 datasets==4.5.0
deepspeed>=0.18.3 deepspeed>=0.18.3
trl==0.28.0 trl==0.27.1
hf_xet==1.2.0 hf_xet==1.2.0
kernels==0.12.1 kernels==0.11.5
trackio>=0.16.1 trackio>=0.13.0
typing-extensions>=4.15.0 typing-extensions>=4.15.0
optimum==1.16.2 optimum==1.16.2
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.16.0 torchao==0.13.0
openenv-core==0.1.0 openenv-core==0.1.0
schedulefree==1.4.1 schedulefree==1.4.1

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"'
) )

View File

@@ -26,11 +26,6 @@ def parse_requirements(extras_require_map):
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64" install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip torchao on ARM64
_install_requires = [
req for req in _install_requires if "torchao" not in req
]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# skip packages not compatible with OSX # skip packages not compatible with OSX
skip_packages = [ skip_packages = [

View File

@@ -5,7 +5,7 @@ import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any, Optional, Union from typing import Union
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
@@ -32,63 +32,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__) LOG = get_logger(__name__)
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
"""Coerce a string CLI value to its most likely Python type.
If an existing value is present in the config, its type is used to guide
casting. Otherwise, YAML-style inference is applied: booleans, ints,
floats, and None literals are recognised automatically.
Args:
value: The raw value (typically a string from the CLI).
existing: An optional existing config value whose type guides coercion.
Returns:
The value cast to the inferred or expected type.
"""
if not isinstance(value, str):
return value
# If the config already has a typed value, cast to match
if existing is not None:
if isinstance(existing, bool):
return value.lower() in ("true", "1", "yes")
if isinstance(existing, int):
try:
return int(value)
except (ValueError, TypeError):
return value
if isinstance(existing, float):
try:
return float(value)
except (ValueError, TypeError):
return value
# For other types (str, list, dict, etc.), return as-is
return value
# No existing value -- use YAML-style inference
lower = value.lower()
if lower in ("true", "yes"):
return True
if lower in ("false", "no"):
return False
if lower in ("null", "none", "~"):
return None
# Try int then float
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
API_KEY_FIELDS = {"comet_api_key"} API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance() TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -265,37 +208,13 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid # If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value # from the yaml, then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()
# Separate nested (dot-notation) kwargs from flat kwargs
nested_kwargs: dict[str, dict[str, Any]] = {}
flat_kwargs: dict[str, Any] = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
# Apply flat kwargs
for key, value in flat_kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already # If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict: if key in cfg_keys or not cfg.strict:
cfg[key] = _coerce_value(value, cfg.get(key)) if isinstance(cfg[key], bool):
cfg[key] = bool(value)
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta) else:
for parent, children in nested_kwargs.items(): cfg[key] = value
if parent not in cfg_keys and cfg.strict:
continue
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
LOG.warning(
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
)
cfg[parent] = {}
for child_key, child_value in children.items():
existing_child = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
try: try:
device_props = torch.cuda.get_device_properties("cuda") device_props = torch.cuda.get_device_properties("cuda")

View File

@@ -2,7 +2,7 @@
import dataclasses import dataclasses
from functools import wraps from functools import wraps
from types import NoneType, UnionType from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin from typing import Any, Callable, Type, Union, get_args, get_origin
import click import click
@@ -20,8 +20,7 @@ def _strip_optional_type(field_type: type | str | None):
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged. returns the input type unchanged.
""" """
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType) if get_origin(field_type) is Union and type(None) in get_args(field_type):
if is_union and type(None) in get_args(field_type):
field_type = next( field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType) t for t in get_args(field_type) if not isinstance(t, NoneType)
) )
@@ -88,70 +87,10 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
return decorator return decorator
def _is_pydantic_model(field_type: type) -> bool:
"""Check if a type is a Pydantic BaseModel subclass."""
try:
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
except TypeError:
return False
def _get_field_description(field) -> str | None:
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
if field.description:
return field.description
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
return field.json_schema_extra.get("description")
return None
def _add_nested_model_options(
function: Callable, parent_name: str, model_class: Type[BaseModel]
) -> Callable:
"""
Add Click options for all fields of a nested Pydantic model using dot-notation.
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
Args:
function: Click command function to add options to.
parent_name: Parent field name (e.g., "trl").
model_class: Nested Pydantic model class.
Returns:
Function with added Click options.
"""
for sub_name, sub_field in reversed(model_class.model_fields.items()):
sub_type = _strip_optional_type(sub_field.annotation)
# Use dot notation: --parent.sub_field
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
# The kwarg name uses double-underscore as separator
param_name = f"{parent_name}__{sub_name}"
description = _get_field_description(sub_field)
if sub_type is bool:
option_name = f"--{cli_name}/--no-{cli_name}"
function = click.option(
option_name, param_name, default=None, help=description
)(function)
else:
option_name = f"--{cli_name}"
click_type = {str: str, int: int, float: float}.get(sub_type)
function = click.option(
option_name, param_name, default=None, type=click_type, help=description
)(function)
return function
def add_options_from_config(config_class: Type[BaseModel]) -> Callable: def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
""" """
Create Click options from the fields of a Pydantic model. Create Click options from the fields of a Pydantic model.
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
generated for each sub-field (e.g., ``--trl.beta=0.1``).
Args: Args:
config_class: PyDantic model with fields to parse from the CLI config_class: PyDantic model with fields to parse from the CLI
@@ -164,11 +103,6 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()): for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation) field_type = _strip_optional_type(field.annotation)
# Handle nested Pydantic models with dot-notation options
if _is_pydantic_model(field_type):
function = _add_nested_model_options(function, name, field_type)
continue
if field_type is bool: if field_type is bool:
field_name = name.replace("_", "-") field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"

View File

@@ -258,6 +258,11 @@ class TrainerBuilderBase(abc.ABC):
bf16 = bf16 if bf16 is not None else False bf16 = bf16 if bf16 is not None else False
training_args_kwargs["bf16"] = bf16 training_args_kwargs["bf16"] = bf16
if self.cfg.fp8:
training_args_kwargs["fp8"] = True
if self.cfg.fp8_enable_fsdp_float8_all_gather:
training_args_kwargs["enable_fsdp_float8_all_gather:"] = True
def _configure_scheduler(self, training_args_kwargs: dict): def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]: if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
training_args_kwargs["lr_scheduler_type"] = "cosine" training_args_kwargs["lr_scheduler_type"] = "cosine"

View File

@@ -122,12 +122,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ColabCallback = colab_inference_post_train_callback(trainer) ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg)) callbacks.append(ColabCallback(self.cfg))
if getattr(self.cfg, "generate_samples", False):
from axolotl.utils.callbacks.generation import SFTGenerationCallback
callbacks.append(SFTGenerationCallback(trainer))
LOG.info("SFT sample generation enabled")
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks return callbacks
@@ -252,8 +246,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters ddp_find_unused_parameters
) )
if self.cfg.group_by_length: training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)

View File

@@ -11,6 +11,7 @@ from axolotl.core.trainers import (
) )
from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.callbacks.qat import QATCallback
@@ -52,8 +53,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}: if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1 sequence_parallel=self.cfg.context_parallel_size > 1
) )
@@ -134,17 +133,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None: if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
blocklist_args_kwargs.append("max_prompt_length") # Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO: elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO: elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length # KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs.append("max_prompt_length") blocklist_args_kwargs = ["max_prompt_length"]
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0 self.cfg.kto_desirable_weight or 1.0
@@ -154,8 +157,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
) )
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}: elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class() training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()

View File

@@ -584,11 +584,9 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess() super().create_accelerator_and_postprocess()
def additional_accelerator_args( def build_fp8_accelerator_args(self) -> dict[str, Any]:
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs args = {}
) -> dict[str, Any]: if self.args.fp8:
ret_kwargs = {}
if fp8:
from accelerate.utils import AORecipeKwargs from accelerate.utils import AORecipeKwargs
from torchao.float8 import Float8LinearConfig from torchao.float8 import Float8LinearConfig
@@ -596,15 +594,22 @@ class AxolotlTrainer(
# scaling strategy. See more details here: # scaling strategy. See more details here:
# https://github.com/pytorch/ao/tree/main/torchao/float8. # https://github.com/pytorch/ao/tree/main/torchao/float8.
config = Float8LinearConfig( config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, enable_fsdp_float8_all_gather=self.args.enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True, force_recompute_fp8_weight_in_bwd=self.args.enable_fsdp_float8_all_gather
is True,
) )
ret_kwargs["mixed_precision"] = "fp8" args["mixed_precision"] = "fp8"
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore args["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
return ret_kwargs return args
def _build_accelerator_args(self, **kwargs) -> dict[str, Any]:
args = super().build_accelerator_args(**kwargs)
fp8_args = self.build_fp8_accelerator_args()
args.update(fp8_args)
return args
def log(self, logs: dict[str, float], start_time: float | None = None) -> None: def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
""" """
@@ -719,8 +724,6 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}") LOG.info(f"Saving model checkpoint to {output_dir}")
# fix for Context Parallel save
if state_dict is None: if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model) state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None: if state_dict is not None:
@@ -728,7 +731,6 @@ class AxolotlTrainer(
k: v.clone() if isinstance(v, torch.Tensor) else v k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items() for k, v in state_dict.items()
} }
supported_classes = ( supported_classes = (
(PreTrainedModel,) (PreTrainedModel,)
if not is_peft_available() if not is_peft_available()
@@ -739,7 +741,6 @@ class AxolotlTrainer(
if not isinstance(self.model, supported_classes): if not isinstance(self.model, supported_classes):
if state_dict is None: if state_dict is None:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
if isinstance( if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False), self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes, supported_classes,
@@ -749,7 +750,6 @@ class AxolotlTrainer(
).save_pretrained( ).save_pretrained(
output_dir, output_dir,
state_dict=state_dict, state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
) )
else: else:
LOG.info( LOG.info(
@@ -761,7 +761,11 @@ class AxolotlTrainer(
metadata={"format": "pt"}, metadata={"format": "pt"},
) )
else: else:
self.model.save_pretrained(output_dir, state_dict=state_dict) self.model.save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None: if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir) self.processing_class.save_pretrained(output_dir)
@@ -773,7 +777,11 @@ class AxolotlTrainer(
LOG.info( LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`" "Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
) )
self.data_collator.tokenizer.save_pretrained(output_dir) save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -57,18 +57,16 @@ class AxolotlDPOTrainer(
def tokenize_row( def tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length: int | None = None, max_prompt_length,
max_completion_length: int | None = None, max_completion_length,
add_special_tokens: bool = True, add_special_tokens,
is_chat: bool = False,
) -> Dict: ) -> Dict:
res = DPOTrainer.tokenize_row( res = DPOTrainer.tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length=max_prompt_length, max_prompt_length,
max_completion_length=max_completion_length, max_completion_length,
add_special_tokens=add_special_tokens, add_special_tokens,
is_chat=is_chat,
) )
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -25,7 +25,7 @@ class SchedulerMixin(Trainer):
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
) -> LRScheduler: ) -> LRScheduler:
""" """
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
@@ -45,13 +45,6 @@ class SchedulerMixin(Trainer):
and self.args.cosine_min_lr_ratio is not None and self.args.cosine_min_lr_ratio is not None
) )
if optimizer is None:
if self.optimizer is None:
raise ValueError(
"Optimizer must be set before calling create_scheduler or passed as an argument."
)
optimizer = self.optimizer
# fmt: off # fmt: off
if self.lr_scheduler is None: # type: ignore if self.lr_scheduler is None: # type: ignore
# fmt: on # fmt: on

View File

@@ -263,3 +263,13 @@ class AxolotlTrainingMixins:
dion_rank_multiple_of: int | None = field( dion_rank_multiple_of: int | None = field(
default=None, default=None,
) )
fp8: bool | None = field(
default=None,
metadata={"help": "Whether to use FP8 precision for training"},
)
enable_fsdp_float8_all_gather: bool | None = field(
default=None,
metadata={"help": "Whether to use FSDP with FP8 precision for all_gather"},
)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip - If you are installing from pip
```bash ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"
``` ```
## Usage ## Usage
@@ -31,7 +31,6 @@ plugins:
## Supported Models ## Supported Models
- afmoe
- apertus - apertus
- arcee - arcee
- cohere - cohere
@@ -52,7 +51,6 @@ plugins:
- glm4v - glm4v
- glm4v_moe - glm4v_moe
- glm_image - glm_image
- glm_moe_dsa
- gpt_oss - gpt_oss
- granite - granite
- granitemoe - granitemoe
@@ -78,19 +76,14 @@ plugins:
- olmo - olmo
- olmo2 - olmo2
- olmo3 - olmo3
- olmoe
- phi - phi
- phi3 - phi3
- phi4_multimodal - phi4_multimodal
- qwen2 - qwen2
- qwen2_5_vl
- qwen2_moe - qwen2_moe
- qwen2_vl - qwen2_vl
- qwen2_5_vl
- qwen3 - qwen3
- qwen3_5
- qwen3_5_moe
- qwen3_5_moe_vl
- qwen3_5_vl
- qwen3_moe - qwen3_moe
- qwen3_next - qwen3_next
- qwen3_vl - qwen3_vl

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using " "Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"`'
) )

View File

@@ -1,44 +0,0 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -33,16 +33,3 @@ class KernelsArgs(BaseModel):
data["experts_implementation"] = "eager" data["experts_implementation"] = "eager"
return data return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
)
data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False
return data

View File

@@ -1,18 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
__all__ = [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
]

View File

@@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import lora_ops, ops
__all__ = ["ops", "lora_ops"]

View File

@@ -1,645 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
from typing import Optional
import torch
import triton
import triton.language as tl
BLOCK_M = 128
ALLOW_TF32 = True
@triton.jit
def _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=True,
):
K_block = tl.arange(0, BLOCK_K)
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = (
W_ptr
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
+ E_idx * stride_we
)
iters = tl.cdiv(K, BLOCK_K)
for K_block_id in range(iters):
if no_k_mask:
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
else:
K_mask = (K_block_id * BLOCK_K + K_block) < K
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
return acc
def _scatter2scatter_configs():
return [
triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
]
@triton.autotune(
configs=_scatter2scatter_configs(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _scatter2scatter(
X_ptr,
stride_xm: tl.constexpr,
stride_xk: tl.constexpr,
W_ptr,
stride_we,
stride_wk: tl.constexpr,
stride_wn: tl.constexpr,
Y_ptr,
stride_ym: tl.constexpr,
stride_yn: tl.constexpr,
B_ptr,
stride_be: tl.constexpr,
stride_bn: tl.constexpr,
grouped_idx_ptr,
expert_idxs_ptr,
# block_start_idx_ptr,
FAN_OUT: tl.constexpr,
M,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
# OUT_M,
allow_tf32: tl.constexpr,
x_grouped: tl.constexpr,
y_grouped: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
M_block_id = pid // N_BLOCK_COUNT
N_block_id = pid % N_BLOCK_COUNT
M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
M_boundary_mask = M_block < (FAN_OUT * M)
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
no_k_mask = K % BLOCK_K == 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
E_first_idx = tl.min(E_idxs)
E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
for E_idx in range(E_first_idx, E_last_idx + 1):
E_mask = E_idxs == E_idx
E_M_idx = M_idx
if x_grouped:
M_in_idx = M_block
else:
M_in_idx = E_M_idx // FAN_OUT
acc = _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=allow_tf32,
)
if B_ptr is not None:
B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
if y_grouped:
M_out_idx = M_block
else:
M_out_idx = M_idx
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def scatter2scatter(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
b=None,
x_grouped=False,
y_grouped=False,
out=None,
):
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
# Pre-kernel setup
y_dim = W.size(-1)
L_scattered = sorted_expert_idxs.size(0)
if out is None:
output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
else:
assert out.size(0) == L_scattered and out.size(1) == y_dim
output = out
scatter2scatter_compileable(
output,
W,
X,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
b,
x_grouped,
y_grouped,
)
return output
@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
def scatter2scatter_compileable(
output: torch.Tensor,
W: torch.Tensor,
X: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
b: Optional[torch.Tensor],
x_grouped: bool,
y_grouped: bool,
) -> None:
def grid(META):
grid_num = (
triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"])
* triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid_num
if b is None:
b = None
stride_be = stride_bn = 0
else:
stride_be, stride_bn = b.stride()
_scatter2scatter[grid](
# X_ptr, stride_xm, stride_xk,
X,
X.stride(0),
X.stride(1),
# W_ptr, stride_we, stride_wk, stride_wn,
W,
W.stride(0),
W.stride(1),
W.stride(2),
# Y_ptr, stride_ym, stride_yn,
output,
output.stride(0),
output.stride(1),
# B_ptr, stride_be, stride_bn
b,
stride_be,
stride_bn,
grouped_idx_ptr=sorted_scattered_idxs,
expert_idxs_ptr=sorted_expert_idxs,
# block_start_idx_ptr=padded_block_idxs,
FAN_OUT=k,
M=X.size(0),
K=X.size(1),
N=output.size(1),
E=W.size(0),
BLOCK_M=BLOCK_M,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
x_grouped=x_grouped,
y_grouped=y_grouped,
)
def _config_XtY():
return [
triton.Config(
{"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4
),
]
def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
DW = DWt.permute(0, 2, 1)
if has_bias:
Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
else:
Db = None
groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
return DW, Db
@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW", "Db"})
def groupXtY_compileable(
E: int,
DW: torch.Tensor,
Db: Optional[torch.Tensor],
DY: torch.Tensor,
X: torch.Tensor,
expert_offsets: torch.Tensor,
) -> None:
def grid(META):
grid = (
E * triton.cdiv(META["K"], META["BLOCK_K"]),
triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid
if Db is None:
stride_dbe = 0
stride_dbn = 0
else:
stride_dbe, stride_dbn = Db.stride()
_groupXtY[grid](
# DY_ptr, stride_dym, stride_dyk,
DY,
DY.stride(0),
DY.stride(1),
# X_ptr, stride_xm, stride_xn,
X,
X.stride(0),
X.stride(1),
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
DW,
DW.stride(0),
DW.stride(1),
DW.stride(2),
# Db_ptr, stride_dwe, stride_dbn,
Db,
stride_dbe,
stride_dbn,
# expert_offsets_ptr,
expert_offsets,
# K: tl.constexpr, N: tl.constexpr,
M=DY.size(0),
N=DY.size(-1),
K=X.size(-1),
# ACC_TYPE: tl.constexpr,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
@triton.autotune(
configs=_config_XtY(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _groupXtY(
DY_ptr,
stride_dym,
stride_dyk,
X_ptr,
stride_xm,
stride_xn,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
expert_offsets_ptr,
M,
K: tl.constexpr,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
num0 = tl.num_programs(0)
num1 = tl.num_programs(1)
# pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
E_idx = pid0 // K_BLOCK_COUNT
K_block_id = pid0 % K_BLOCK_COUNT
N_block_id = pid1
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
if end_idx > start_idx:
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
K_mask = K_block < K
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
M_idxs = M_block
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
dy_blk_ptrs = (
DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
)
if (Db_ptr is not None) and (K_block_id == 0):
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=True,
)
else:
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=False,
)
@triton.jit
def _xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias: tl.constexpr,
):
if compute_bias:
db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
else:
db_acc = None
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
for i in range(0, iters):
M_mask = (i * BLOCK_M + M_block) < end_idx
if NO_K_MASK:
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
else:
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
if NO_N_MASK:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
else:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
xt_blk_ptrs += BLOCK_M * stride_xm
dy_blk_ptrs += BLOCK_M * stride_dym
if compute_bias:
db_acc += tl.sum(dy, axis=0)
DW_blk_ptrs = (
DW_ptr
+ E_idx * stride_dwe
+ K_block[:, None] * stride_dwk
+ N_block[None, :] * stride_dwn
)
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
if compute_bias:
Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
def _config_grouping():
return [
triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
]
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
N = sorted_expert_idxs.size(0)
K = A.size(1)
assert A.size(0) * fan_out == N
if out is not None:
Y = out
else:
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
return Y
@torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
def group_compileable(
A: torch.Tensor,
K: int,
N: int,
Y: torch.Tensor,
coeff: Optional[torch.Tensor],
has_coeff: bool,
fan_out: int,
sorted_expert_idxs: torch.Tensor,
) -> None:
def grid(META):
grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),)
return grid_num
_group[grid](
# A_ptr, stride_an, stride_ai,
A,
A.stride(0),
A.stride(1),
has_coeff,
coeff,
fan_out,
# Y_ptr, stride_yn, stride_yk,
Y,
Y.stride(0),
Y.stride(1),
# grouped_idx_ptr,
sorted_expert_idxs,
# N: tl.constexpr, K: tl.constexpr,
N,
K,
)
@triton.autotune(configs=_config_grouping(), key=["K"])
@triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0})
@triton.jit
def _group(
src_ptr,
stride_sn,
stride_sk,
has_coeff: tl.constexpr,
coeff_ptr,
FAN_OUT: tl.constexpr,
tgt_ptr,
stride_tn,
stride_ti,
grouped_idx_ptr,
N,
K: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
NO_K_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_block_id = pid
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_blk < N
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
K_blk = tl.arange(0, BLOCK_K)
src_blk_ptrs = (
src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
)
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
if has_coeff:
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
iters = tl.cdiv(K, BLOCK_K)
for i in range(0, iters):
if NO_K_MASK or i < iters - 1:
block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
else:
K_mask = (i * BLOCK_K + K_blk) < K
mask = N_mask[:, None] & K_mask[None, :]
block = tl.load(src_blk_ptrs, mask=mask)
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=mask)
src_blk_ptrs += BLOCK_K * stride_sk
tgt_blk_ptrs += BLOCK_K * stride_ti

View File

@@ -1,98 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
import torch
import triton
import triton.language as tl
@triton.jit
def _single2scatter(
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
Y_ptr,
stride_ym,
stride_yn,
expert_idxs_ptr,
FAN_OUT: tl.constexpr,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
N_block_id = pid0
if FAN_OUT == 1:
in_idx = pid1
else:
in_idx = 0
out_idx = pid1
K_block = tl.arange(0, BLOCK_K)
N_block = tl.max_contiguous(
tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N),
BLOCK_N,
)
E_idx = tl.load(expert_idxs_ptr + pid1)
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
W_blk_ptrs = (
W_ptr
+ E_idx * stride_we
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
)
N_mask = N_block < N
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
K_mask = K_block < K
x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0)
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0)
acc += tl.sum(x * w, axis=0)[None, :]
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
K_block += BLOCK_K
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :])
def single2scatter(X, W, expert_idxs):
E, xdim, ydim = W.size()
k = expert_idxs.size(1)
assert X.size(0) == k or X.size(0) == 1
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
BLOCK_N = 128
BLOCK_K = 128
grid = triton.cdiv(ydim, BLOCK_N), k
_single2scatter[grid](
X,
X.stride(0),
X.stride(1),
W,
W.stride(0),
W.stride(1),
W.stride(2),
Y,
Y.stride(0),
Y.stride(1),
expert_idxs,
FAN_OUT=Y.size(0) // X.size(0),
K=xdim,
N=ydim,
E=E,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
ACC_TYPE=tl.float32,
)
return Y

View File

@@ -1,439 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ScatterMoE layer replacements for HuggingFace MoE architectures.
Provides drop-in forward replacements that use ScatterMoE kernels for
acceleration. When used via the HF ``kernels`` library
(``replace_kernel_forward_from_hub``), these classes replace the forward
method of the original MoE block.
LoRA support
------------
When peft wraps parameters via ``target_parameters``, the ``self.experts``
submodule becomes a chain of ``ParamWrapper`` objects and the ``self.gate``
router may also become a ``ParamWrapper``. The ``HFScatterMoEGatedMLP``
forward detects this and automatically:
1. Unwraps ``self.gate`` to the base router, applying gate LoRA delta
2. Unwraps ``self.experts`` to the base ``OlmoeExperts`` module
3. Extracts LoRA A/B weights and scaling from each wrapper
4. Converts B layout from peft rank-major to scattermoe expert-major
5. Routes to ``parallel_linear_lora`` for fused LoRA computation
6. Passes through ``self.shared_expert`` / ``self.shared_expert_gate``
(peft wraps their linear layers with standard LoRA, no special handling)
"""
import torch
from torch import nn
from torch.nn import functional as F
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
# =============================================================================
# LoRA layout conversion utilities (peft <-> scattermoe)
# =============================================================================
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
expert-major ``[N, r*E]``.
peft reshapes B to ``[out, r, E]`` (rank-major).
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
"""
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
are swapped relative to scattermoe's convention.
peft gives:
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
scattermoe needs:
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
This function swaps A<->B and converts B from rank-major to expert-major.
Uses vectorized tensor operations (no Python loop over experts).
Works for **both** gate_up_proj and down_proj since the transposition
issue is the same for any parameter.
"""
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
smoe_A = (
peft_B_em.reshape(dim2, num_experts, rank)
.permute(1, 2, 0)
.contiguous()
.reshape(rank * num_experts, dim2)
)
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
smoe_B = (
peft_A.reshape(num_experts, rank, dim1)
.permute(2, 0, 1)
.contiguous()
.reshape(dim1, num_experts * rank)
)
return smoe_A, smoe_B
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
# =============================================================================
# ParamWrapper unwrapping
# =============================================================================
def _unwrap_gate_lora(gate_module):
"""Unwrap peft ``ParamWrapper`` on the router gate.
When peft targets ``gate.weight``, ``self.gate`` becomes::
ParamWrapper(weight)
-> base_layer: OlmoeTopKRouter (the real module)
This function detects the wrapping and returns the base router, its
weight tensor, and an optional LoRA delta tensor.
Returns:
(base_gate, gate_weight, gate_lora_delta_or_None)
``base_gate`` is the original router module (with ``.top_k``,
``.num_experts``, ``.norm_topk_prob``).
``gate_weight`` is the base router weight (may be a DTensor under FSDP).
``gate_lora_delta_or_None`` is the LoRA delta tensor if LoRA is active,
else ``None``. Kept separate to avoid mixing DTensor + Tensor in an add.
"""
if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"):
base_gate = gate_module.base_layer
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
if lora_A is not None:
# gate weight: [num_experts, hidden_size]
# lora_A: [r, hidden_size], lora_B: [num_experts, r]
# delta = scaling * B @ A = [num_experts, hidden_size]
delta = scaling * (lora_B @ lora_A)
return base_gate, base_gate.weight, delta
else:
return base_gate, base_gate.weight, None
else:
# No wrapping — gate is the original module
return gate_module, gate_module.weight, None
def _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling):
"""Convert peft LoRA weights to scattermoe layout."""
smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
return (smoe_A, smoe_B, scaling)
def _unwrap_experts_lora(experts_module):
"""Walk a peft ``ParamWrapper`` chain on ``self.experts``.
When peft targets ``experts.gate_up_proj`` and ``experts.down_proj`` via
``target_parameters``, ``self.experts`` becomes a nested chain::
ParamWrapper(down_proj)
-> base_layer: ParamWrapper(gate_up_proj)
-> base_layer: OlmoeExperts (the real module)
This function walks the chain, collects LoRA params keyed by
``parameter_name``, and returns the base experts module.
Returns:
(base_experts, gup_lora, down_lora)
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
A/B are already in scattermoe layout.
"""
# Collect ParamWrapper layers by their parameter_name
wrappers = {}
module = experts_module
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
param_name = getattr(module, "parameter_name", None)
if param_name is not None:
wrappers[param_name] = module
module = module.base_layer
base_experts = module
if not wrappers:
return base_experts, None, None
# Determine num_experts from base module
num_experts = getattr(base_experts, "num_experts", None)
if num_experts is None:
# Fallback: infer from parameter shape
gup = getattr(base_experts, "gate_up_proj", None)
if gup is not None:
num_experts = gup.shape[0]
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj")
if gup_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
if lora_A is not None:
rank = lora_A.shape[0] // num_experts
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
# Extract down_proj LoRA (needs A<->B swap due to transposition)
down_lora = None
down_wrapper = wrappers.get("down_proj")
if down_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper)
if lora_A is not None:
rank = lora_A.shape[0] // num_experts
down_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
return base_experts, gup_lora, down_lora
# =============================================================================
# Layer classes
# =============================================================================
class ScatterMoEGatedMLP(nn.Module):
def forward(self, layer_input):
"""
Forward pass of the mixture of experts layer.
Args:
layer_input (Tensor):
Input tensor.
Returns:
Tensor:
Output tensor.
"""
bsz, length, emb_size = layer_input.size()
layer_input = layer_input.reshape(-1, emb_size)
# compute the top_k routing decision
router_logits = self.router.layer(layer_input)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.router.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(layer_input.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
selected_experts, num_experts=self.router.num_experts
)
# compute experts
gates, h = parallel_linear(
layer_input,
self.input_linear.weight.transpose(2, 1),
self.router.top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=False,
grouped_out=True,
).chunk(2, dim=-1)
h = self.activation(gates) * h
layer_output = parallel_linear(
h,
self.output_linear.weight.transpose(2, 1),
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=True,
grouped_out=False,
gates=routing_weights,
)
layer_output = layer_output.view(bsz, length, emb_size)
return layer_output
class HFScatterMoEGatedMLP(nn.Module):
"""
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
method replaces the original ``OlmoeSparseMoeBlock.forward``.
Supports both full-parameter training and LoRA fine-tuning:
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
"""
@staticmethod
def forward(self: nn.Module, layer_input: torch.Tensor):
"""
Forward pass using ScatterMoE kernels.
Args:
self: The MoeSparseMoeBlock module containing:
- self.gate: Router (or peft ParamWrapper wrapping it)
- self.experts: Experts module (or peft ParamWrapper chain)
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
- self.shared_expert_gate: Optional shared expert gate
layer_input: Input tensor [batch_size, seq_len, hidden_size]
Returns:
Tensor: [batch_size, seq_len, hidden_size]
"""
batch_size, sequence_length, hidden_dim = layer_input.shape
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE)
# ====================================================================
# peft wraps individual linear layers inside shared_expert with
# standard LoRA — calling forward() handles this transparently.
if hasattr(self, "shared_expert") and self.shared_expert is not None:
shared_expert_output = self.shared_expert(hidden_states_flat)
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate_output = F.sigmoid(
self.shared_expert_gate(hidden_states_flat)
)
shared_expert_output = shared_expert_output * shared_expert_gate_output
else:
shared_expert_output = None
# ====================================================================
# Router Computation (with optional gate LoRA)
# ====================================================================
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
router_logits = F.linear(hidden_states_flat, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states_flat, gate_lora_delta
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if base_gate.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states_flat.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
selected_experts, num_experts=num_experts
)
# ====================================================================
# Detect LoRA (peft ParamWrapper) and extract adapter weights
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
grouped_in=False,
grouped_out=True,
use_fused_dX=True,
use_fused_gather=True,
)
else:
gup = parallel_linear(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=False,
grouped_out=True,
)
gates, h = gup.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
expert_output = parallel_linear_lora(
h,
down_W,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
gates=routing_weights,
grouped_in=True,
grouped_out=False,
use_fused_dX=True,
use_fused_gather=True,
)
else:
expert_output = parallel_linear(
h,
down_W,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=True,
grouped_out=False,
gates=routing_weights,
)
# ====================================================================
# Combine with shared expert and reshape
# ====================================================================
if shared_expert_output is not None:
expert_output = expert_output + shared_expert_output
expert_output = expert_output.view(batch_size, sequence_length, hidden_dim)
return expert_output

View File

@@ -1,99 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ParallelExperts module with LoRA support.
Provides a drop-in replacement for ScatterMoE's ParallelExperts that
uses the fused LoRA kernel when adapter weights are attached.
"""
from typing import Optional
import torch
import torch.nn as nn
from .parallel_linear_lora import parallel_linear_lora
class ParallelExperts(nn.Module):
"""
Parallel Experts with fused LoRA support.
Drop-in replacement for the original ParallelExperts. When LoRA parameters
are attached via set_lora(), the forward pass uses a fused kernel:
Y = X @ W + scaling * (X @ A^T) @ B^T
"""
def __init__(
self,
num_experts: int,
input_size: int,
output_size: int,
bias: bool = False,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self._lora_A: torch.Tensor | None = None
self._lora_B: torch.Tensor | None = None
self._lora_scaling: float | None = None
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def extra_repr(self) -> str:
return (
f"num_experts={self.num_experts}, "
f"input_size={self.input_size}, "
f"output_size={self.output_size}"
)
def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float):
"""Attach LoRA parameters for fused computation."""
self._lora_A = lora_A
self._lora_B = lora_B
self._lora_scaling = scaling
def clear_lora(self):
"""Remove LoRA parameters."""
self._lora_A = None
self._lora_B = None
self._lora_scaling = None
def forward(
self,
inputs: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
) -> torch.Tensor:
return parallel_linear_lora(
inputs,
self.weight.permute(0, 2, 1), # [E, input, output]
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=self._lora_A,
lora_B=self._lora_B,
scaling=self._lora_scaling if self._lora_scaling is not None else 1.0,
expert_biases=self.bias,
gates=gates,
grouped_in=grouped_in,
grouped_out=grouped_out,
)

View File

@@ -1,253 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
from typing import Optional
import torch
import torch.nn as nn
from . import kernels
@torch.library.custom_op("scattermoe::bincount", mutates_args={})
def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
return x.bincount(minlength=minlength)
@compileable_bincount.register_fake
def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
return torch.empty(minlength, dtype=torch.long, device=x.device)
@torch.compile
def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
with torch.no_grad():
flattened_expert_idxs = expert_idxs.flatten()
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
expert_counts = compileable_bincount(
flattened_expert_idxs, minlength=num_experts
)
expert_offsets = expert_counts.cumsum(-1)
return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
class ParallelLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
):
with torch.device(x.device):
output = kernels.ops.scatter2scatter(
X=x,
W=expert_weights,
b=expert_biases,
k=k,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
x_grouped=grouped_in,
y_grouped=grouped_out,
)
if gates is not None:
output_expanded = output.view(
gates.size(0), gates.size(1), output.size(-1)
)
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x,
expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
)
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(
x,
expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
) = ctx.saved_tensors
k = ctx.k
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
if gates is not None:
# calculate gates gradient
# d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
grouped_grad_out = output_expanded.flatten(
0, 1
) # reuse expanded buffer later
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = kernels.ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x
d_weights, d_biases = kernels.ops.group_bwd_W(
DY=grouped_grad_out,
X=grouped_x,
expert_offsets=expert_offsets,
E=expert_weights.size(0),
has_bias=expert_biases is not None,
)
d_expanded_input = kernels.ops.scatter2scatter(
X=grouped_grad_out,
x_grouped=True,
W=expert_weights.permute(0, 2, 1),
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input, # Reuse grouped_x buffer
)
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(
x.size(0), k, d_expanded_input.size(-1)
).sum(-2)
return (
# x, expert_weights,
d_input,
d_weights,
# k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
None,
None,
None,
None,
# bias, gates
d_biases,
d_gates,
# grouped_in, grouped_out,
None,
None,
)
def parallel_linear(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases=None,
gates=None,
grouped_in=False,
grouped_out=False,
):
results = ParallelLinear.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates,
grouped_in,
grouped_out,
)
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self.reset_parameters()
def extra_repr(self):
return "num_experts={}, input_size={}, output_size={}".format(
self.num_experts, self.input_size, self.output_size
)
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(
self,
inputs,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates=None,
grouped_in=False,
grouped_out=False,
):
results = parallel_linear(
inputs,
self.weight.permute(0, 2, 1),
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases=self.bias,
gates=gates,
grouped_in=grouped_in,
grouped_out=grouped_out,
)
return results

View File

@@ -1,480 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ScatterMoE + LoRA Autograd Function
====================================
Provides the autograd function and Python interface for fused ScatterMoE + LoRA.
Key design for LoRA training:
- Expert weights W are FROZEN (no gradient computed for W).
- Only LoRA adapter weights (A, B) receive gradients.
- The input gradient dX is still computed (needed for upstream layers).
- This avoids the expensive group_bwd_W computation entirely.
Forward:
Y = X @ W + scaling * (X @ A^T) @ B^T
Backward (W frozen):
dX = dY @ W^T + scaling * (dY @ B) @ A (via scatter2scatter for base, separate for LoRA)
dA = scaling * (dY @ B)^T @ X (per-expert, on grouped data)
dB = scaling * dY^T @ (X @ A^T) (per-expert, on grouped data)
"""
from typing import Optional
import torch
from .kernels import ops as base_ops
from .kernels.lora_ops import (
group_bwd_lora,
group_bwd_lora_fused,
scatter2scatter_lora,
scatter2scatter_lora_dX,
)
class ScatterMoELoRA(torch.autograd.Function):
"""
Autograd function for fused ScatterMoE + LoRA with frozen expert weights.
This function is optimized for the LoRA fine-tuning scenario where:
- Expert weights W are frozen (requires_grad=False)
- Only LoRA A and B matrices receive gradients
- Input gradients are computed for upstream layer backprop
"""
@staticmethod
def forward(
ctx,
x: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
with torch.device(x.device):
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
output = scatter2scatter_lora(
X=x,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
b=expert_biases,
x_grouped=grouped_in,
y_grouped=grouped_out,
)
# Handle gating (weighted combination of top-k expert outputs)
if gates is not None:
output_expanded = output.view(
gates.size(0), gates.size(1), output.size(-1)
)
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x,
lora_A,
lora_B,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
)
# Store frozen weights as plain Python attributes instead of
# save_for_backward. This avoids:
# 1. Version-check conflicts with FSDP unshard/reshard
# 2. Pinning all-gathered parameters via saved_tensors hooks
# 3. Interfering with activation offloading pack/unpack hooks
# Safe because expert_weights are frozen (requires_grad=False).
ctx.expert_weights = expert_weights
ctx.expert_biases = expert_biases
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
ctx.scaling = scaling
ctx.use_fused_dX = use_fused_dX
ctx.use_fused_gather = use_fused_gather
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(
x,
lora_A,
lora_B,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
) = ctx.saved_tensors
expert_weights = ctx.expert_weights
k = ctx.k
scaling = ctx.scaling
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
E = expert_weights.size(0)
# ------------------------------------------------------------------
# Gate gradients (if using top-k gating with routing weights)
# ------------------------------------------------------------------
if gates is not None:
# d_gates[t, j] = output_expanded[t, j, :] . grad_out[t, :]
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
# Reuse output_expanded buffer for grouped_grad_out
grouped_grad_out = output_expanded.flatten(0, 1)
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
# ------------------------------------------------------------------
# LoRA gradients (dA, dB) and setup for dX
# ------------------------------------------------------------------
# Fused gather uses sorted_scattered_idxs for indirect X access
# in the Triton kernel, avoiding the group(x) allocation.
#
# can_fuse_gather: X is ungrouped and not too large for scatter loads
# - When gates is None and grouped_out=False: both DY and X ungrouped
# - When grouped_out=True (gate_up_proj): DY already grouped, X ungrouped
# -> use dy_grouped=True in the fused kernel
M_total = sorted_scattered_idxs.size(0)
K_dim = x.size(-1)
N_dim = expert_weights.size(-1)
fuse_gather_workload = M_total * max(K_dim, N_dim)
_FUSE_GATHER_THRESHOLD = 2**24 # ~16M elements
can_fuse_gather = (
ctx.use_fused_gather
and not grouped_in # X must be ungrouped for scatter access
and gates is None # gate coeff requires multiplicative gather
and fuse_gather_workload < _FUSE_GATHER_THRESHOLD
)
if can_fuse_gather:
# ------------------------------------------------------------------
# Fused path: skip group(x) entirely
# ------------------------------------------------------------------
d_expanded_input = None
d_lora_A, d_lora_B = group_bwd_lora_fused(
DY=grad_out,
X=x,
lora_A=lora_A,
lora_B=lora_B,
expert_offsets=expert_offsets,
sorted_scattered_idxs=sorted_scattered_idxs,
E=E,
k=k,
scaling=scaling,
dy_grouped=grouped_out,
)
# Prepare grouped_grad_out for the dX path (needed by both
# the fused dX kernel when grouped_out=True, and the non-fused path)
if grouped_out:
grouped_grad_out = grad_out
elif not ctx.use_fused_dX:
grouped_grad_out = base_ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
else:
# ------------------------------------------------------------------
# Original path: explicit group() calls
# ------------------------------------------------------------------
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = base_ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = base_ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x # Will be overwritten; reuse buffer
d_lora_A, d_lora_B = group_bwd_lora(
DY=grouped_grad_out,
X=grouped_x,
lora_A=lora_A,
lora_B=lora_B,
expert_offsets=expert_offsets,
E=E,
scaling=scaling,
)
# ------------------------------------------------------------------
# Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A
# ------------------------------------------------------------------
if ctx.use_fused_dX:
if can_fuse_gather and not grouped_out:
# Fully fused: read ungrouped DY via scatter pattern
d_expanded_input = scatter2scatter_lora_dX(
DY=grad_out,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
dy_grouped=False,
dx_grouped=grouped_in,
out=d_expanded_input,
)
else:
# Fused dX only: read from pre-grouped DY
d_expanded_input = scatter2scatter_lora_dX(
DY=grouped_grad_out,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
dy_grouped=True,
dx_grouped=grouped_in,
out=d_expanded_input,
)
else:
# Original path: separate base scatter2scatter + LoRA Python loop
d_expanded_input = base_ops.scatter2scatter(
X=grouped_grad_out,
x_grouped=True,
W=expert_weights.permute(0, 2, 1), # [E, N, K]
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input,
)
# LoRA part: dX_lora = scaling * (dY @ B) @ A
if scaling != 0.0:
d_input_lora_grouped = _compute_lora_input_grad(
grouped_grad_out,
lora_A,
lora_B,
expert_offsets,
E,
scaling,
)
if grouped_in:
d_expanded_input.add_(d_input_lora_grouped)
else:
# Scatter-add LoRA gradient directly into d_expanded_input.
# Avoids allocating a zeros_like + add result
d_expanded_input[sorted_scattered_idxs] += d_input_lora_grouped
# Reduce over top-k if k > 1
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(
x.size(0), k, d_expanded_input.size(-1)
).sum(-2)
# W is frozen during LoRA training -- skip weight gradient
d_weights = (
torch.zeros_like(expert_weights)
if expert_weights.requires_grad
else None
)
d_biases = None
return (
d_input,
d_weights,
None,
None,
None,
None, # k, sorted indices, offsets
d_lora_A,
d_lora_B,
None, # lora_A, lora_B, scaling
d_biases,
d_gates,
None,
None, # grouped_in, grouped_out
None, # use_fused_dX
None, # use_fused_gather
)
def _compute_lora_input_grad(
grouped_grad_out: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
expert_offsets: torch.Tensor,
E: int,
scaling: float,
) -> torch.Tensor:
"""
Compute the LoRA contribution to the input gradient:
dX_lora = scaling * (dY @ B) @ A
Uses PyTorch ops on expert-grouped data.
Each expert e: dX_e = scaling * (dY_e @ B_e) @ A_e
"""
R = lora_A.size(0) // E
K = lora_A.size(1)
M_total = grouped_grad_out.size(0)
d_input_lora = torch.zeros(
(M_total, K), device=grouped_grad_out.device, dtype=grouped_grad_out.dtype
)
compute_dtype = grouped_grad_out.dtype
prev_offset = 0
for e in range(E):
curr_offset = expert_offsets[e].item()
if curr_offset > prev_offset:
dy_e = grouped_grad_out[prev_offset:curr_offset] # [M_e, N]
a_e = lora_A[e * R : (e + 1) * R, :].to(compute_dtype) # [r, K]
b_e = lora_B[:, e * R : (e + 1) * R].to(compute_dtype) # [N, r]
# dX_e = scaling * (dY_e @ B_e) @ A_e
dy_b = dy_e @ b_e # [M_e, r]
dx_e = scaling * (dy_b @ a_e) # [M_e, K]
d_input_lora[prev_offset:curr_offset] = dx_e
prev_offset = curr_offset
return d_input_lora
# =============================================================================
# Helper: Extract LoRA params from PEFT ParamWrapper
# =============================================================================
def get_lora_params_from_wrapper(module) -> tuple:
"""
Extract LoRA parameters from a PEFT ParamWrapper.
Returns:
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
"""
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
return None, None, None
active_adapters = getattr(module, "active_adapters", ["default"])
if not active_adapters:
return None, None, None
adapter_name = active_adapters[0]
lora_A_dict = getattr(module, "lora_A", {})
lora_B_dict = getattr(module, "lora_B", {})
scaling_dict = getattr(module, "scaling", {})
if adapter_name not in lora_A_dict:
return None, None, None
lora_A = lora_A_dict[adapter_name].weight
lora_B = lora_B_dict[adapter_name].weight
scaling = scaling_dict[adapter_name]
return lora_A, lora_B, scaling
# =============================================================================
# Drop-in replacement for parallel_linear
# =============================================================================
def parallel_linear_lora(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
lora_A: Optional[torch.Tensor] = None,
lora_B: Optional[torch.Tensor] = None,
scaling: float = 1.0,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
"""
Drop-in replacement for parallel_linear that supports LoRA.
If lora_A and lora_B are provided, uses fused LoRA kernel.
Otherwise falls back to standard scatter2scatter.
"""
if lora_A is not None and lora_B is not None:
return ScatterMoELoRA.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A,
lora_B,
scaling,
expert_biases,
gates,
grouped_in,
grouped_out,
use_fused_dX,
use_fused_gather,
)
else:
from .parallel_experts import ParallelLinear
return ParallelLinear.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates,
grouped_in,
grouped_out,
)

View File

@@ -1,7 +1,5 @@
from pathlib import Path
from kernels import ( from kernels import (
LocalLayerRepository, LayerRepository,
Mode, Mode,
register_kernel_mapping, register_kernel_mapping,
replace_kernel_forward_from_hub, replace_kernel_forward_from_hub,
@@ -21,19 +19,16 @@ class KernelsPlugin(BasePlugin):
self._kernelize_model(cfg.model_config_type) self._kernelize_model(cfg.model_config_type)
def _register_kernels(self): def _register_kernels(self):
plugin_root = Path(__file__).parent
register_kernel_mapping( register_kernel_mapping(
{ {
"HFScatterMoEParallelExperts": { "HFScatterMoEParallelExperts": {
"cuda": { "cuda": {
Mode.TRAINING: LocalLayerRepository( Mode.TRAINING: LayerRepository(
repo_path=plugin_root / "libs" / "scattermoe_lora", repo_id="axolotl-ai-co/scattermoe",
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP", layer_name="HFScatterMoEGatedMLP",
), ),
Mode.INFERENCE: LocalLayerRepository( Mode.INFERENCE: LayerRepository(
repo_path=plugin_root / "libs" / "scattermoe_lora", repo_id="axolotl-ai-co/scattermoe",
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP", layer_name="HFScatterMoEGatedMLP",
), ),
}, },

View File

@@ -6,12 +6,6 @@ See https://github.com/EleutherAI/lm-evaluation-harness
## Usage ## Usage
There are two ways to use the LM Eval integration:
### 1. Post-Training Evaluation
When training with the plugin enabled, evaluation runs automatically after training completes:
```yaml ```yaml
plugins: plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin - axolotl.integrations.lm_eval.LMEvalPlugin
@@ -22,50 +16,9 @@ lm_eval_tasks:
- arc_easy - arc_easy
lm_eval_batch_size: # Batch size for evaluation lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
# Directory to save evaluation results.
# The final model is loaded from this directory
# unless specified otherwise (see below)
output_dir:
``` ```
Run training as usual:
```bash
axolotl train config.yml
```
### 2. Standalone CLI Evaluation
Evaluate any model directly without training:
```yaml
lm_eval_model: meta-llama/Llama-2-7b-hf
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: 8
output_dir: ./outputs
```
Run evaluation:
```bash
axolotl lm-eval config.yml
```
## Model Selection Priority
The model to evaluate is selected in the following priority order:
1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)
2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub
3. **`output_dir`** - Local checkpoint directory containing trained model weights
## Citation ## Citation
```bib ```bib

View File

@@ -5,7 +5,7 @@ Module for the Plugin for LM Eval Harness
import subprocess # nosec import subprocess # nosec
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs as LMEvalArgs from .args import LMEvalArgs as LMEvalArgs
@@ -29,7 +29,7 @@ class LMEvalPlugin(BasePlugin):
wandb_project=cfg.wandb_project, wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity, wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name, wandb_name=cfg.wandb_name,
model=get_model_path(cfg), model=cfg.lm_eval_model or cfg.hub_model_id,
): ):
subprocess.run( # nosec subprocess.run( # nosec
lm_eval_args, lm_eval_args,

View File

@@ -13,21 +13,6 @@ import yaml
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
def get_model_path(cfg: DictDefault) -> str | None:
"""
Determine which model path to use for evaluation.
Priority order (highest to lowest):
1. lm_eval_model - Explicit model path override
2. hub_model_id - Model pushed to HuggingFace Hub
3. None - Falls back to output_dir in build_lm_eval_command
Returns:
Model path string or None to use output_dir fallback
"""
return cfg.lm_eval_model or cfg.hub_model_id or None
def build_lm_eval_command( def build_lm_eval_command(
tasks: list[str], tasks: list[str],
bfloat16=True, bfloat16=True,
@@ -123,7 +108,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project, wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity, wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name, wandb_name=cfg.wandb_name,
model=get_model_path(cfg), model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision, revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template, apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn, fewshot_as_multiturn=cfg.fewshot_as_multiturn,

View File

@@ -10,7 +10,6 @@ from functools import cached_property
import addict import addict
import transformers import transformers
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
@@ -101,7 +100,6 @@ class PatchManager:
self._apply_fsdp_patches() self._apply_fsdp_patches()
self._apply_adapter_patches() self._apply_adapter_patches()
self._apply_model_specific_patches() self._apply_model_specific_patches()
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches() self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches() self._apply_gradient_checkpointing_patches()
self._patch_attention() self._patch_attention()
@@ -236,17 +234,6 @@ class PatchManager:
patch_kimi_model() patch_kimi_model()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8(
self.cfg.fp8_enable_fsdp_float8_all_gather
)
def _apply_flash_attention_peft_patches(self): def _apply_flash_attention_peft_patches(self):
"""Apply patches for Flash Attention with PEFT.""" """Apply patches for Flash Attention with PEFT."""
if self.cfg.adapter: if self.cfg.adapter:
@@ -329,7 +316,7 @@ class PatchManager:
else: else:
has_remote_code = False has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is not None: if has_remote_code and self.cfg.trust_remote_code is False:
# If explicitly set in YAML, prefer that # If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code has_remote_code = self.cfg.trust_remote_code
@@ -501,7 +488,6 @@ class PatchManager:
and not self.cfg.trust_remote_code and not self.cfg.trust_remote_code
and not self.cfg.gptq and not self.cfg.gptq
and self.cfg.flash_attention and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference and not self.inference
): ):
# TODO(MengqingCao): split these patches separately # TODO(MengqingCao): split these patches separately

View File

@@ -19,11 +19,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if cfg.processor_type: if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type) processor_cls = getattr(transformers, cfg.processor_type)
# Build common kwargs for processor loading
processor_kwargs = {}
if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model
if cfg.tokenizer_use_mistral_common: if cfg.tokenizer_use_mistral_common:
def _patch_mistralcommontokenizer(): def _patch_mistralcommontokenizer():
@@ -45,7 +40,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if processor_cls == VoxtralProcessor: if processor_cls == VoxtralProcessor:
return VoxtralProcessor.from_pretrained( return VoxtralProcessor.from_pretrained(
cfg.processor_config, cfg.processor_config,
**processor_kwargs,
) )
from axolotl.utils.mistral import Mistral3Processor from axolotl.utils.mistral import Mistral3Processor
@@ -54,12 +48,10 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
tokenizer=tokenizer, tokenizer=tokenizer,
) )
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
processor_kwargs["tokenizer"] = tokenizer
processor = processor_cls.from_pretrained( processor = processor_cls.from_pretrained(
cfg.processor_config, cfg.processor_config,
**processor_kwargs, trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
) )
# Attempt to load image size from processor if available # Attempt to load image size from processor if available

View File

@@ -28,10 +28,7 @@ PLUGIN_MANAGER = PluginManager.get_instance()
def modify_tokenizer_files( def modify_tokenizer_files(
tokenizer_path: str, tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
token_mappings: dict[int, str],
output_dir: str,
revision: str = "main",
) -> str: ) -> str:
""" """
Modify tokenizer files to replace added_tokens strings, save to output directory, Modify tokenizer files to replace added_tokens strings, save to output directory,
@@ -44,7 +41,6 @@ def modify_tokenizer_files(
tokenizer_path: Path or name of the original tokenizer tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string} token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer output_dir: Directory to save the modified tokenizer
revision: Model revision/branch/tag/commit to load from (HF Hub)
Returns: Returns:
Path to the modified tokenizer directory Path to the modified tokenizer directory
@@ -57,9 +53,7 @@ def modify_tokenizer_files(
if is_local_main_process(): if is_local_main_process():
# Load the tokenizer # Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained( temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
tokenizer_path, use_fast=True, revision=revision
)
# Save the tokenizer to the output directory # Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir) temp_tokenizer.save_pretrained(tokenizer_dir)
@@ -140,10 +134,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
from axolotl.utils.mistral import HFMistralTokenizer from axolotl.utils.mistral import HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer # Load the HF-compatible wrapper around MistralTokenizer
kwargs = {} tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
if cfg.revision_of_model:
kwargs["revision"] = cfg.revision_of_model
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs)
return tokenizer return tokenizer
@@ -159,8 +150,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if cfg.tokenizer_legacy is not None: if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224 # True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if cfg.revision_of_model:
tokenizer_kwargs["revision"] = cfg.revision_of_model
tokenizer_cls = AutoTokenizer tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type: if cfg.tokenizer_type:
@@ -172,11 +161,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
# Apply token string overrides if specified # Apply token string overrides if specified
if cfg.added_tokens_overrides: if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer # Modify tokenizer files and get path to modified tokenizer
modify_kwargs = {"output_dir": cfg.output_dir}
if cfg.revision_of_model:
modify_kwargs["revision"] = cfg.revision_of_model
tokenizer_path = modify_tokenizer_files( tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
) )
tokenizer = tokenizer_cls.from_pretrained( tokenizer = tokenizer_cls.from_pretrained(

View File

@@ -59,12 +59,7 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True hidden_states.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args) (output,) = ctx.forward_function(hidden_states, *ctx.args)
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
# return a plain tensor, not a tuple. Older models return tuples
# like (hidden_states, present_kv, ...). Unwrap if needed.
if isinstance(output, (tuple, list)):
(output,) = output
torch.autograd.backward(output, dY) torch.autograd.backward(output, dY)
return ( return (
None, None,

View File

@@ -1,83 +0,0 @@
"""
allow adding additional kwargs to Accelerator init
"""
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
# create accelerator object
self.accelerator = Accelerator(**args)
"""
PATCHED_TRAINER_CODE = """
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args)
if additional_args:
args.update(additional_args)
# create accelerator object
self.accelerator = Accelerator(**args)
"""
def get_create_accelerate_code() -> str:
training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess)
return training_loop
def check_create_accelerate_code_is_patchable() -> bool:
create_code = get_create_accelerate_code()
create_code, _ = detab_code(create_code)
return ORIGINAL_TRAINER_CODE in create_code
def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool):
"""
Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs.
"""
try:
create_code = get_create_accelerate_code()
except OSError:
return
Trainer._original_create_accelerator_and_postprocess = create_code
create_code, _ = detab_code(create_code)
if ORIGINAL_TRAINER_CODE not in create_code:
return
patched_trainer_code = PATCHED_TRAINER_CODE.format(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather
)
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code)
create_code = create_code.replace(
"def create_accelerator_and_postprocess(",
"def fixed_create_accelerator_and_postprocess(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in create_code:
items_to_import.append(item)
exec(
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(create_code, globals())
LOG.info("patching create_accelerator_and_postprocess to allow for overrides")
Trainer.create_accelerator_and_postprocess = (
fixed_create_accelerator_and_postprocess
)

View File

@@ -28,12 +28,8 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()', "array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
} }
ORIGINAL_MAYBE_CODE = ( ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()" PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
)
PATCHED_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
)
def check_evaluation_loop_is_patchable() -> bool: def check_evaluation_loop_is_patchable() -> bool:

View File

@@ -48,9 +48,9 @@ class ChatTemplatePrompter(Prompter):
): ):
# check if message_property_mappings is None or empty dict # check if message_property_mappings is None or empty dict
if message_property_mappings is None or (not message_property_mappings): if message_property_mappings is None or (not message_property_mappings):
default_message_property_mappings_keys = ["role", "content", "tool"]
message_property_mappings = { message_property_mappings = {
prop: prop for prop in default_message_property_mappings_keys "role": "role",
"content": "content",
} }
if template_thinking_key and field_thinking: if template_thinking_key and field_thinking:
message_property_mappings[template_thinking_key] = field_thinking message_property_mappings[template_thinking_key] = field_thinking

View File

@@ -156,10 +156,6 @@ class TelemetryManager:
Returns: Returns:
Boolean denoting whether telemetry is enabled or not. Boolean denoting whether telemetry is enabled or not.
""" """
# Only rank 0 will send telemetry
if not is_main_process():
return False
# Parse relevant env vars # Parse relevant env vars
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK") axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK") do_not_track = os.getenv("DO_NOT_TRACK")
@@ -173,6 +169,10 @@ class TelemetryManager:
): ):
return True return True
# Only rank 0 will send telemetry
if not is_main_process():
return False
if do_not_track is None: if do_not_track is None:
do_not_track = "0" do_not_track = "0"

View File

@@ -1,84 +0,0 @@
"""Callback for generating samples during SFT/Pretrain training."""
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.utils.generation.sft import generate_samples
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SFTGenerationCallback(TrainerCallback):
"""Callback for generating samples during SFT/Pretrain training."""
def __init__(self, trainer):
self.trainer = trainer
def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Generate samples at specified intervals."""
cfg = self.trainer.axolotl_cfg
if not getattr(cfg, "generate_samples", False):
return
dataloader = None
try:
if getattr(self.trainer, "eval_dataset", None) is not None:
dataloader = self.trainer.get_eval_dataloader()
LOG.info(
f"Using eval dataloader for generation at step {state.global_step}"
)
except Exception as e:
LOG.warning(f"Could not get eval dataloader: {e}")
dataloader = None
if dataloader is None:
dataloader = self.trainer.get_train_dataloader()
LOG.info(
f"Using train dataloader for generation at step {state.global_step}"
)
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.processing_class,
dataloader=dataloader,
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
temperature=getattr(cfg, "generation_temperature", 0.7),
top_p=getattr(cfg, "generation_top_p", None),
top_k=getattr(cfg, "generation_top_k", None),
do_sample=getattr(cfg, "generation_do_sample", True),
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
)
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples to console and W&B."""
from axolotl.utils.generation.sft import format_generation_for_logging
for i, sample in enumerate(samples):
console_text, wandb_text = format_generation_for_logging(sample, i, step)
LOG.info(console_text)
try:
import wandb
if wandb.run is not None:
wandb.log(
{
f"samples/sample_{i + 1}": wandb.Html(
f"<pre>{wandb_text}</pre>"
)
},
step=step,
)
except (ImportError, Exception):
pass

View File

@@ -54,19 +54,15 @@ class FileLockLoader:
def cleanup(self): def cleanup(self):
"""Clean up ready flag when last process is done.""" """Clean up ready flag when last process is done."""
try: with FileLock(str(self.lock_file_path)):
with FileLock(str(self.lock_file_path)): counter_content = self.counter_path.read_text().strip()
counter_content = self.counter_path.read_text().strip() count = int(counter_content) if counter_content else 0
count = int(counter_content) if counter_content else 0 count -= 1
count -= 1
if count <= 0: if count <= 0:
# Last process cleans everything up # Last process cleans everything up
self.ready_flag_path.unlink(missing_ok=True) self.ready_flag_path.unlink(missing_ok=True)
self.counter_path.unlink(missing_ok=True) self.counter_path.unlink(missing_ok=True)
else: else:
# Still have active processes # Still have active processes
self.counter_path.write_text(str(count)) self.counter_path.write_text(str(count))
except FileNotFoundError:
# Lock file might have already been deleted by another process
pass

View File

@@ -246,10 +246,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
dataset = merge_datasets(split_datasets, cfg) dataset = merge_datasets(split_datasets, cfg)
if not cfg.skip_prepare_dataset: if not cfg.skip_prepare_dataset:
# Deduplicate before saving so the saved dataset is already de-duplicated
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
# Save preprocessed dataset # Save preprocessed dataset
dataset_hash = generate_dataset_hash_from_config( dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path cfg, datasets_configs, tokenizer.name_or_path

View File

@@ -351,10 +351,6 @@ def _load_raw_datasets(
if cfg.sample_packing: if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None) dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Deduplicate before saving so the saved dataset is already de-duplicated
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
# Save the prepared dataset # Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config( dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path cfg, datasets_configs, tokenizer.name_or_path
@@ -442,8 +438,25 @@ def _handle_train_dataset_split(
) )
return train_dataset, eval_dataset return train_dataset, eval_dataset
# No validation split - deduplication already applied during preprocessing # No validation split - apply deduplication if needed and return as train dataset
return dataset, None if cfg.dataset_exact_deduplication:
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
train_dataset = dataset
return train_dataset, None
def _handle_test_dataset_split(
dataset: Dataset, cfg: DictDefault
) -> tuple[None, Dataset | None]:
"""Handle processing for test split."""
if cfg.dataset_exact_deduplication:
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
eval_dataset = dataset
return None, eval_dataset
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset: def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
@@ -502,7 +515,6 @@ def _load_and_prepare_datasets(
if split == "train": if split == "train":
train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg) train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg)
else: else:
# Deduplication already applied during preprocessing train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg)
train_dataset, eval_dataset = None, dataset
return train_dataset, eval_dataset, prompters return train_dataset, eval_dataset, prompters

View File

@@ -520,8 +520,7 @@ def generate_dataset_hash_from_config(
""" """
config_str = ( config_str = (
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@" f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@" f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
f"{cfg.dataset_exact_deduplication or False}|"
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}" f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
f"|{tokenizer_name}" f"|{tokenizer_name}"
) )

View File

@@ -15,7 +15,7 @@ from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.samplers.utils import get_dataset_lengths from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import filter_sequences_by_length from axolotl.utils.trainer import drop_long_seq
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -148,33 +148,22 @@ def deduplicate_and_log_datasets(
return dataset, other_dataset return dataset, other_dataset
def keep_min_len(sample, min_sequence_len=2): def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
""" """
Batched filter function that keeps only samples with sequence length >= min_sequence_len. Truncate samples whose sequence length is too long (> sequence_len)
Returns a list of booleans indicating which samples to keep. or drop those too short (< min_sequence_len).
""" """
min_sequence_len = min_sequence_len or 2 min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"] input_ids = sample["input_ids"]
# Batched (input_ids is a list of lists)
results = [] results = []
for seq in input_ids:
results.append(len(seq) >= min_sequence_len)
return results
def truncate_long_seq(sample, sequence_len=2048):
"""
Truncate samples whose sequence length is too long (> sequence_len).
Modifies the sample in-place and returns the modified sample.
"""
input_ids = sample["input_ids"]
# Batched (input_ids is a list of lists) # Batched (input_ids is a list of lists)
for i, seq in enumerate(input_ids): for i, seq in enumerate(input_ids):
length = len(seq) length = len(seq)
if length > sequence_len: if length < min_sequence_len:
results.append(False)
elif length > sequence_len:
sample["input_ids"][i] = seq[:sequence_len] sample["input_ids"][i] = seq[:sequence_len]
if "attention_mask" in sample: if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len] sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
@@ -182,133 +171,10 @@ def truncate_long_seq(sample, sequence_len=2048):
sample["labels"][i] = sample["labels"][i][:sequence_len] sample["labels"][i] = sample["labels"][i][:sequence_len]
if "position_ids" in sample: if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
return sample results.append(True)
else:
results.append(True)
def _should_skip_processing(dataset: Dataset) -> bool: return results
"""Check if dataset should skip long sequence handling."""
if (
hasattr(dataset, "column_names")
and dataset.column_names
and "input_ids" not in dataset.column_names
):
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return True
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
LOG.info(
"Dataset is streaming (IterableDataset), skipping long sequence handling"
)
return True
return False
def _log_dataset_stats(dataset: Dataset) -> None:
"""Log min/max sequence lengths for debugging."""
with contextlib.suppress(AttributeError, ValueError):
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
LOG.info(f"min_input_len: {np.min(ds_lengths)}")
LOG.info(f"max_input_len: {np.max(ds_lengths)}")
def _build_filter_kwargs(dataset: Dataset, cfg: DictDefault) -> dict:
"""Build kwargs for dataset filter/map operations."""
kwargs = {}
if not isinstance(dataset, IterableDataset):
kwargs["num_proc"] = cfg.dataset_num_proc
kwargs["load_from_cache_file"] = not cfg.is_preprocess
return kwargs
def _filter_short_sequences(
dataset: Dataset, min_len: int, filter_kwargs: dict
) -> tuple[Dataset, int]:
"""Filter out sequences shorter than min_len. Returns (dataset, num_dropped)."""
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
desc_kwargs = {}
if filter_kwargs:
desc_kwargs["desc"] = f"Filtering Short Sequences (<{min_len})"
dataset = dataset.filter(
functools.partial(keep_min_len, min_sequence_len=min_len),
batched=True,
**filter_kwargs,
**desc_kwargs,
)
dropped = 0
if prior_len:
dropped = prior_len - len(dataset)
if dropped > 0:
LOG.info(f"Dropped {dropped} short sequences (<{min_len} tokens)")
return dataset, dropped
def _truncate_long_sequences(
dataset: Dataset, max_len: int, map_kwargs: dict
) -> Dataset:
"""Truncate sequences longer than max_len."""
desc_kwargs = {}
if map_kwargs:
desc_kwargs["desc"] = f"Truncating Sequences (target_len={max_len})"
dataset = dataset.map(
functools.partial(truncate_long_seq, sequence_len=max_len),
batched=True,
**map_kwargs,
**desc_kwargs,
)
LOG.info(f"Truncated long sequences to max length {max_len}")
return dataset
def _drop_outside_range(
dataset: Dataset,
max_len: int,
min_len: int,
raise_on_long: bool,
filter_kwargs: dict,
) -> tuple[Dataset, int]:
"""Drop sequences outside valid length range [min_len, max_len].
Returns (dataset, num_dropped)."""
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
desc_kwargs = {}
if filter_kwargs:
action = (
"Checking Sequence Lengths"
if raise_on_long
else "Dropping Invalid Sequences"
)
desc_kwargs["desc"] = f"{action} (<{min_len} or >{max_len})"
dataset = dataset.filter(
functools.partial(
filter_sequences_by_length,
sequence_len=max_len,
min_sequence_len=min_len,
raise_on_drop=raise_on_long,
),
batched=True,
**filter_kwargs,
**desc_kwargs,
)
dropped = 0
if not raise_on_long and prior_len:
dropped = prior_len - len(dataset)
if dropped > 0:
LOG.info(
f"Dropped {dropped} sequences outside valid range "
f"([{min_len}, {max_len}])"
)
return dataset, dropped
def handle_long_seq_in_dataset( def handle_long_seq_in_dataset(
@@ -327,25 +193,80 @@ def handle_long_seq_in_dataset(
'truncate' truncates them down to sequence_len 'truncate' truncates them down to sequence_len
'raise' raises a ValueError if any sequence was found that was longer than sequence_len 'raise' raises a ValueError if any sequence was found that was longer than sequence_len
""" """
# Early returns for special cases if (
if _should_skip_processing(dataset): hasattr(dataset, "column_names")
and dataset.column_names
and "input_ids" not in dataset.column_names
):
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
LOG.info(
"Dataset is streaming (IterableDataset), skipping long sequence handling"
)
return dataset return dataset
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
_log_dataset_stats(dataset) drop_long = functools.partial(
drop_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
raise_on_drop=excess_length_strategy == "raise",
)
# Setup kwargs with contextlib.suppress(AttributeError):
filter_kwargs = _build_filter_kwargs(dataset, cfg) ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
# Handle sequences based on strategy prior_len = len(dataset) if hasattr(dataset, "__len__") else None
if excess_length_strategy == "truncate":
dataset, _ = _filter_short_sequences(dataset, cfg.min_sample_len, filter_kwargs) filter_map_kwargs = {}
dataset = _truncate_long_sequences(dataset, sequence_len, filter_kwargs) if not isinstance(dataset, IterableDataset):
else: filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
raise_on_long = excess_length_strategy == "raise" filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
dataset, _ = _drop_outside_range(
dataset, sequence_len, cfg.min_sample_len, raise_on_long, filter_kwargs drop_long_kwargs = {}
if filter_map_kwargs:
action = (
"Checking Sequence Lengths"
if excess_length_strategy == "raise"
else "Dropping Long Sequences"
) )
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
)
drop_long_kwargs["desc"] = (
f"Truncating/Filtering Sequences (target_len={sequence_len})"
)
else:
process_fn = drop_long
dataset = dataset.filter(
process_fn,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
action = (
"truncated/filtered"
if excess_length_strategy == "truncate"
else "dropped"
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
return dataset return dataset

View File

@@ -1,5 +0,0 @@
"""Generation utilities for monitoring during training."""
from .sft import format_generation_for_logging, generate_samples
__all__ = ["generate_samples", "format_generation_for_logging"]

View File

@@ -1,174 +0,0 @@
"""Sample generation utilities for SFT/Pretrain training."""
from typing import Any, List, Optional
import torch
from accelerate.utils import extract_model_from_parallel
from colorama import Fore, Style
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
dataloader: Any,
num_generation_samples: int = 3,
max_new_tokens: int = 50,
temperature: float = 0.7,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
do_sample: bool = True,
prompt_ratio: float = 0.5,
) -> List[dict]:
"""
Generate samples from the model during training for monitoring.
Args:
model: The model to generate from
tokenizer: The tokenizer to use for encoding/decoding
dataloader: Dataloader to sample prompts from
num_generation_samples: Number of samples to generate
max_new_tokens: Maximum new tokens to generate
temperature: Sampling temperature (0.0 = greedy)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
do_sample: Whether to use sampling vs greedy decoding
prompt_ratio: Ratio of sequence to use as prompt (0.0-1.0)
Returns:
List of dicts with 'prompt', 'generated', and 'full_text' keys
"""
unwrapped_model = extract_model_from_parallel(model)
training = unwrapped_model.training
unwrapped_model.eval()
device = next(unwrapped_model.parameters()).device
generations = []
try:
with torch.no_grad():
samples_collected = 0
for batch in dataloader:
if samples_collected >= num_generation_samples:
break
input_ids = batch["input_ids"].to(device)
attention_mask = batch.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(device)
batch_size = input_ids.shape[0]
indices = torch.randperm(batch_size)[
: num_generation_samples - samples_collected
]
for idx in indices:
if samples_collected >= num_generation_samples:
break
sequence = input_ids[idx]
if attention_mask is not None:
seq_len = attention_mask[idx].sum().item()
else:
seq_len = sequence.shape[0]
if seq_len < 5:
continue
prompt_len = max(1, int(seq_len * prompt_ratio))
prompt_ids = sequence[:prompt_len].unsqueeze(0)
try:
generation_config = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"pad_token_id": tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else tokenizer.eos_token_id,
}
if do_sample:
generation_config["temperature"] = temperature
if top_p is not None:
generation_config["top_p"] = top_p
if top_k is not None:
generation_config["top_k"] = top_k
generated_ids = unwrapped_model.generate(
prompt_ids, **generation_config
)
prompt_text = tokenizer.decode(
prompt_ids[0], skip_special_tokens=True
)
generated_text = tokenizer.decode(
generated_ids[0][prompt_len:], skip_special_tokens=True
)
full_text = tokenizer.decode(
generated_ids[0], skip_special_tokens=True
)
generations.append(
{
"prompt": prompt_text,
"generated": generated_text,
"full_text": full_text,
}
)
samples_collected += 1
except Exception as e:
LOG.warning(f"Failed to generate sample: {e}", exc_info=True)
continue
except Exception as e:
LOG.warning(f"Error during sample generation: {e}", exc_info=True)
if training:
unwrapped_model.train()
else:
unwrapped_model.eval()
return generations
def format_generation_for_logging(
sample: dict, sample_idx: int, step: int
) -> tuple[str, str]:
"""
Format a generation sample for pretty logging.
Args:
sample: Dict with 'prompt', 'generated', and 'full_text' keys
sample_idx: Index of the sample
step: Current training step
Returns:
Tuple of (console_text, wandb_text)
"""
console_text = (
f"\n{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
f"{Style.BRIGHT}{Fore.GREEN}Sample {sample_idx + 1} (Step {step}){Style.RESET_ALL}\n"
f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
f"{Style.BRIGHT}{Fore.YELLOW}[PROMPT]{Style.RESET_ALL}\n{sample['prompt']}\n\n"
f"{Style.BRIGHT}{Fore.MAGENTA}[GENERATED]{Style.RESET_ALL}\n{sample['generated']}\n"
f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
)
wandb_text = (
f"\n{'=' * 80}\n"
f"Sample {sample_idx + 1} (Step {step})\n"
f"{'=' * 80}\n"
f"[PROMPT]\n{sample['prompt']}\n\n"
f"[GENERATED]\n{sample['generated']}\n"
f"{'=' * 80}\n"
)
return console_text, wandb_text

View File

@@ -30,8 +30,18 @@ class Mistral3Processor(ProcessorMixin):
Wraps HFMistralTokenizer and adds image processing capabilities. Wraps HFMistralTokenizer and adds image processing capabilities.
""" """
# TODO(nano): This should be removed in transformers V5
attributes = ["tokenizer"]
tokenizer_class = "HFMistralTokenizer"
def __init__(self, tokenizer: HFMistralTokenizer): def __init__(self, tokenizer: HFMistralTokenizer):
super().__init__(tokenizer) # Don't call super().__init__ to avoid the class validation issue
self.tokenizer = tokenizer
@property
def chat_template(self) -> None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return None
@property @property
def audio_tokenizer(self) -> None: def audio_tokenizer(self) -> None:

View File

@@ -338,6 +338,18 @@ class AxolotlInputConfig(
) )
ddp_find_unused_parameters: bool | None = None ddp_find_unused_parameters: bool | None = None
eval_table_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0"
},
)
eval_max_new_tokens: int | None = Field(
default=None,
json_schema_extra={
"description": "Total number of tokens generated for predictions sent to wandb. Default is 128"
},
)
do_causal_lm_eval: bool | None = Field( do_causal_lm_eval: bool | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
@@ -434,16 +446,7 @@ class AxolotlInputConfig(
}, },
) )
unfrozen_parameters: list[str] | None = Field( unfrozen_parameters: list[str] | None = None
default=None,
json_schema_extra={
"description": "List of regex patterns for parameter names to keep unfrozen. "
"All other parameters will be frozen via requires_grad=False. "
"Note: range-based patterns (e.g. embed_tokens.weight$[:32000]) use gradient "
"zeroing rather than a true freeze, so weight decay will still apply to the "
"frozen portion and optimizer states are allocated for the full parameter."
},
)
sequence_len: int = Field( sequence_len: int = Field(
default=512, default=512,
@@ -1094,46 +1097,6 @@ class AxolotlInputConfig(
"description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html" "description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html"
}, },
) )
generate_samples: bool | None = Field(
default=False,
json_schema_extra={
"description": "Enable sample generation during training for monitoring"
},
)
num_generation_samples: int | None = Field(
default=3,
json_schema_extra={
"description": "Number of samples to generate at each interval"
},
)
generation_max_new_tokens: int | None = Field(
default=50,
json_schema_extra={"description": "Maximum new tokens to generate per sample"},
)
generation_temperature: float | None = Field(
default=0.7,
json_schema_extra={
"description": "Temperature for sample generation (0.0 = greedy)"
},
)
generation_top_p: float | None = Field(
default=None,
json_schema_extra={"description": "Nucleus sampling parameter for generation"},
)
generation_top_k: int | None = Field(
default=None,
json_schema_extra={"description": "Top-k sampling parameter for generation"},
)
generation_prompt_ratio: float | None = Field(
default=0.5,
json_schema_extra={"description": "Ratio of input to use as prompt (0.0-1.0)"},
)
generation_do_sample: bool | None = Field(
default=True,
json_schema_extra={
"description": "Whether to use sampling (vs greedy decoding)"
},
)
@field_serializer("datasets") @field_serializer("datasets")
def datasets_serializer( def datasets_serializer(
@@ -1509,16 +1472,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"dataset_exact_deduplication is not available for streaming datasets. " "dataset_exact_deduplication is not available for streaming datasets. "
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_deduplication_with_skip_prepare(cls, data):
if data.get("dataset_exact_deduplication") and data.get("skip_prepare_dataset"):
raise ValueError(
"dataset_exact_deduplication=True has no effect when "
"skip_prepare_dataset=True. Deduplication runs as part of the "
"prepare pipeline, which is skipped. Either set "
"skip_prepare_dataset: false or disable "
"dataset_exact_deduplication."
)
return data

View File

@@ -17,8 +17,6 @@ class DeprecatedParameters(BaseModel):
noisy_embedding_alpha: float | None = None noisy_embedding_alpha: float | None = None
dpo_beta: float | None = None dpo_beta: float | None = None
evaluation_strategy: str | None = None evaluation_strategy: str | None = None
eval_table_size: int | None = None
eval_max_new_tokens: int | None = None
@field_validator("max_packed_sequence_len") @field_validator("max_packed_sequence_len")
@classmethod @classmethod
@@ -57,27 +55,6 @@ class DeprecatedParameters(BaseModel):
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead") LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
return evaluation_strategy return evaluation_strategy
@field_validator("eval_table_size")
@classmethod
def validate_eval_table_size(cls, eval_table_size):
if eval_table_size is not None:
LOG.warning(
"eval_table_size is deprecated and superseded by generate_samples config. "
"Please use generate_samples: true and num_generation_samples instead. "
"The LogPredictionCallback is replaced by the new sample generation feature."
)
return eval_table_size
@field_validator("eval_max_new_tokens")
@classmethod
def validate_eval_max_new_tokens(cls, eval_max_new_tokens):
if eval_max_new_tokens is not None:
LOG.warning(
"eval_max_new_tokens is deprecated and superseded by generate_samples config. "
"Please use generation_max_new_tokens instead."
)
return eval_max_new_tokens
class RemappedParameters(BaseModel): class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names""" """Parameters that have been remapped to other names"""

View File

@@ -1,6 +1,6 @@
"""Pydantic models for PEFT-related configuration""" """Pydantic models for PEFT-related configuration"""
from typing import Any, Literal from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
@@ -38,10 +38,10 @@ class LoraConfig(BaseModel):
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"} default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
) )
adapter: Literal["lora", "qlora", "llama-adapter"] | None = Field( adapter: str | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "If you want to use 'lora', 'qlora', or 'llama-adapter', or leave blank to train all parameters in original model" "description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model"
}, },
) )
lora_model_dir: str | None = Field( lora_model_dir: str | None = Field(

View File

@@ -205,13 +205,10 @@ def add_length(sample):
return sample return sample
def filter_sequences_by_length( def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False
):
""" """
Filter sequences outside valid length range [min_sequence_len, sequence_len]. Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Drops samples that are either too short (< min_sequence_len) or too long (> sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]). Works for both single-example (list[int]) or batched (list[list[int]]).
@@ -386,10 +383,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing( def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
): ):
drop_outside_range = partial(filter_sequences_by_length, sequence_len=sequence_len) drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
drop_outside_range, drop_long,
desc="Dropping Long Sequences", desc="Dropping Long Sequences",
load_from_cache_file=False, load_from_cache_file=False,
) )
@@ -483,7 +480,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
bin_size=cfg.sample_packing_bin_size, bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially, sequential=cfg.sample_packing_sequentially,
drop_last=True, drop_last=True,
num_processes=cfg.dataset_num_proc, num_processes=cfg.dataset_prcoesses,
mp_start_method=cfg.sample_packing_mp_start_method or "fork", mp_start_method=cfg.sample_packing_mp_start_method or "fork",
) )

View File

@@ -1,227 +0,0 @@
"""Tests for nested config option handling via CLI dot-notation."""
import click
from click.testing import CliRunner
from pydantic import BaseModel, Field
from axolotl.cli.utils.args import add_options_from_config, filter_none_kwargs
class InnerConfig(BaseModel):
"""A nested config model for testing."""
beta: float | None = Field(
default=None,
description="Beta parameter.",
)
host: str | None = Field(
default=None,
description="Server host.",
)
use_feature: bool = Field(
default=False,
description="Whether to use the feature.",
)
class OuterConfig(BaseModel):
"""A top-level config model for testing."""
learning_rate: float | None = Field(
default=None,
description="Learning rate.",
)
inner: InnerConfig | None = Field(
default=None,
description="Inner config.",
)
name: str | None = Field(
default=None,
description="Model name.",
)
class TestAddOptionsFromConfigNested:
"""Test that add_options_from_config handles nested BaseModel fields."""
def setup_method(self):
self.runner = CliRunner()
def test_nested_dot_notation_options_are_registered(self):
"""Nested model fields should create --parent.child CLI options."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
for k, v in sorted(kwargs.items()):
click.echo(f"{k}={v}")
result = self.runner.invoke(cmd, ["--inner.beta=0.5", "--inner.host=localhost"])
assert result.exit_code == 0, result.output
assert "inner__beta=0.5" in result.output
assert "inner__host=localhost" in result.output
def test_nested_bool_option(self):
"""Nested bool fields should support --parent.field/--no-parent.field."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
for k, v in sorted(kwargs.items()):
click.echo(f"{k}={v}")
result = self.runner.invoke(cmd, ["--inner.use-feature"])
assert result.exit_code == 0, result.output
assert "inner__use_feature=True" in result.output
def test_flat_and_nested_options_together(self):
"""Flat and nested options should work together."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
for k, v in sorted(kwargs.items()):
click.echo(f"{k}={v}")
result = self.runner.invoke(
cmd, ["--learning-rate=0.001", "--inner.beta=0.1", "--name=test"]
)
assert result.exit_code == 0, result.output
assert "learning_rate=0.001" in result.output
assert "inner__beta=0.1" in result.output
assert "name=test" in result.output
def test_no_nested_options_passed(self):
"""When no nested options are passed, they should not appear in kwargs."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
click.echo(f"keys={sorted(kwargs.keys())}")
result = self.runner.invoke(cmd, ["--learning-rate=0.01"])
assert result.exit_code == 0, result.output
assert "inner__" not in result.output
class TestLoadCfgNestedKwargs:
"""Test that load_cfg correctly applies nested (double-underscore) kwargs."""
@staticmethod
def _apply_nested_kwargs(cfg, kwargs):
"""Helper that mirrors the nested kwargs handling from load_cfg,
including type coercion for string CLI values."""
from axolotl.cli.config import _coerce_value
nested_kwargs: dict = {}
flat_kwargs: dict = {}
for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
cfg_keys = cfg.keys()
for key, value in flat_kwargs.items():
if key in cfg_keys:
cfg[key] = _coerce_value(value, cfg.get(key))
for parent, children in nested_kwargs.items():
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
cfg[parent] = {}
for child_key, child_value in children.items():
existing = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing)
return cfg
def test_nested_kwargs_applied_to_cfg(self, tmp_path):
"""Double-underscore kwargs should set nested config values."""
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"trl": {"beta": 0.1}, "learning_rate": 0.01})
# CLI passes strings, so simulate that
kwargs = {
"trl__beta": "0.5",
"trl__host": "192.168.1.1",
"learning_rate": "0.02",
}
cfg = self._apply_nested_kwargs(cfg, kwargs)
assert cfg["learning_rate"] == 0.02
assert isinstance(cfg["learning_rate"], float)
assert cfg["trl"]["beta"] == 0.5
assert isinstance(cfg["trl"]["beta"], float)
assert cfg["trl"]["host"] == "192.168.1.1"
def test_nested_kwargs_creates_parent_if_none(self):
"""If the parent key is None, nested kwargs should create the dict."""
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"trl": None, "learning_rate": 0.01})
cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"})
# No existing value, YAML-style inference: "0.5" -> 0.5
assert cfg["trl"]["beta"] == 0.5
assert isinstance(cfg["trl"]["beta"], float)
def test_nested_kwargs_overwrites_string_parent(self):
"""If the parent key is a string, it should be replaced with a dict."""
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"trl": "some_string", "learning_rate": 0.01})
cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"})
assert cfg["trl"]["beta"] == 0.5
class TestCoerceValue:
"""Test YAML-style type coercion for CLI string values."""
def test_coerce_with_existing_float(self):
from axolotl.cli.config import _coerce_value
assert _coerce_value("0.5", 0.1) == 0.5
assert isinstance(_coerce_value("0.5", 0.1), float)
def test_coerce_with_existing_int(self):
from axolotl.cli.config import _coerce_value
assert _coerce_value("42", 10) == 42
assert isinstance(_coerce_value("42", 10), int)
def test_coerce_with_existing_bool(self):
from axolotl.cli.config import _coerce_value
assert _coerce_value("true", False) is True
assert _coerce_value("false", True) is False
assert _coerce_value("1", False) is True
assert _coerce_value("0", True) is False
def test_coerce_yaml_inference_no_existing(self):
"""Without an existing value, use YAML-style inference."""
from axolotl.cli.config import _coerce_value
assert _coerce_value("true", None) is True
assert _coerce_value("false", None) is False
assert _coerce_value("42", None) == 42
assert isinstance(_coerce_value("42", None), int)
assert _coerce_value("3.14", None) == 3.14
assert isinstance(_coerce_value("3.14", None), float)
assert _coerce_value("null", None) is None
assert _coerce_value("hello", None) == "hello"
def test_coerce_non_string_passthrough(self):
"""Non-string values should pass through unchanged."""
from axolotl.cli.config import _coerce_value
assert _coerce_value(0.5, 0.1) == 0.5
assert _coerce_value(True, False) is True

View File

@@ -300,6 +300,7 @@ class TestHFRLTrainerBuilder:
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl) self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
# ORPO specific # ORPO specific
assert training_arguments.beta == 0.1 # maps from orpo_alpha assert training_arguments.beta == 0.1 # maps from orpo_alpha
assert training_arguments.max_prompt_length == 512
def test_kto_training_arguments(self, kto_cfg, model, tokenizer): def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer) builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)

View File

@@ -186,7 +186,6 @@ class TestFSDP1:
verify_training_success(temp_dir) verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test, deprecate fsdp1 asap")
def test_dpo_fft(self, temp_dir): def test_dpo_fft(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {

View File

@@ -365,7 +365,6 @@ class TestFSDP2:
verify_training_success(temp_dir) verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12")
@require_torch_2_7_0 @require_torch_2_7_0
def test_dpo_fft(self, temp_dir): def test_dpo_fft(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
@@ -423,7 +422,6 @@ class TestFSDP2:
verify_training_success(temp_dir) verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12")
@require_torch_2_7_0 @require_torch_2_7_0
def test_dpo_lora(self, temp_dir): def test_dpo_lora(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(

View File

@@ -1,323 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
Unit tests for scattermoe-lora code-review fixes.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel_scattermoe
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
- group_compileable: coeff=None accepted
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
"""
from unittest.mock import patch
import pytest
import torch
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
# ============================================================================
class TestKernelsArgsValidator:
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
These tests call the validator classmethod directly on raw dicts,
since lora_mlp_kernel / mlp_kernel are not declared model fields.
"""
def test_disables_lora_mlp_kernel_when_scattermoe(self):
"""lora_mlp_kernel=True gets set to False when use_scattermoe=True."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": True,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
def test_mlp_kernel_disabled_without_lora(self):
"""Even without lora_mlp_kernel, mlp_kernel should be disabled."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["mlp_kernel"] is False
# lora_mlp_kernel was not in data, should not be added
assert "lora_mlp_kernel" not in result
def test_lora_mlp_kernel_false_unchanged(self):
"""lora_mlp_kernel=False should stay False (no warning, no change)."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": True,
"lora_mlp_kernel": False,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["lora_mlp_kernel"] is False
def test_no_change_when_scattermoe_disabled(self):
"""When use_scattermoe is not True, nothing should be changed."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": False,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["lora_mlp_kernel"] is True
class TestParallelExpertsScaling:
"""Test that scaling=0.0 is preserved and not overridden to 1.0."""
def test_scaling_zero_preserved(self):
"""scaling=0.0 should be passed as 0.0, not replaced with 1.0."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
ParallelExperts,
)
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
pe.set_lora(
lora_A=torch.randn(4, 4),
lora_B=torch.randn(4, 4),
scaling=0.0,
)
assert pe._lora_scaling == 0.0
# Patch parallel_linear_lora to capture the scaling arg
with patch(
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
) as mock_pll:
mock_pll.return_value = torch.randn(4, 4)
# Create dummy routing tensors
pe.forward(
inputs=torch.randn(2, 4),
k=1,
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
expert_offsets=torch.tensor([2, 4]),
)
# Check that scaling=0.0 was passed, not 1.0
call_kwargs = mock_pll.call_args
assert (
call_kwargs.kwargs.get("scaling") == 0.0
or call_kwargs[1].get("scaling") == 0.0
), f"Expected scaling=0.0 but got {call_kwargs}"
def test_scaling_none_defaults_to_one(self):
"""scaling=None (no LoRA attached) should default to 1.0."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
ParallelExperts,
)
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
# No set_lora called, so _lora_scaling is None
with patch(
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
) as mock_pll:
mock_pll.return_value = torch.randn(4, 4)
pe.forward(
inputs=torch.randn(2, 4),
k=1,
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
expert_offsets=torch.tensor([2, 4]),
)
call_kwargs = mock_pll.call_args
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
"scaling"
)
assert scaling_val == 1.0, (
f"Expected scaling=1.0 for None but got {scaling_val}"
)
def test_scaling_positive_preserved(self):
"""Normal positive scaling should be preserved."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
ParallelExperts,
)
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
pe.set_lora(
lora_A=torch.randn(4, 4),
lora_B=torch.randn(4, 4),
scaling=0.5,
)
with patch(
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
) as mock_pll:
mock_pll.return_value = torch.randn(4, 4)
pe.forward(
inputs=torch.randn(2, 4),
k=1,
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
expert_offsets=torch.tensor([2, 4]),
)
call_kwargs = mock_pll.call_args
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
"scaling"
)
assert scaling_val == 0.5
# ============================================================================
# 4. single2scatter: non-aligned K/N dimensions (GPU only)
# ============================================================================
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestSingle2ScatterBounds:
"""Test single2scatter with non-aligned dimensions."""
def test_non_aligned_k(self):
"""K not a multiple of BLOCK_K should produce correct results."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
single2scatter,
)
E, K, N = 2, 100, 128 # K=100 not a multiple of 128
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
Y = single2scatter(X, W, expert_idxs)
assert Y.shape == (2, N)
# Verify against manual computation
Y_ref_0 = X[0] @ W[0]
Y_ref_1 = X[0] @ W[1]
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
def test_non_aligned_n(self):
"""N not a multiple of BLOCK_N should produce correct results."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
single2scatter,
)
E, K, N = 2, 128, 100 # N=100 not a multiple of 128
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
Y = single2scatter(X, W, expert_idxs)
assert Y.shape == (2, N)
Y_ref_0 = X[0] @ W[0]
Y_ref_1 = X[0] @ W[1]
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
def test_non_aligned_both(self):
"""Both K and N not aligned should produce correct results."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
single2scatter,
)
E, K, N = 2, 100, 100 # Neither aligned to 128
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
Y = single2scatter(X, W, expert_idxs)
assert Y.shape == (2, N)
Y_ref_0 = X[0] @ W[0]
Y_ref_1 = X[0] @ W[1]
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
# ============================================================================
# 5. group_compileable: coeff=None accepted
# ============================================================================
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestGroupCoeffNone:
"""Test that group() works with coeff=None."""
def test_group_with_none_coeff(self):
"""group() should accept coeff=None without errors."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
M, K = 4, 32
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
# This should not raise a TypeError
Y = group(A, sorted_expert_idxs, coeff=None, fan_out=1)
assert Y.shape == (M, K)
def test_group_with_coeff(self):
"""group() should also work with actual coeff values."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
M, K = 4, 32
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
coeff = torch.ones(M, device="cuda", dtype=torch.float32) * 0.5
Y = group(A, sorted_expert_idxs, coeff=coeff, fan_out=1)
assert Y.shape == (M, K)
# ============================================================================
# 6. Layer return value contracts
# ============================================================================
class TestLayerReturnValues:
"""Test that layer forward methods return the correct types."""
def test_hf_scatter_moe_returns_single_tensor(self):
"""HFScatterMoEGatedMLP.forward should return a single tensor, not a tuple."""
pytest.importorskip("triton")
# Verify the forward method signature and return annotation
import inspect
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
)
sig = inspect.signature(HFScatterMoEGatedMLP.forward)
# It's a staticmethod taking (self, layer_input)
params = list(sig.parameters.keys())
assert "self" in params
assert "layer_input" in params
def test_scatter_moe_gated_mlp_docstring_no_router_logits(self):
"""ScatterMoEGatedMLP.forward docstring should not mention router logits as return."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
ScatterMoEGatedMLP,
)
docstring = ScatterMoEGatedMLP.forward.__doc__
assert docstring is not None
# The docstring should mention output tensor but NOT router logits
assert "Output tensor" in docstring or "output tensor" in docstring.lower()
assert "Router logits" not in docstring, (
"Docstring should not mention 'Router logits' in Returns section"
)

View File

@@ -7,7 +7,7 @@ import unittest
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
from axolotl.utils.data import encode_streaming, md5 from axolotl.utils.data import encode_streaming, md5
from axolotl.utils.trainer import filter_sequences_by_length from axolotl.utils.trainer import drop_long_seq
from tests.hf_offline_utils import enable_hf_offline from tests.hf_offline_utils import enable_hf_offline
@@ -70,19 +70,17 @@ class TestEncodePretraining(unittest.TestCase):
# -- single sequence -- # -- single sequence --
# This should work # This should work
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]} data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
filter_sequences_by_length(data, 32, raise_on_drop=True) drop_long_seq(data, 32, raise_on_drop=True)
# This should return True, since data fits # This should return True, since data fits
dropped = filter_sequences_by_length(data, 32) dropped = drop_long_seq(data, 32)
self.assertTrue(dropped) self.assertTrue(dropped)
# This should raise # This should raise
self.assertRaises( self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should return False, since data doesn't fit # This should return False, since data doesn't fit
dropped = filter_sequences_by_length(data, 15) dropped = drop_long_seq(data, 15)
self.assertFalse(dropped) self.assertFalse(dropped)
# -- batch sequence -- # -- batch sequence --
@@ -93,15 +91,13 @@ class TestEncodePretraining(unittest.TestCase):
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
] ]
} }
filter_sequences_by_length(data, 32, raise_on_drop=True) drop_long_seq(data, 32, raise_on_drop=True)
# This should raise # This should raise
self.assertRaises( self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should keep the first but drop the second entry # This should keep the first but drop the second entry
dropped = filter_sequences_by_length(data, 15) dropped = drop_long_seq(data, 15)
self.assertEqual(dropped, [True, False]) self.assertEqual(dropped, [True, False])

View File

@@ -1,135 +0,0 @@
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
from unittest.mock import MagicMock, patch
from transformers import PreTrainedTokenizerBase
from axolotl.utils.dict import DictDefault
class TestRevisionParameter:
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
@patch("axolotl.loaders.tokenizer.load_model_config")
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch(
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
)
def test_load_tokenizer_passes_revision(
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
):
mock_tokenizer = MagicMock()
mock_tokenizer.__class__.__name__ = "MockTokenizer"
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
cfg = DictDefault(
{
"tokenizer_config": "some-model",
"revision_of_model": "abc123",
}
)
from axolotl.loaders.tokenizer import load_tokenizer
load_tokenizer(cfg)
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "abc123"
@patch("axolotl.loaders.tokenizer.load_model_config")
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch(
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
)
def test_load_tokenizer_omits_revision_when_unset(
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
):
mock_tokenizer = MagicMock()
mock_tokenizer.__class__.__name__ = "MockTokenizer"
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
cfg = DictDefault(
{
"tokenizer_config": "some-model",
}
)
from axolotl.loaders.tokenizer import load_tokenizer
load_tokenizer(cfg)
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
@patch("axolotl.loaders.tokenizer.barrier")
def test_modify_tokenizer_files_passes_revision(
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
):
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
from axolotl.loaders.tokenizer import modify_tokenizer_files
modify_tokenizer_files("some-model", {}, output_dir=temp_dir, revision="abc123")
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "abc123"
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
@patch("axolotl.loaders.tokenizer.barrier")
def test_modify_tokenizer_files_defaults_revision_to_main(
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
):
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
from axolotl.loaders.tokenizer import modify_tokenizer_files
modify_tokenizer_files("some-model", {}, output_dir=temp_dir)
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "main"
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_passes_revision(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"revision_of_model": "abc123",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "abc123"
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_omits_revision_when_unset(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs

View File

@@ -1,210 +0,0 @@
"""Tests to verify that deduplication runs before dataset saving during preprocessing.
This addresses GitHub issue #2719: Save De-duplicated Set During Pre-processing.
"""
from unittest.mock import MagicMock, patch
from datasets import Dataset
from axolotl.utils.dict import DictDefault
class TestSFTSaveDeduplicatedBeforeSave:
"""Verify that in SFT data loading, deduplication occurs before saving."""
@patch("axolotl.utils.data.sft.save_preprocessed_dataset")
@patch("axolotl.utils.data.sft.generate_dataset_hash_from_config")
@patch("axolotl.utils.data.sft.deduplicate_and_log_datasets")
@patch("axolotl.utils.data.sft.merge_datasets")
@patch("axolotl.utils.data.sft._load_and_process_single_dataset")
@patch("axolotl.utils.data.sft.datasets_with_name_generator")
def test_dedup_called_before_save_sft(
self,
mock_datasets_gen,
mock_load_single,
mock_merge,
mock_dedup,
mock_gen_hash,
mock_save,
):
"""Deduplication should be called before save_preprocessed_dataset in SFT."""
from axolotl.utils.data.sft import _load_raw_datasets
# Set up mock data
dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]})
deduped_dataset = Dataset.from_dict({"text": ["a", "b"], "label": [1, 2]})
mock_datasets_gen.return_value = [
DictDefault({"path": "test", "type": "alpaca"})
]
mock_load_single.return_value = (dataset, None)
mock_merge.return_value = dataset
mock_dedup.return_value = (deduped_dataset, None)
mock_gen_hash.return_value = "testhash"
cfg = DictDefault(
{
"skip_prepare_dataset": False,
"dataset_exact_deduplication": True,
"sequence_len": 1024,
"eval_sequence_len": None,
"sample_packing": False,
"is_preprocess": False,
"seed": 42,
"datasets": [{"path": "test", "type": "alpaca"}],
}
)
tokenizer = MagicMock()
tokenizer.name_or_path = "test-tokenizer"
# Track call order
call_order = []
mock_dedup.side_effect = lambda **kwargs: (
call_order.append("dedup") or (deduped_dataset, None)
)
mock_save.side_effect = lambda *args, **kwargs: call_order.append("save")
_load_raw_datasets(
cfg=cfg,
datasets_configs=cfg.datasets,
tokenizer=tokenizer,
split="train",
)
# Verify dedup was called
assert "dedup" in call_order, "Deduplication should have been called"
# Verify save was called
assert "save" in call_order, "Save should have been called"
# Verify dedup happened before save
assert call_order.index("dedup") < call_order.index("save"), (
"Deduplication must occur before saving the dataset"
)
@patch("axolotl.utils.data.sft.save_preprocessed_dataset")
@patch("axolotl.utils.data.sft.generate_dataset_hash_from_config")
@patch("axolotl.utils.data.sft.merge_datasets")
@patch("axolotl.utils.data.sft._load_and_process_single_dataset")
@patch("axolotl.utils.data.sft.datasets_with_name_generator")
def test_no_dedup_when_disabled_sft(
self,
mock_datasets_gen,
mock_load_single,
mock_merge,
mock_gen_hash,
mock_save,
):
"""Deduplication should not be called when dataset_exact_deduplication is False."""
from axolotl.utils.data.sft import _load_raw_datasets
dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]})
mock_datasets_gen.return_value = [
DictDefault({"path": "test", "type": "alpaca"})
]
mock_load_single.return_value = (dataset, None)
mock_merge.return_value = dataset
mock_gen_hash.return_value = "testhash"
cfg = DictDefault(
{
"skip_prepare_dataset": False,
"dataset_exact_deduplication": False,
"sequence_len": 1024,
"eval_sequence_len": None,
"sample_packing": False,
"is_preprocess": False,
"seed": 42,
"datasets": [{"path": "test", "type": "alpaca"}],
}
)
tokenizer = MagicMock()
tokenizer.name_or_path = "test-tokenizer"
with patch("axolotl.utils.data.sft.deduplicate_and_log_datasets") as mock_dedup:
_load_raw_datasets(
cfg=cfg,
datasets_configs=cfg.datasets,
tokenizer=tokenizer,
split="train",
)
mock_dedup.assert_not_called()
class TestRLSaveDeduplicatedBeforeSave:
"""Verify that in RL data loading, deduplication occurs before saving."""
@patch.object(Dataset, "filter", lambda self, *args, **kwargs: self)
@patch("axolotl.utils.data.rl.save_preprocessed_dataset")
@patch("axolotl.utils.data.rl.generate_dataset_hash_from_config")
@patch("axolotl.utils.data.rl.deduplicate_and_log_datasets")
@patch("axolotl.utils.data.rl.merge_datasets")
@patch("axolotl.utils.data.rl.load_dataset_with_config")
@patch("axolotl.utils.data.rl.datasets_with_name_generator")
@patch("axolotl.utils.data.rl.load_tokenizer")
def test_dedup_called_before_save_rl(
self,
mock_load_tokenizer,
mock_datasets_gen,
mock_load_dataset,
mock_merge,
mock_dedup,
mock_gen_hash,
mock_save,
):
"""Deduplication should be called before save_preprocessed_dataset in RL."""
from axolotl.utils.data.rl import _load_split
dataset = Dataset.from_dict(
{
"prompt": ["hi", "bye", "hi"],
"chosen": ["a", "b", "a"],
"rejected": ["c", "d", "c"],
}
)
deduped_dataset = Dataset.from_dict(
{
"prompt": ["hi", "bye"],
"chosen": ["a", "b"],
"rejected": ["c", "d"],
}
)
mock_datasets_gen.return_value = [DictDefault({"path": "test", "type": None})]
mock_load_dataset.return_value = dataset
mock_merge.return_value = dataset
mock_dedup.return_value = (deduped_dataset, None)
mock_gen_hash.return_value = "testhash"
tokenizer = MagicMock()
tokenizer.name_or_path = "test-tokenizer"
mock_load_tokenizer.return_value = tokenizer
cfg = DictDefault(
{
"skip_prepare_dataset": False,
"dataset_exact_deduplication": True,
"sequence_len": 1024,
"rl": "dpo",
"datasets": [{"path": "test", "type": None}],
"hf_use_auth_token": False,
"dataset_num_proc": 1,
"is_preprocess": False,
}
)
call_order = []
mock_dedup.side_effect = lambda **kwargs: (
call_order.append("dedup") or (deduped_dataset, None)
)
mock_save.side_effect = lambda *args, **kwargs: call_order.append("save")
_load_split(cfg, split="train")
assert "dedup" in call_order, "Deduplication should have been called"
assert "save" in call_order, "Save should have been called"
assert call_order.index("dedup") < call_order.index("save"), (
"Deduplication must occur before saving the dataset"
)

View File

@@ -116,7 +116,6 @@ class TestTokenizers:
tokenizer.decode([128041, 128042]) == "RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2" tokenizer.decode([128041, 128042]) == "RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2"
) )
@pytest.mark.skip("FIXME slow test sdist py3.11 + torch2.8.0")
@enable_hf_offline @enable_hf_offline
def test_added_tokens_overrides_gemma3(self, temp_dir): def test_added_tokens_overrides_gemma3(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(

View File

@@ -1,545 +0,0 @@
"""
Unit tests for data utility functions
"""
import unittest
from unittest.mock import MagicMock
from datasets import Dataset
from axolotl.utils.data.utils import handle_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
class TestHandleLongSeqInDataset(unittest.TestCase):
"""
Test class for handle_long_seq_in_dataset function
"""
def test_drop_strategy_removes_long_sequences(self):
"""Test that 'drop' strategy removes sequences longer than sequence_len"""
# Create dataset with mixed length sequences
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3], # length 3 - keep
[1, 2, 3, 4, 5], # length 5 - keep
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - drop
[1, 2], # length 2 - keep
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have dropped the sequence with length 11
self.assertEqual(len(result), 3)
self.assertEqual(len(result[0]["input_ids"]), 3)
self.assertEqual(len(result[1]["input_ids"]), 5)
self.assertEqual(len(result[2]["input_ids"]), 2)
def test_drop_strategy_is_default(self):
"""Test that 'drop' is the default strategy when not specified"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should drop
]
}
)
cfg = DictDefault(
{
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have dropped the long sequence
self.assertEqual(len(result), 1)
def test_truncate_strategy_truncates_long_sequences(self):
"""Test that 'truncate' strategy truncates sequences to sequence_len"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3], # length 3 - keep as is
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
], # length 12 - truncate to 10
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have 2 samples
self.assertEqual(len(result), 2)
# First sample unchanged
self.assertEqual(len(result[0]["input_ids"]), 3)
# Second sample truncated to 10
self.assertEqual(len(result[1]["input_ids"]), 10)
self.assertEqual(result[1]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_truncate_strategy_truncates_all_auxiliary_fields(self):
"""Test that truncation applies to all auxiliary fields consistently"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
"attention_mask": [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
],
"labels": [
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
"position_ids": [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
],
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# All fields should be truncated to 10
self.assertEqual(len(result[0]["input_ids"]), 10)
self.assertEqual(len(result[0]["attention_mask"]), 10)
self.assertEqual(len(result[0]["labels"]), 10)
self.assertEqual(len(result[0]["position_ids"]), 10)
# Verify content is correct
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
self.assertEqual(result[0]["attention_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
self.assertEqual(result[0]["position_ids"], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
def test_raise_strategy_raises_on_long_sequences(self):
"""Test that 'raise' strategy raises ValueError when encountering long sequences"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should raise
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "raise",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
with self.assertRaises(ValueError):
handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
def test_min_sequence_len_filters_short_sequences(self):
"""Test that sequences shorter than min_sample_len are filtered out"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - drop (< min_sample_len=3)
[1, 2], # length 2 - drop
[1, 2, 3], # length 3 - keep
[1, 2, 3, 4, 5], # length 5 - keep
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 3,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should only keep sequences with length >= 3
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]["input_ids"]), 3)
self.assertEqual(len(result[1]["input_ids"]), 5)
def test_dataset_without_input_ids_column(self):
"""Test that datasets without 'input_ids' column are returned unchanged"""
dataset = Dataset.from_dict(
{
"chosen": [1, 2, 3],
"rejected": [4, 5, 6],
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Dataset should be unchanged
self.assertEqual(len(result), len(dataset))
self.assertListEqual(list(result.column_names), ["chosen", "rejected"])
def test_truncate_filters_short_before_truncating(self):
"""Test that truncate strategy filters short sequences before truncating long ones
This is important for efficiency - we should not waste time truncating
sequences that will be filtered out anyway.
"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - filter out first
[1, 2, 3], # length 3 - keep, no truncation needed
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
], # length 12 - keep and truncate
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have filtered out the first (short) sequence
self.assertEqual(len(result), 2)
# Second sample unchanged
self.assertEqual(len(result[0]["input_ids"]), 3)
# Third sample truncated to 10
self.assertEqual(len(result[1]["input_ids"]), 10)
def test_case_insensitive_strategy(self):
"""Test that excess_length_strategy is case-insensitive"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "TRUNCATE", # uppercase
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should still truncate
self.assertEqual(len(result[0]["input_ids"]), 10)
def test_raise_strategy_silently_drops_short_sequences(self):
"""Test that 'raise' strategy drops short sequences without raising"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - too short, should be dropped silently
[1, 2, 3, 4, 5], # length 5 - keep
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "raise",
"min_sample_len": 3,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
# Should NOT raise, just silently drop the short sequence
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]["input_ids"]), 5)
def test_drop_boundary_sequence_equal_to_sequence_len(self):
"""Test that drop strategy keeps sequences with length exactly equal to sequence_len"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 > sequence_len
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Exactly equal should be kept, one over should be dropped
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]["input_ids"]), 10)
def test_truncate_boundary_sequence_equal_to_sequence_len(self):
"""Test that truncate strategy leaves sequences with length exactly equal to sequence_len unchanged"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should be unchanged - not truncated
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_empty_dataset(self):
"""Test that an empty dataset is handled gracefully"""
dataset = Dataset.from_dict({"input_ids": []})
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 0)
def test_all_sequences_dropped_returns_empty_dataset(self):
"""Test that dropping all sequences results in an empty dataset"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # too short
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # too long
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 5,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 0)
def test_iterable_dataset_skips_processing(self):
"""Test that streaming datasets (column_names is None) are returned unchanged.
The skip check in _should_skip_processing triggers when column_names is
None, which happens with true streaming datasets loaded via
load_dataset(..., streaming=True).
"""
mock_dataset = MagicMock()
mock_dataset.column_names = None
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(mock_dataset, sequence_len=10, cfg=cfg)
# Should be returned unchanged (same object)
self.assertIs(result, mock_dataset)
def test_truncate_with_partial_auxiliary_fields(self):
"""Test truncation when only some auxiliary fields are present"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
"labels": [
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
# No attention_mask or position_ids
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result[0]["input_ids"]), 10)
self.assertEqual(len(result[0]["labels"]), 10)
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
# Confirm no extra columns were introduced
self.assertListEqual(sorted(result.column_names), ["input_ids", "labels"])
def test_min_sample_len_defaults_to_two_when_not_set(self):
"""Test that min_sample_len defaults to 2 when not specified in config"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - should be dropped (< default 2)
[1, 2], # length 2 - should be kept (>= default 2)
[1, 2, 3], # length 3 - should be kept
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
# min_sample_len not set
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]["input_ids"]), 2)
self.assertEqual(len(result[1]["input_ids"]), 3)
def test_invalid_strategy_falls_through_to_drop(self):
"""Test that an unrecognized strategy value falls through to drop behavior"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3], # keep
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
], # length 11 - should be dropped
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "not_a_real_strategy",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should behave like 'drop'
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]["input_ids"]), 3)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,149 +0,0 @@
"""Tests for Mistral3Processor with transformers v5 ProcessorMixin integration"""
from unittest.mock import MagicMock
import pytest
import torch
from transformers.feature_extraction_utils import BatchFeature
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
@pytest.fixture()
def mock_tokenizer():
"""Create a mock HFMistralTokenizer that passes v5 ProcessorMixin isinstance checks."""
return MagicMock(spec=HFMistralTokenizer)
@pytest.fixture()
def processor(mock_tokenizer):
return Mistral3Processor(tokenizer=mock_tokenizer)
class TestMistral3ProcessorInit:
def test_tokenizer_is_set(self, processor, mock_tokenizer):
assert processor.tokenizer is mock_tokenizer
def test_chat_template_is_none(self, processor):
assert processor.chat_template is None
def test_audio_tokenizer_is_none(self, processor):
assert processor.audio_tokenizer is None
class TestApplyChatTemplateTokenized:
"""Test apply_chat_template with tokenize=True, return_dict=True"""
@pytest.fixture()
def batched_conversations(self):
return [
[
{"role": "user", "content": "Describe this image."},
{"role": "assistant", "content": "It is red."},
],
[
{"role": "user", "content": "What is this?"},
{"role": "assistant", "content": "A cat."},
],
]
def test_returns_batch_feature_with_pixel_values(
self, processor, mock_tokenizer, batched_conversations
):
pixel_values = torch.randn(2, 3, 224, 224, dtype=torch.float64)
mock_tokenizer.apply_chat_template.return_value = {
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
"attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]),
"pixel_values": pixel_values,
}
result = processor.apply_chat_template(
batched_conversations, tokenize=True, return_dict=True
)
assert isinstance(result, BatchFeature)
assert "pixel_values" in result
assert "image_sizes" in result
assert result["pixel_values"].dtype == torch.float32
assert result["image_sizes"].shape == (2, 2)
assert result["image_sizes"][0].tolist() == [224, 224]
def test_returns_batch_feature_without_pixel_values(
self, processor, mock_tokenizer, batched_conversations
):
mock_tokenizer.apply_chat_template.return_value = {
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
"attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]),
}
result = processor.apply_chat_template(
batched_conversations, tokenize=True, return_dict=True
)
assert isinstance(result, BatchFeature)
assert "input_ids" in result
assert "image_sizes" not in result
class TestApplyChatTemplateNotTokenized:
def test_single_conversation_returns_unwrapped(self, processor, mock_tokenizer):
"""Single conversation (not batched) should return unwrapped result."""
single_conversation = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
mock_tokenizer.apply_chat_template.return_value = [
"<s>[INST]Hello[/INST]Hi</s>"
]
result = processor.apply_chat_template(
single_conversation, tokenize=False, return_dict=False
)
assert result == "<s>[INST]Hello[/INST]Hi</s>"
def test_batched_conversations_returns_list(self, processor, mock_tokenizer):
batched = [
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
],
[
{"role": "user", "content": "Bye"},
{"role": "assistant", "content": "Bye"},
],
]
mock_tokenizer.apply_chat_template.return_value = ["text1", "text2"]
result = processor.apply_chat_template(
batched, tokenize=False, return_dict=False
)
assert result == ["text1", "text2"]
class TestCall:
def test_delegates_to_tokenizer(self, processor, mock_tokenizer):
mock_tokenizer.return_value = {
"input_ids": [1, 2, 3],
"attention_mask": [1, 1, 1],
}
result = processor("Hello world")
mock_tokenizer.assert_called_once()
assert isinstance(result, BatchFeature)
class TestReturnTensorsValidation:
def test_rejects_non_pt_return_tensors(self, processor):
conversation = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
with pytest.raises(ValueError, match=r"only supports.*return_tensors='pt'"):
processor.apply_chat_template(
conversation, tokenize=True, return_dict=True, return_tensors="np"
)