Compare commits

..

5 Commits

Author SHA1 Message Date
Dan Saunders
09725be990 add support for CP + torch SDPA 2025-09-25 12:03:43 -04:00
Dan Saunders
f9bd6936c1 Merge branch 'main' into cp-fix 2025-09-24 14:01:23 -04:00
Dan Saunders
b9a3bfee5a only patch in CP > 1 case 2025-09-24 13:36:14 -04:00
Dan Saunders
08124a7c92 nits 2025-09-24 13:25:46 -04:00
Dan Saunders
56e0a77e0d patch transformers to allow CP + FA2 2025-09-24 13:08:38 -04:00
123 changed files with 688 additions and 2182 deletions

View File

@@ -25,6 +25,20 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126" - cuda: "126"
cuda_version: 12.6.3 cuda_version: 12.6.3
cudnn_version: "" cudnn_version: ""
@@ -53,20 +67,6 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
# - cuda: "128" # - cuda: "128"
# cuda_version: 12.8.1 # cuda_version: 12.8.1
# cudnn_version: "" # cudnn_version: ""
@@ -122,6 +122,13 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "126" - cuda: "126"
cuda_version: 12.6.3 cuda_version: 12.6.3
cudnn_version: "" cudnn_version: ""
@@ -143,20 +150,6 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -15,6 +15,11 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
@@ -83,6 +88,11 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
@@ -152,6 +162,11 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"

View File

@@ -26,6 +26,13 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
@@ -40,13 +47,6 @@ jobs:
axolotl_extras: fbgemm-gpu axolotl_extras: fbgemm-gpu
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 120 timeout-minutes: 120
steps: steps:

View File

@@ -15,12 +15,12 @@ jobs:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
- cuda: 128 - cuda: 126
cuda_version: 12.8.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.7.1
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
@@ -68,12 +68,12 @@ jobs:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
- cuda: 128 - cuda: 126
cuda_version: 12.8.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.7.1
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -2,7 +2,7 @@ name: Pre-commit auto-update
on: on:
schedule: schedule:
- cron: '0 0 1 * *' # Run monthly - cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff workflow_dispatch: # Manual kickoff
jobs: jobs:

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2 max-parallel: 2
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0"] pytorch_version: ["2.6.0", "2.7.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -102,14 +102,14 @@ jobs:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.6.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
nightly_build: "true" nightly_build: "true"
- cuda: 128 - cuda: 126
cuda_version: 12.8.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.7.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
nightly_build: "true" nightly_build: "true"

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"] pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -81,12 +81,12 @@ jobs:
- name: Install PyTorch - name: Install PyTorch
run: | run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 show torch pip3 show torch
pip3 install --no-cache-dir --no-build-isolation -U -e . pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
@@ -130,7 +130,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"] pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -152,17 +152,17 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 show torch pip3 show torch
python -m build --no-isolation --sdist python -m build --no-isolation --sdist
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz pip3 install --no-build-isolation dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
@@ -231,10 +231,16 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 128 - cuda: 126
cuda_version: 12.8.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
dockerfile: "Dockerfile-uv.jinja" dockerfile: "Dockerfile-uv.jinja"
@@ -283,15 +289,15 @@ jobs:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.7.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
# - cuda: 128
# cuda_version: 12.8.1
# python_version: "3.11"
# pytorch: 2.7.1
# num_gpus: 1
# axolotl_extras:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
@@ -299,12 +305,6 @@ jobs:
num_gpus: 1 num_gpus: 1
gpu_type: "B200" gpu_type: "B200"
axolotl_extras: fbgemm-gpu axolotl_extras: fbgemm-gpu
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
num_gpus: 1
axolotl_extras:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -11,13 +11,13 @@ repos:
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.3 rev: v0.12.12
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]
- id: ruff-format - id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2 rev: v1.17.1
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:

View File

@@ -73,7 +73,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU - NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11 - Python 3.11
- PyTorch ≥2.7.1 - PyTorch ≥2.6.0
### Google Colab ### Google Colab

View File

@@ -267,7 +267,6 @@ website:
- docs/dataset_loading.qmd - docs/dataset_loading.qmd
- docs/qat.qmd - docs/qat.qmd
- docs/quantize.qmd - docs/quantize.qmd
- docs/optimizations.qmd
- section: "Core Concepts" - section: "Core Concepts"
contents: contents:

View File

@@ -32,7 +32,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi fi
RUN uv pip install packaging==23.2 setuptools==75.8.0 RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN uv pip install torchvision
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \

View File

@@ -1,6 +1,6 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }} FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}" ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}" ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}" ENV CUDA="{{ CUDA }}"
@@ -9,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}" ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}" ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}" ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8" ENV AXOLOTL_DATASET_PROCESSES="8"
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \ sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi fi
RUN pip install packaging==23.2 setuptools==75.8.0 psutil RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \

View File

@@ -65,13 +65,8 @@ def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec import subprocess # nosec
sp_env = os.environ.copy() sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8" sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
# Propagate errors from subprocess. # Propagate errors from subprocess.
try: if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec exit(exit_code)
if exit_code:
print(f"Command '{cmd}' failed with exit code {exit_code}")
return exit_code
except Exception as e: # pylint: disable=broad-except
print(f"Command '{cmd}' failed with exception {e}")

View File

@@ -13,7 +13,7 @@ datasets:
val_set_size: 0 val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model output_dir: temp_debug/axolotl_outputs/model
dataset_prepared_path: temp_debug/axolotl_outputs/data dataset_prepared_path: temp_debug/axolotl_outputs/data
dataset_num_proc: 1 dataset_processes: 1
sequence_len: 4096 sequence_len: 4096
sample_packing: false sample_packing: false

View File

@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/workspace/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10" ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.1.2"
@@ -24,35 +24,29 @@ RUN apt-get update \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir -p /workspace/.conda \ && mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \ && bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \ && rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \ RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge python3 -m pip cache purge
RUN if [ "$CUDA" != "130" ] ; then \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
python3 -m pip cache purge; \
fi
RUN git lfs install --skip-repo && \ RUN git lfs install --skip-repo && \
pip3 install awscli && \ pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 && \ pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge pip3 cache purge
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \ RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi fi

View File

@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/workspace/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11" ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="next" ARG PYTORCH_VERSION="next"
@@ -19,12 +19,12 @@ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir -p /workspace/.conda \ && mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \ && bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \ && rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace

View File

@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/workspace/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11" ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly" ARG PYTORCH_VERSION="nightly"
@@ -19,14 +19,14 @@ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir -p /workspace/.conda \ && mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \ && bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \ && rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace

View File

@@ -30,13 +30,7 @@ RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}" ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \ RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \ && uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \ && uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \ && uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic && uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi

View File

@@ -212,14 +212,6 @@ Instead of passing `tools` via the system prompt, an alternative method would be
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step). Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
::: :::
::: {.callout-warning}
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
```
"arguments": "{\"...\": \"...\"}"
```
:::
Example config for Llama4: Example config for Llama4:
```yaml ```yaml
chat_template: llama4 chat_template: llama4

View File

@@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet
### Pre-training without streaming ### Pre-training without streaming
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming. On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs. One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.

View File

@@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`. 1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing: 1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`. - Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`. - Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config): 2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
```yaml ```yaml
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"-m", "axolotl.cli.train", "dev_chat_template.yml", "-m", "axolotl.cli.train", "dev_chat_template.yml",
// The flags below simplify debugging by overriding the axolotl config // The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed. // with the debugging tips above. Modify as needed.
"--dataset_num_proc=1", // limits data preprocessing to one process "--dataset_processes=1", // limits data preprocessing to one process
"--max_steps=1", // limits training to just one step "--max_steps=1", // limits training to just one step
"--batch_size=1", // minimizes batch size "--batch_size=1", // minimizes batch size
"--micro_batch_size=1", // minimizes batch size "--micro_batch_size=1", // minimizes batch size

View File

@@ -63,14 +63,6 @@ description: Frequently asked questions
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717. > A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
**Q: Can we mix text and text+image datasets for VLM training?**
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
### Chat templates ### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`** **Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
@@ -148,7 +140,3 @@ description: Frequently asked questions
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")` **Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML. > A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
**Q: `Error parsing tool_calls arguments as JSON.`
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.

View File

@@ -1,5 +1,5 @@
--- ---
title: "FSDP + QLoRA" title: "FDSP + QLoRA"
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs. description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
format: format:
html: html:
@@ -23,12 +23,6 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp). 2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`. 3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Enabling Swap for FSDP2
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
## Example Config ## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl. [examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.

View File

@@ -5,11 +5,10 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU (in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
autograd functions. Our goal was to leverage operator fusion and tensor re-use in order to leverage operator fusion and tensor re-use in order to improve speed and reduce
to improve speed and reduce memory usage during the forward and backward passes of memory usage during the forward and backward passes of these calculations.
these calculations.
We currently support several common model architectures, including (but not limited to): We currently support several common model architectures, including (but not limited to):
@@ -132,5 +131,6 @@ computation path.
## Future Work ## Future Work
- Support for additional model architectures - Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias - Support for dropout and bias
- Additional operator fusions - Additional operator fusions

View File

@@ -27,9 +27,3 @@ learning_rate: 2e-5
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module. self attention `q_proj` module.
::: {.callout-note}
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
:::

View File

@@ -88,7 +88,6 @@ fsdp_sync_module_states | **REMOVED**
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
fsdp_state_dict_type | state_dict_type fsdp_state_dict_type | state_dict_type
fsdp_use_orig_params | **REMOVED** fsdp_use_orig_params | **REMOVED**
fsdp_activation_checkpointing | activation_checkpointing
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl, For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
if you were using the following FSDP1 config: if you were using the following FSDP1 config:

View File

@@ -56,14 +56,10 @@ image_resize_algorithm: bilinear
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs. Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
::: {.callout-tip} ::: {.callout-warning}
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs. Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
::: :::
::: {.callout-note}
As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.
:::
### Mllama {#sec-mllama} ### Mllama {#sec-mllama}
```yaml ```yaml
@@ -172,14 +168,6 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl chat_template: qwen2_vl # same as qwen2-vl
``` ```
### Qwen3-VL {#sec-qwen3-vl}
```yaml
base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### SmolVLM2 {#sec-smolvlm2} ### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip} ::: {.callout-tip}

View File

@@ -1,133 +0,0 @@
---
title: Optimizations Guide
description: A guide to the performance and memory optimizations available in Axolotl.
---
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
## Speed Optimizations
These optimizations focus on increasing training throughput and reducing total training time.
### Sample Packing
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
- **Config:** `sample_packing: true`
- **Learn more:** [Sample Packing](multipack.qmd)
### Attention Implementations
Using an optimized attention implementation is critical for training speed.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
*Note: You should only enable one attention backend.*
### LoRA Optimizations
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
## Memory Optimizations
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
### Parameter Efficient Finetuning (LoRA & QLoRA)
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
### Gradient Checkpointing & Activation Offloading
These techniques save VRAM by changing how activations are handled.
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
### Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
### Liger Kernels
Provides efficient Triton kernels to improve training speed and reduce memory usage.
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
## Long Context Models
Techniques to train models on sequences longer than their original context window.
### RoPE Scaling
Extends a model's context window by interpolating its Rotary Position Embeddings.
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
### Sequence Parallelism
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
### Artic Long Sequence Training (ALST)
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
- TiledMLP to reduce memory usage in MLP layers.
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
- Activation Offloading to CPU.
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
## Large Models (Distributed Training)
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
### N-D Parallelism (Beta)
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
## Quantization
Techniques to reduce the precision of model weights for memory savings.
### 4-bit Training (QLoRA)
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
### FP8 Training
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
### Quantization Aware Training (QAT)
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
- **Learn more:** [QAT Documentation](qat.qmd)
### GPTQ
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)

View File

@@ -30,7 +30,6 @@ qat:
``` ```
We support the following quantization schemas: We support the following quantization schemas:
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl) - `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
- `Int8DynamicActivationInt4Weight` - `Int8DynamicActivationInt4Weight`
- `Float8DynamicActivationFloat8Weight` - `Float8DynamicActivationFloat8Weight`

View File

@@ -219,21 +219,6 @@ DPO supports the following types with the following dataset format:
} }
``` ```
#### chat_template.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
#### chat_template.default #### chat_template.default
```yaml ```yaml

View File

@@ -6,8 +6,6 @@ LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl. This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
Thanks to the team at LiquidAI for giving us early access to prepare for these releases.
## Getting Started ## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). 1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
@@ -33,14 +31,6 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
``` ```
**LFM2-MoE**
```bash
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
# LoRA SFT (1x48GB @ 16.2GiB)
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
```
### TIPS ### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it: - **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
@@ -55,13 +45,14 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
## Optimization Guides ## Optimization Guides
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html) - [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources ## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) - [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models) - [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts)
- [Axolotl Docs](https://docs.axolotl.ai) - [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) - [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) - [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,7 +1,6 @@
base_model: LiquidAI/LFM2-350M base_model: LiquidAI/LFM2-350M
plugins: chunked_cross_entropy: true
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
eot_tokens: eot_tokens:
- "<|im_end|>" - "<|im_end|>"

View File

@@ -1,59 +0,0 @@
base_model: LiquidAI/LFM2-8B-A1B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
eot_tokens:
- "<|im_end|>"
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: true
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -3,9 +3,6 @@ trust_remote_code: true
model_type: AutoModelForImageTextToText model_type: AutoModelForImageTextToText
processor_type: AutoProcessor processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# these 3 lines are needed for now to handle vision chat templates w images # these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true skip_prepare_dataset: true
remove_unused_columns: false remove_unused_columns: false

View File

@@ -7,24 +7,3 @@ techniques. It is a combination of:
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage - Activation Offloading: Offload activations to CPU RAM to reduce memory usage
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996). For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
## Usage
```yaml
tiled_mlp: true
# See Sequence Parallelism docs
# https://docs.axolotl.ai/docs/sequence_parallelism.html
context_parallel_size: int
plugins:
# See Cut Cross Entropy docs
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# or Liger Kernel docs
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
- axolotl.integrations.liger.LigerPlugin
# ...
```

View File

@@ -40,7 +40,7 @@
"%%capture\n", "%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n", "# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\"" "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
] ]
}, },
{ {

View File

@@ -1,7 +1,7 @@
base_model: google/gemma-3-1b-it base_model: google/gemma-3-1b-it
# optionally might have model_type or tokenizer_type
model_type: Gemma3ForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name

View File

@@ -1,7 +1,7 @@
base_model: google/gemma-3-270m-it base_model: google/gemma-3-270m-it
# optionally might have model_type or tokenizer_type
model_type: Gemma3ForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-4b-it base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
load_in_4bit: true load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp # gemma3 doesn't seem to play nice with ddp

View File

@@ -2,8 +2,6 @@
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B. [GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
In October 2025, OpenAI released safeguard models built upon GPT-OSS called [GPT-OSS-Safeguard](https://huggingface.co/collections/openai/gpt-oss-safeguard). They use the same architecture, so the same examples below can be re-used.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started ## Getting started
@@ -66,16 +64,6 @@ axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offlo
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
``` ```
### How to set reasoning_effort in template?
The harmony template has a feature to set the `reasoning_effort` during prompt building. The default is `medium`. If you would like to adjust this, you can add the following to your config:
```yaml
chat_template_kwargs:
reasoning_effort: "high" # low | medium | high
```
Currently, this applies globally. There is no method to apply per sample yet. If you are interested in adding this, please feel free to create an Issue to discuss.
### Inferencing your fine-tuned model ### Inferencing your fine-tuned model

View File

@@ -1,67 +0,0 @@
base_model: openai/gpt-oss-safeguard-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-safeguard-out/
sequence_len: 4096
sample_packing: true
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
lora_target_linear: true
# TODO: not supported for now, see peft#2710
#lora_target_parameters: # target the experts in the last two layers
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-4
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"

View File

@@ -66,7 +66,6 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
# fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config # save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -29,7 +29,7 @@ flex_attention: true
flex_attn_compile_kwargs: flex_attn_compile_kwargs:
dynamic: false dynamic: false
mode: max-autotune-no-cudagraphs mode: max-autotune-no-cudagraphs
save_strategy: no
torch_compile: true torch_compile: true
wandb_project: wandb_project:

View File

@@ -1,50 +0,0 @@
base_model: NousResearch/Llama-3.2-1B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_4bit: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
output_dir: ./outputs/opentelemetry-example
adapter: qlora
sequence_len: 512
sample_packing: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
# OpenTelemetry Configuration
use_otel_metrics: true
otel_metrics_host: "localhost"
otel_metrics_port: 8000
# Disable WandB
use_wandb: false
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
flash_attention: false
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -12,7 +12,7 @@ Before starting, ensure you have:
Run the thinking model fine-tuning: Run the thinking model fine-tuning:
```bash ```bash
axolotl train examples/magistral/think/magistral-small-think-qlora.yaml axolotl train magistral-small-think-qlora.yaml
``` ```
This config uses about 19.1 GiB VRAM. This config uses about 19.1 GiB VRAM.

View File

@@ -21,7 +21,7 @@ Before starting, ensure you have:
3. Run the fine-tuning: 3. Run the fine-tuning:
```bash ```bash
axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml axolotl train magistral-small-vision-24B-qlora.yml
``` ```
This config uses about 17GiB VRAM. This config uses about 17GiB VRAM.

View File

@@ -1,51 +0,0 @@
# Mistral Small 3.1/3.2 Fine-tuning
This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
## Getting Started
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.5'
```
2. Download the example dataset image:
```bash
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
```
3. Run the fine-tuning:
```bash
axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml
```
This config uses about 29.4 GiB VRAM.
## Dataset Format
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
Example:
```json
{
"messages": [
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
{"role": "user", "content": [
{ "type": "text", "text": "What's in this image?"},
{"type": "image", "path": "path/to/image.jpg" }
]},
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
],
}
```
## Limitations
- Sample Packing is not supported for multi-modality training currently.

View File

@@ -39,7 +39,7 @@ wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 1
num_epochs: 1 num_epochs: 1
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -38,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
``` ```
This config uses about 45.62 GiB VRAM. This config uses about 41.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀 Let us know how it goes. Happy finetuning! 🚀

View File

@@ -27,14 +27,6 @@ lora_r: 16
lora_alpha: 8 lora_alpha: 8
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: lora_target_modules:
- linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz
- linear_attn.out_proj
- shared_expert.up_proj
- shared_expert.down_proj
- shared_expert.gate_proj
- shared_expert_gate
- mlp.gate
- q_proj - q_proj
- v_proj - v_proj
- k_proj - k_proj

View File

@@ -5,30 +5,31 @@ bitsandbytes==0.47.0
triton>=3.0.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
liger-kernel==0.6.3 autoawq==0.2.7.post3
liger-kernel==0.6.1
# END section # END section
packaging==23.2 packaging==23.2
huggingface_hub>=0.36.0 huggingface_hub>=0.33.0
peft>=0.17.1 peft>=0.17.0
transformers==4.56.1
tokenizers>=0.21.1 tokenizers>=0.21.1
transformers==4.57.1
accelerate==1.10.1 accelerate==1.10.1
datasets==4.3.0 datasets==4.0.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.24.0 trl==0.23.0
hf_xet==1.2.0 hf_xet==1.1.5
kernels>=0.9.0 kernels==0.9.0
trackio trackio
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
sentencepiece sentencepiece
gradio==5.49.1 gradio==5.41.1
modal==1.0.2 modal==1.0.2
pydantic>=2.10.6 pydantic==2.10.6
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
@@ -36,8 +37,8 @@ requests
wandb wandb
einops einops
colorama colorama
numba>=0.61.2 numba
numpy>=2.2.6 numpy>=1.24.4,<=2.0.1
# qlora things # qlora things
evaluate==0.4.1 evaluate==0.4.1
@@ -50,7 +51,7 @@ python-dotenv==1.0.1
# remote filesystems # remote filesystems
s3fs>=2024.5.0 s3fs>=2024.5.0
gcsfs>=2025.3.0 gcsfs>=2024.5.0
adlfs>=2024.5.0 adlfs>=2024.5.0
ocifs==1.3.2 ocifs==1.3.2
@@ -66,7 +67,7 @@ antlr4-python3-runtime==4.13.2
torchao==0.13.0 torchao==0.13.0
schedulefree==1.4.1 schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.7 axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5 axolotl-contribs-mit==0.0.5
mistral-common==1.8.5 mistral-common==1.8.5

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
) )

View File

@@ -26,6 +26,7 @@ def parse_requirements(extras_require_map):
_install_requires.append(line) _install_requires.append(line)
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# skip packages not compatible with OSX # skip packages not compatible with OSX
skip_packages = [ skip_packages = [
@@ -33,6 +34,7 @@ def parse_requirements(extras_require_map):
"triton", "triton",
"mamba-ssm", "mamba-ssm",
"xformers", "xformers",
"autoawq",
"liger-kernel", "liger-kernel",
] ]
_install_requires = [ _install_requires = [
@@ -49,7 +51,7 @@ def parse_requirements(extras_require_map):
try: try:
torch_version = version("torch") torch_version = version("torch")
except PackageNotFoundError: except PackageNotFoundError:
torch_version = "2.8.0" # default to torch 2.8.0 torch_version = "2.6.0" # default to torch 2.6
_install_requires.append(f"torch=={torch_version}") _install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -62,15 +64,8 @@ def parse_requirements(extras_require_map):
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 9): if (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu") pass
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
extras_require_map["vllm"] = ["vllm==0.11.1"]
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"]
elif (major, minor) >= (2, 7): elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:
@@ -79,7 +74,7 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
else: else:
_install_requires.append("xformers==0.0.31") _install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"] extras_require_map["vllm"] = ["vllm>=0.10.0"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3") _install_requires.append("xformers==0.0.29.post3")
@@ -92,6 +87,7 @@ def parse_requirements(extras_require_map):
_install_requires.append("xformers==0.0.28.post2") _install_requires.append("xformers==0.0.28.post2")
else: else:
_install_requires.append("xformers>=0.0.28.post3") _install_requires.append("xformers>=0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4): elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
@@ -165,13 +161,7 @@ extras_require = {
"llmcompressor": [ "llmcompressor": [
"llmcompressor==0.5.1", "llmcompressor==0.5.1",
], ],
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"], "fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
"opentelemetry": [
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-prometheus",
"prometheus-client",
],
} }
install_requires, dependency_links, extras_require_build = parse_requirements( install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require extras_require

View File

@@ -85,7 +85,9 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
unpatch_llama4 = patch_llama4_linearized_modeling() unpatch_llama4 = patch_llama4_linearized_modeling()
from transformers import Llama4ForConditionalGeneration from transformers import Llama4ForConditionalGeneration
model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16) model_ = Llama4ForConditionalGeneration.from_pretrained(
model, torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained(model) processor = AutoProcessor.from_pretrained(model)
processor.save_pretrained(output) processor.save_pretrained(output)

View File

@@ -69,7 +69,7 @@ def do_quantize(
config = AutoConfig.from_pretrained(model_path) config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", dtype=torch_dtype model_path, device_map="auto", torch_dtype=torch_dtype
) )
LOG.info( LOG.info(

View File

@@ -99,7 +99,7 @@ def ray_train_func(kwargs: dict):
resolve_dtype(cfg) resolve_dtype(cfg)
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict # ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"): if cfg.deepspeed:
cfg.deepspeed = cfg.deepspeed.to_dict() cfg.deepspeed = cfg.deepspeed.to_dict()
# initialize accelerator before model instantiation # initialize accelerator before model instantiation

View File

@@ -12,9 +12,6 @@ MOE_ARCH_BLOCK = {
"mixtral": "MixtralSparseMoeBlock", "mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock", "qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock", "qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE", "deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",
"gpt_oss": "GptOssDecoderLayer", "gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
} }

View File

@@ -29,11 +29,7 @@ from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import ( from axolotl.utils import is_comet_available, is_mlflow_available
is_comet_available,
is_mlflow_available,
is_opentelemetry_available,
)
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
GCCallback, GCCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
@@ -138,12 +134,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append( callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_otel_metrics and is_opentelemetry_available():
from axolotl.utils.callbacks.opentelemetry import (
OpenTelemetryMetricsCallback,
)
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
if self.cfg.save_first_step: if self.cfg.save_first_step:
callbacks.append(SaveModelOnFirstStepCallback()) callbacks.append(SaveModelOnFirstStepCallback())
@@ -501,7 +491,6 @@ class TrainerBuilderBase(abc.ABC):
"dion_momentum", "dion_momentum",
"dion_rank_fraction", "dion_rank_fraction",
"dion_rank_multiple_of", "dion_rank_multiple_of",
"dataset_num_proc",
]: ]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg) training_args_kwargs[arg] = getattr(self.cfg, arg)
@@ -525,6 +514,9 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# max_length is not used in CausalTrainer # max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl: if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -12,7 +12,7 @@ from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
Trainer, Trainer,
) )
from trl.trainer.reward_trainer import DataCollatorForPreference from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import ( from axolotl.core.trainers import (
@@ -28,6 +28,7 @@ from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
LossWatchDogCallback, LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory, causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback, colab_inference_post_train_callback,
@@ -62,6 +63,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.relora: if self.cfg.relora:
callbacks.append(ReLoRACallback(self.cfg)) callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback())
# TODO: check if can move to base class # TODO: check if can move to base class
if self.cfg.loss_watchdog_threshold is not None: if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg)) callbacks.append(LossWatchDogCallback(self.cfg))
@@ -453,7 +460,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
DataCollatorWithFlattening, DataCollatorWithFlattening,
DataCollatorForPreference, RewardDataCollatorWithPadding,
] ]
] ]
collator_args = [self.tokenizer] collator_args = [self.tokenizer]
@@ -470,10 +477,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if kwargs and isinstance(kwargs, dict): if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1]) kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model: elif self.cfg.reward_model:
collator = DataCollatorForPreference collator = RewardDataCollatorWithPadding
tokenizer = collator_args.pop(0)
kwargs["pad_token_id"] = tokenizer.pad_token_id
kwargs.pop("padding")
elif use_batch_sampler_collator: elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama # supported multipack models, or non-flash-attention llama

View File

@@ -225,6 +225,17 @@ class AxolotlTrainer(
data_collator = self.data_collator if is_training else self.eval_data_collator data_collator = self.data_collator if is_training else self.eval_data_collator
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if (
dataset.column_names
and "position_ids" in dataset.column_names
and "attention_mask" in dataset.column_names
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
dataset = dataset.remove_columns(["attention_mask"])
if isinstance(dataset, datasets.Dataset): if isinstance(dataset, datasets.Dataset):
if is_training: if is_training:
if not self.args.sample_packing or self.args.pretraining: if not self.args.sample_packing or self.args.pretraining:
@@ -283,18 +294,6 @@ class AxolotlTrainer(
): ):
self.accelerator.even_batches = False self.accelerator.even_batches = False
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if (
dataset.column_names
and "position_ids" in dataset.column_names
and "attention_mask" in dataset.column_names
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
dataset = dataset.remove_columns(["attention_mask"])
dataloader = DataLoader(dataset, **dataloader_params) dataloader = DataLoader(dataset, **dataloader_params)
# Accelerator.free_memory() will destroy the references, so # Accelerator.free_memory() will destroy the references, so
@@ -561,6 +560,13 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess() super().create_accelerator_and_postprocess()
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
and self.args.fsdp_config["limit_all_gathers"]
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
def additional_accelerator_args( def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]: ) -> dict[str, Any]:

View File

@@ -52,7 +52,6 @@ class GRPOStrategy:
if trl.vllm_mode: if trl.vllm_mode:
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate": if trl.vllm_mode == "colocate":
grpo_args_kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization vllm_cfg.gpu_memory_utilization
) )

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip - If you are installing from pip
```bash ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
``` ```
## Usage ## Usage
@@ -31,7 +31,6 @@ plugins:
## Supported Models ## Supported Models
- apertus
- arcee - arcee
- cohere - cohere
- cohere2 - cohere2
@@ -45,22 +44,14 @@ plugins:
- glm - glm
- glm4 - glm4
- glm4_moe - glm4_moe
- glm4v
- glm4v_moe
- gpt_oss - gpt_oss
- granite - granite
- granitemoe - granitemoe
- granitemoeshared
- granitemoehybrid
- hunyuan_v1_dense - hunyuan_v1_dense
- hunyuan_v1_moe - hunyuan_v1_moe
- lfm2
- lfm2_moe
- lfm2_vl
- llama - llama
- llama4 - llama4
- llama4_text - llama4_text
- llava
- mistral - mistral
- mistral3 - mistral3
- mixtral - mixtral
@@ -74,8 +65,6 @@ plugins:
- qwen2_5_vl - qwen2_5_vl
- qwen3 - qwen3
- qwen3_moe - qwen3_moe
- qwen3_vl
- qwen3_vl_moe
- qwen3_next - qwen3_next
- smollm3 - smollm3
- seed_oss - seed_oss

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using " "Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
) )

View File

@@ -7,7 +7,7 @@ import torch
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions from .utils import create_bidirectional_attention_mask
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -360,7 +360,7 @@ def _diffusion_step(
# Forward pass # Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask) outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = shift_logits_to_input_positions(outputs.logits) logits = outputs.logits
# Only sample at currently masked positions # Only sample at currently masked positions
if current_mask.any(): if current_mask.any():

View File

@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback from .callbacks import DiffusionGenerationCallback
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions from .utils import create_bidirectional_attention_mask
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer):
input_ids=noisy_batch.long(), input_ids=noisy_batch.long(),
attention_mask=bidirectional_mask, attention_mask=bidirectional_mask,
) )
logits = shift_logits_to_input_positions(outputs.logits) logits = outputs.logits
if masked_indices.sum() > 0: if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices) valid_indices = torch.where(masked_indices)

View File

@@ -157,10 +157,3 @@ def create_bidirectional_attention_mask(
# Add head dimension: [batch_size, 1, seq_len, seq_len] # Add head dimension: [batch_size, 1, seq_len, seq_len]
return bidirectional_mask.unsqueeze(1) return bidirectional_mask.unsqueeze(1)
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
"""Align next-token logits with their input token positions for diffusion."""
if logits.size(1) <= 1:
return logits
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)

View File

@@ -72,9 +72,9 @@ def kldiv_forward_llama_like(
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100 # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self._loss_function should be LigerFusedLinearKLTopKLogprobLoss # self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self._loss_function( loss = self.loss_function(
self.lm_head.weight, self.lm_head.weight,
hidden_states, hidden_states,
target_token_ids, target_token_ids,

View File

@@ -29,8 +29,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True self.model_accepts_loss_kwargs = True
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
loss_fn = LigerFusedLinearKLTopKLogprobLoss(
self.args.kd_ce_alpha, # hard label loss self.args.kd_ce_alpha, # hard label loss
self.args.kd_alpha, # kd loss self.args.kd_alpha, # kd loss
self.args.kd_temperature, self.args.kd_temperature,
@@ -38,14 +37,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
compute_ce_loss=bool(self.args.kd_ce_alpha), compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk, normalize_topk=self.args.kd_normalize_topk,
) )
target = self.model
# Unwrap PEFT wrapper
if hasattr(target, "get_base_model"):
target = target.get_base_model()
# Set on the actual model instance
target._loss_function = loss_fn
def _set_signature_columns_if_needed(self): def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed() super()._set_signature_columns_if_needed()

View File

@@ -515,6 +515,9 @@ class ModelLoader:
if self.cfg.model_quantization_config_kwargs: if self.cfg.model_quantization_config_kwargs:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs) self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
else:
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.gptq: if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"): if not hasattr(self.model_config, "quantization_config"):
@@ -549,7 +552,9 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig( self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config **self.model_config.quantization_config
) )
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit: elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
"load_in_4bit", False
):
bnb_config = { bnb_config = {
"load_in_4bit": True, "load_in_4bit": True,
"llm_int8_threshold": 6.0, "llm_int8_threshold": 6.0,
@@ -575,7 +580,9 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig( self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
) )
elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit: elif self.cfg.adapter == "lora" and self.model_kwargs.get(
"load_in_8bit", False
):
bnb_config = { bnb_config = {
"load_in_8bit": True, "load_in_8bit": True,
} }
@@ -589,6 +596,11 @@ class ModelLoader:
**bnb_config, **bnb_config,
) )
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)
def _set_attention_config(self): def _set_attention_config(self):
"""Sample packing uses custom FA2 patch""" """Sample packing uses custom FA2 patch"""
if self.cfg.attn_implementation: if self.cfg.attn_implementation:

View File

@@ -84,7 +84,9 @@ class PatchManager:
patch_evaluation_loop() patch_evaluation_loop()
patch_maybe_log_save_evaluate() patch_maybe_log_save_evaluate()
if self.cfg.context_parallel_size > 1: if self.cfg.context_parallel_size > 1 and getattr(
self.cfg, "flash_attention", False
):
from axolotl.monkeypatch.transformers.trainer_context_parallel import ( from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs, patch_prepare_context_parallel_inputs,
) )

View File

@@ -4,7 +4,6 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
import copy import copy
import functools import functools
import os
import sys import sys
import torch import torch
@@ -278,11 +277,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
mesh = getattr(accelerator.state, "device_mesh", None) mesh = getattr(accelerator.state, "device_mesh", None)
# Disable memory pinning if requested
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "") == "false":
fsdp2_plugin.cpu_offload.pin_memory = False
fsdp2_kwargs = { fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward, "reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload, "offload_policy": fsdp2_plugin.cpu_offload,
@@ -347,6 +341,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
) )
if fsdp2_plugin.cpu_ram_efficient_loading: if fsdp2_plugin.cpu_ram_efficient_loading:
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
fsdp2_load_full_state_dict( fsdp2_load_full_state_dict(
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
) )

View File

@@ -134,11 +134,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return Qwen2Attention return Qwen2Attention
if model_type == "qwen3_vl":
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextAttention
return Qwen3VLTextAttention
if model_type == "mllama": if model_type == "mllama":
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention

View File

@@ -45,8 +45,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"gpt_oss", "gpt_oss",
"arcee", "arcee",
"seed_oss", "seed_oss",
"lfm2",
"lfm2_moe",
] ]

View File

@@ -13,21 +13,10 @@ from typing import Callable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers
import transformers.modeling_flash_attention_utils import transformers.modeling_flash_attention_utils as flash_utils
from ring_flash_attn import ring_flash_attn_func from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
@@ -118,7 +107,7 @@ def create_flash_attn_forward_varlen_llama3(
# Handle sliding window # Handle sliding window
use_sliding_windows = ( use_sliding_windows = (
_flash_supports_window _flash_windows_supported()
and sliding_window is not None and sliding_window is not None
and key_states.shape[1] > sliding_window and key_states.shape[1] > sliding_window
) )
@@ -194,3 +183,18 @@ def substitute_hf_flash_attn(
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
def _flash_windows_supported() -> bool:
"""Return whether current transformers build advertises sliding-window support."""
support = getattr(flash_utils, "_flash_supports_window", None)
if support is None:
support = getattr(flash_utils, "_flash_supports_window_size", None)
if support is None:
return True
if callable(support):
return True
return bool(support)

View File

@@ -13,18 +13,9 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers.modeling_flash_attention_utils as flash_utils
from torch.distributed import DeviceMesh from torch.distributed import DeviceMesh
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
@@ -83,7 +74,7 @@ def create_ring_flash_attention_forward(
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = ( use_sliding_windows = (
_flash_supports_window _flash_windows_supported()
and sliding_window is not None and sliding_window is not None
and key_states.shape[1] > sliding_window and key_states.shape[1] > sliding_window
) )
@@ -225,3 +216,19 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
def _flash_windows_supported() -> bool:
"""Best-effort check for FlashAttention sliding-window support."""
support = getattr(flash_utils, "_flash_supports_window", None)
if support is None:
support = getattr(flash_utils, "_flash_supports_window_size", None)
if support is None:
return True
if callable(support):
# Signature differs across versions; assume support when callable.
return True
return bool(support)

View File

@@ -13,7 +13,9 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":' GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):' PATCHED_GUARD = (
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
)
def patch_prepare_context_parallel_inputs() -> None: def patch_prepare_context_parallel_inputs() -> None:

View File

@@ -6,10 +6,8 @@ from typing import Optional
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.Image import Resampling from PIL.Image import Resampling
from torch import Tensor, zeros_like from torch import Tensor, zeros_like
from transformers import ProcessorMixin from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
from transformers.image_utils import load_image from transformers.image_utils import load_image
from transformers.models.smolvlm import SmolVLMProcessor
from transformers.models.voxtral import VoxtralProcessor
from axolotl.utils.dict import remove_none_values from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger

View File

@@ -71,10 +71,10 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
] ]
return { return {
"chosen_input_ids": chosen_tokenized["input_ids"], "input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"], "attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0, "labels_chosen": 1.0,
"rejected_input_ids": rejected_tokenized["input_ids"], "input_ids_rejected": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"], "attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0, "labels_rejected": 0.0,
} }

View File

@@ -2,7 +2,6 @@
HF Chat Templates prompt strategy HF Chat Templates prompt strategy
""" """
import json
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
@@ -795,22 +794,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if val is not None: if val is not None:
transformed_message[key] = val transformed_message[key] = val
if "tool_calls" in transformed_message and transformed_message["tool_calls"]:
for tool_call in transformed_message["tool_calls"]:
if "function" in tool_call and "arguments" in tool_call["function"]:
args = tool_call["function"]["arguments"]
if isinstance(args, str):
try:
tool_call["function"]["arguments"] = json.loads(args)
except json.JSONDecodeError as e:
LOG.error(
f"Error parsing tool_calls arguments as JSON. "
f"Function: {tool_call.get('function', {}).get('name', 'unknown')}, "
f"Arguments string: {args!r}, "
f"Error: {e}"
)
raise
return transformed_message return transformed_message
def _get_images(self, prompt): def _get_images(self, prompt):

View File

@@ -120,123 +120,3 @@ def default(cfg, dataset_idx=0, **kwargs):
return result return result
return transform_fn, {"remove_columns": [field_messages]} return transform_fn, {"remove_columns": [field_messages]}
def argilla_chat(cfg, dataset_idx=0, **kwargs):
"""
DPO chat template strategy for argilla-style datasets.
For argilla-style datasets where chosen/rejected contain full conversations
instead of single response messages. Extracts the conversation history from
the chosen field and formats both chosen/rejected responses using the
configured chat template.
Args:
cfg: Configuration object containing chat_template and dataset settings
dataset_idx: Index of the dataset in the config (default: 0)
**kwargs: Additional keyword arguments (unused)
Returns:
tuple: (transform_fn, dataset_kwargs) where:
- transform_fn: Function to transform dataset samples
- dataset_kwargs: Dict with 'remove_columns' specifying columns to drop
Dataset format:
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
"""
ds_cfg = cfg["datasets"][dataset_idx]
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
message_property_mappings = ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
)
role_map_inv = ds_cfg.get(
"roles",
{
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
)
role_map = {}
for target, sources in role_map_inv.items():
for source in sources:
role_map[source] = target
def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
chosen_raw = sample[field_chosen]
rejected_raw = sample[field_rejected]
# Extract messages (all but last) and responses (last message)
chosen_messages = [
{
"role": role_map[m[message_property_mappings["role"]]],
"content": m[message_property_mappings["content"]],
}
for m in chosen_raw[:-1]
]
chosen_response = {
"role": role_map[chosen_raw[-1][message_property_mappings["role"]]],
"content": chosen_raw[-1][message_property_mappings["content"]],
}
rejected_response = {
"role": role_map[rejected_raw[-1][message_property_mappings["role"]]],
"content": rejected_raw[-1][message_property_mappings["content"]],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
chosen_messages,
add_generation_prompt=True,
chat_template=chat_template_string,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[dummy_user_message, chosen_response],
add_generation_prompt=False,
chat_template=chat_template_string,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen_response["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[dummy_user_message, rejected_response],
add_generation_prompt=False,
chat_template=chat_template_string,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected_response["content"])
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
return result
return transform_fn, {"remove_columns": [field_chosen, field_rejected]}

View File

@@ -40,6 +40,11 @@ from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint from axolotl.utils.train import determine_last_checkpoint
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
try:
from optimum.bettertransformer import BetterTransformer
except ImportError:
BetterTransformer = None
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -136,6 +141,8 @@ def setup_signal_handler(
def terminate_handler(_, __, model_weakref): def terminate_handler(_, __, model_weakref):
if model_weakref() is not None: if model_weakref() is not None:
_model = model_weakref() _model = model_weakref()
if cfg.flash_optimum and BetterTransformer:
_model = BetterTransformer.reverse(_model)
_model.save_pretrained( _model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization cfg.output_dir, safe_serialization=safe_serialization
) )
@@ -172,7 +179,11 @@ def execute_training(
) )
) )
if cfg.context_parallel_size > 1: use_flash_cp = cfg.context_parallel_size > 1 and bool(
getattr(cfg, "flash_attention", False)
)
if use_flash_cp:
models = [trainer.model] models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model: if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model) models.append(trainer.ref_model)
@@ -314,6 +325,9 @@ def save_trained_model(
except FileNotFoundError: except FileNotFoundError:
pass pass
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained( trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization cfg.output_dir, safe_serialization=safe_serialization
@@ -525,17 +539,6 @@ def setup_model_and_trainer(
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer) plugin_manager.post_trainer_create(cfg, trainer)
if cfg.use_ray:
try:
import ray.train.huggingface.transformers
trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
except ImportError:
LOG.warning(
"The Ray integration with Hugging Face Transformers is not available. "
"To use Ray, install the 'ray[train]' package."
)
return ( return (
trainer, trainer,
model, model,

View File

@@ -17,13 +17,6 @@ def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None return importlib.util.find_spec("comet_ml") is not None
def is_opentelemetry_available():
return (
importlib.util.find_spec("opentelemetry") is not None
and importlib.util.find_spec("prometheus_client") is not None
)
def get_pytorch_version() -> tuple[int, int, int]: def get_pytorch_version() -> tuple[int, int, int]:
""" """
Get Pytorch version as a tuple of (major, minor, patch). Get Pytorch version as a tuple of (major, minor, patch).

View File

@@ -16,8 +16,8 @@ import pandas as pd
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import wandb import wandb
import yaml
from datasets import load_dataset from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import (
GenerationConfig, GenerationConfig,
@@ -28,6 +28,8 @@ from transformers import (
TrainingArguments, TrainingArguments,
) )
from transformers.trainer_utils import ( from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
IntervalStrategy,
SaveStrategy, SaveStrategy,
) )
from trl.models import unwrap_model_for_generation from trl.models import unwrap_model_for_generation
@@ -54,6 +56,40 @@ IGNORE_INDEX = -100
LOG = get_logger(__name__) LOG = get_logger(__name__)
class SaveBetterTransformerModelCallback(TrainerCallback):
"""Callback to save the BetterTransformer wrapped model"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> TrainerControl:
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# FIXME - need to cleanup old checkpoints
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control
class LossWatchDogCallback(TrainerCallback): class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high""" """Callback to track loss and stop training if loss is too high"""
@@ -760,37 +796,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}") LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
try:
with open(self.axolotl_config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
chat_tpl = cfg.get("chat_template_jinja")
if chat_tpl:
with NamedTemporaryFile(
mode="w", delete=True, suffix=".jinja", prefix="chat_template_"
) as temp_ct_file:
if (
isinstance(chat_tpl, str)
and os.path.exists(chat_tpl)
and os.path.isfile(chat_tpl)
):
copyfile(chat_tpl, temp_ct_file.name)
else:
temp_ct_file.write(str(chat_tpl))
temp_ct_file.flush()
artifact = wandb.Artifact(
f"chat-template-{wandb.run.id}", type="jinja-template"
)
artifact.add_file(temp_ct_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_ct_file.name)
LOG.info(
"The chat_template_jinja has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError, yaml.YAMLError) as err:
LOG.warning(f"Error while saving chat_template_jinja to WandB: {err}")
if args.deepspeed: if args.deepspeed:
try: try:
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later. # sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.

View File

@@ -1,238 +0,0 @@
"""OpenTelemetry metrics callback for Axolotl training"""
import threading
from typing import Dict, Optional
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from opentelemetry import metrics
from opentelemetry.exporter.prometheus import PrometheusMetricReader
from opentelemetry.metrics import set_meter_provider
from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider
from prometheus_client import start_http_server
OPENTELEMETRY_AVAILABLE = True
except ImportError:
LOG.warning("OpenTelemetry not available. pip install [opentelemetry]")
OPENTELEMETRY_AVAILABLE = False
class OpenTelemetryMetricsCallback(TrainerCallback):
"""
TrainerCallback that exports training metrics to OpenTelemetry/Prometheus.
This callback automatically tracks key training metrics including:
- Training loss
- Evaluation loss
- Learning rate
- Epoch progress
- Global step count
- Gradient norm
Metrics are exposed via HTTP endpoint for Prometheus scraping.
"""
def __init__(self, cfg):
if not OPENTELEMETRY_AVAILABLE:
LOG.warning("OpenTelemetry not available, metrics will not be collected")
self.metrics_enabled = False
return
self.cfg = cfg
self.metrics_host = getattr(cfg, "otel_metrics_host", "localhost")
self.metrics_port = getattr(cfg, "otel_metrics_port", 8000)
self.metrics_enabled = True
self.server_started = False
self.metrics_lock = threading.Lock()
try:
# Create Prometheus metrics reader
prometheus_reader = PrometheusMetricReader()
# Create meter provider with Prometheus exporter
provider = SDKMeterProvider(metric_readers=[prometheus_reader])
set_meter_provider(provider)
# Get meter for creating metrics
self.meter = metrics.get_meter("axolotl.training")
# Create metrics
self._create_metrics()
except Exception as e:
LOG.warning(f"Failed to initialize OpenTelemetry metrics: {e}")
self.metrics_enabled = False
def _create_metrics(self):
"""Create all metrics that will be tracked"""
self.train_loss_gauge = self.meter.create_gauge(
name="axolotl_train_loss",
description="Current training loss",
unit="1",
)
self.eval_loss_gauge = self.meter.create_gauge(
name="axolotl_eval_loss",
description="Current evaluation loss",
unit="1",
)
self.learning_rate_gauge = self.meter.create_gauge(
name="axolotl_learning_rate",
description="Current learning rate",
unit="1",
)
self.epoch_gauge = self.meter.create_gauge(
name="axolotl_epoch",
description="Current training epoch",
unit="1",
)
self.global_step_counter = self.meter.create_counter(
name="axolotl_global_steps",
description="Total training steps completed",
unit="1",
)
self.grad_norm_gauge = self.meter.create_gauge(
name="axolotl_gradient_norm",
description="Gradient norm",
unit="1",
)
self.memory_usage_gauge = self.meter.create_gauge(
name="axolotl_memory_usage",
description="Current memory usage in MB",
unit="MB",
)
def _start_metrics_server(self):
"""Start the HTTP server for metrics exposure"""
if self.server_started:
return
try:
start_http_server(self.metrics_port, addr=self.metrics_host)
self.server_started = True
LOG.info(
f"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics"
)
except Exception as e:
LOG.error(f"Failed to start OpenTelemetry metrics server: {e}")
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the beginning of training"""
if not self.metrics_enabled:
return
self._start_metrics_server()
LOG.info("OpenTelemetry metrics collection started")
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: Optional[Dict[str, float]] = None,
**kwargs,
):
"""Called when logging occurs"""
if not self.metrics_enabled or not logs:
return
if "loss" in logs:
self.train_loss_gauge.set(logs["loss"])
if "eval_loss" in logs:
self.eval_loss_gauge.set(logs["eval_loss"])
if "learning_rate" in logs:
self.learning_rate_gauge.set(logs["learning_rate"])
if "epoch" in logs:
self.epoch_gauge.set(logs["epoch"])
if "grad_norm" in logs:
self.grad_norm_gauge.set(logs["grad_norm"])
if "memory_usage" in logs:
self.memory_usage_gauge.set(logs["memory_usage"])
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the end of each training step"""
if not self.metrics_enabled:
return
# Update step counter and epoch
self.global_step_counter.add(1)
if state.epoch is not None:
self.epoch_gauge.set(state.epoch)
def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
metrics: Optional[Dict[str, float]] = None,
**kwargs,
):
"""Called after evaluation"""
if not self.metrics_enabled or not metrics:
return
if "eval_loss" in metrics:
self.eval_loss_gauge.set(metrics["eval_loss"])
# Record any other eval metrics as gauges
for key, value in metrics.items():
if key.startswith("eval_") and isinstance(value, (int, float)):
# Create gauge for this metric if it doesn't exist
gauge_name = f"axolotl_{key}"
try:
gauge = self.meter.create_gauge(
name=gauge_name,
description=f"Evaluation metric: {key}",
unit="1",
)
gauge.set(value)
except Exception as e:
LOG.warning(f"Failed to create/update metric {gauge_name}: {e}")
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the end of training"""
if not self.metrics_enabled:
return
LOG.info("Training completed. OpenTelemetry metrics collection finished.")
LOG.info(
f"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics"
)

View File

@@ -113,7 +113,7 @@ def _map_dataset(
dataset = dataset.map( dataset = dataset.map(
ds_transform_fn, ds_transform_fn,
num_proc=cfg.dataset_num_proc, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Mapping RL Dataset", desc="Mapping RL Dataset",
**map_kwargs, **map_kwargs,
@@ -234,7 +234,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
prior_len = len(split_datasets[i]) prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter( split_datasets[i] = split_datasets[i].filter(
drop_long, drop_long,
num_proc=cfg.dataset_num_proc, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences", desc="Dropping Long Sequences",
) )

View File

@@ -239,11 +239,6 @@ def _load_from_local_path(
return load_dataset(dataset_config.path, **load_dataset_kwargs) return load_dataset(dataset_config.path, **load_dataset_kwargs)
elif local_path.is_file(): elif local_path.is_file():
dataset_type = get_dataset_type(dataset_config) dataset_type = get_dataset_type(dataset_config)
# For single file datasets, HF always creates only a "train" split
if dataset_type in ("json", "csv", "text"):
load_dataset_kwargs["split"] = "train"
return load_dataset( return load_dataset(
dataset_type, dataset_type,
data_files=dataset_config.path, data_files=dataset_config.path,
@@ -414,7 +409,7 @@ def save_preprocessed_dataset(
) -> None: ) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub.""" """Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash) prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
num_workers = cfg.dataset_num_proc or get_default_process_count() num_workers = cfg.dataset_processes or get_default_process_count()
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
ds_from_iter = Dataset.from_generator( ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset), functools.partial(_generate_from_iterable_dataset, dataset),

View File

@@ -223,7 +223,7 @@ def handle_long_seq_in_dataset(
filter_map_kwargs = {} filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset): if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {} drop_long_kwargs = {}

View File

@@ -80,7 +80,7 @@ def get_dataset_wrapper(
""" """
# Common parameters for dataset wrapping # Common parameters for dataset wrapping
dataset_kwargs: dict[str, Any] = { dataset_kwargs: dict[str, Any] = {
"process_count": cfg.dataset_num_proc, "process_count": cfg.dataset_processes,
"keep_in_memory": cfg.dataset_keep_in_memory is True, "keep_in_memory": cfg.dataset_keep_in_memory is True,
} }

View File

@@ -4,8 +4,6 @@ import os
def get_default_process_count(): def get_default_process_count():
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
return int(axolotl_dataset_num_proc)
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
return int(axolotl_dataset_processes) return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):

View File

@@ -3,46 +3,66 @@ utils to get GPU info for the current environment
""" """
import os import os
import subprocess # nosec B404
from importlib.metadata import version from importlib.metadata import version
import torch
from accelerate.utils.environment import ( from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
get_gpu_info,
) )
from packaging.version import Version, parse from packaging.version import Version, parse
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_cuda_p2p_ib_support(): def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support(): if not accelerate_check_cuda_p2p_ib_support():
return False return False
if not check_cuda_p2p_support(): if not check_runpod_p2p_support():
return False return False
unsupported_devices = {"RTX 6000 Ada", "L40S"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8:
if any(
unsupported_device in device_name
for device_name in device_names
for unsupported_device in unsupported_devices
):
return False
except Exception: # nosec B110
pass
return True return True
def check_cuda_p2p_support() -> bool: def check_runpod_p2p_support() -> bool:
if "RUNPOD_GPU_COUNT" not in os.environ:
return True
try: try:
world_size = int(os.environ.get("WORLD_SIZE", "1")) gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
except ValueError: except ValueError:
return True return True
if gpu_count >= 2:
if world_size > 1: # run `nvidia-smi topo -p2p n` and inspect the GPU0 row
node_world_size = int(os.environ.get("NODE_WORLD_SIZE", "8"))
local_other_rank = (local_rank // node_world_size) * node_world_size
local_other_rank += 1 if (local_rank % node_world_size) == 0 else 0
try: try:
can_p2p = torch.cuda.can_device_access_peer(local_rank, local_other_rank) result = subprocess.run( # nosec B603 B607
except AssertionError as exc: ["nvidia-smi", "topo", "-p2p", "n"],
# some sort of logic error in indexing processes, assume p2p is fine for now check=True,
LOG.warning(exc) capture_output=True,
text=True,
timeout=5,
)
except (
subprocess.CalledProcessError,
FileNotFoundError,
subprocess.TimeoutExpired,
):
return True # fail-open if detection fails
output_lines = result.stdout.strip().split("\n")
# filter rows that start with "GPU0" (avoid header row)
gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")]
if not gpu0_rows:
return True return True
return can_p2p # consider P2P supported if any OK is present in the GPU0 row
return "OK" in gpu0_rows[-1]
return True return True

View File

@@ -148,7 +148,7 @@ def load_sharded_model(
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_name, model_name,
use_cache=False, use_cache=False,
dtype=torch.float32, torch_dtype=torch.float32,
_attn_implementation=model_config._attn_implementation, _attn_implementation=model_config._attn_implementation,
trust_remote_code=cfg.trust_remote_code, trust_remote_code=cfg.trust_remote_code,
) )
@@ -158,7 +158,7 @@ def load_sharded_model(
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
model_config, model_config,
dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code, trust_remote_code=cfg.trust_remote_code,
) )
return model return model

View File

@@ -5,7 +5,6 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import gc import gc
import math import math
import os
import time import time
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context from multiprocessing import cpu_count, get_context
@@ -292,10 +291,7 @@ class MultipackBatchSampler(BatchSampler):
self.total_token_slots = 0 self.total_token_slots = 0
# The number of times to calculate batches to determine minimum packed dataset length # The number of times to calculate batches to determine minimum packed dataset length
world_size = int(os.environ.get("WORLD_SIZE", "1")) self.num_count_samples = num_count_samples
self.num_count_samples = (
1 if world_size >= num_count_samples else num_count_samples
)
if self.sequential and not isinstance(sampler, SequentialSampler): if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning( LOG.warning(

View File

@@ -24,13 +24,11 @@ from axolotl.utils.schemas.datasets import (
) )
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.fsdp import FSDPConfig
from axolotl.utils.schemas.integrations import ( from axolotl.utils.schemas.integrations import (
CometConfig, CometConfig,
GradioConfig, GradioConfig,
LISAConfig, LISAConfig,
MLFlowConfig, MLFlowConfig,
OpenTelemetryConfig,
RayConfig, RayConfig,
WandbConfig, WandbConfig,
) )
@@ -61,7 +59,6 @@ class AxolotlInputConfig(
WandbConfig, WandbConfig,
MLFlowConfig, MLFlowConfig,
CometConfig, CometConfig,
OpenTelemetryConfig,
LISAConfig, LISAConfig,
GradioConfig, GradioConfig,
RayConfig, RayConfig,
@@ -236,7 +233,6 @@ class AxolotlInputConfig(
) )
dataset_processes: int | None = Field( dataset_processes: int | None = Field(
default=None, default=None,
deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.",
json_schema_extra={ json_schema_extra={
"description": ( "description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
@@ -244,16 +240,6 @@ class AxolotlInputConfig(
) )
}, },
) )
dataset_num_proc: int | None = Field(
default=None,
json_schema_extra={
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
)
},
)
dataset_exact_deduplication: bool | None = Field( dataset_exact_deduplication: bool | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
@@ -681,7 +667,8 @@ class AxolotlInputConfig(
json_schema_extra={"description": "FSDP configuration"}, json_schema_extra={"description": "FSDP configuration"},
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ", deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
) )
fsdp_config: FSDPConfig | None = Field( # TODO @SalmanMohammadi strongly type this as its own schema
fsdp_config: dict[str, Any] | None = Field(
default=None, json_schema_extra={"description": "FSDP configuration options"} default=None, json_schema_extra={"description": "FSDP configuration options"}
) )
fsdp_version: int | None = Field( fsdp_version: int | None = Field(
@@ -1327,22 +1314,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def default_dataset_num_proc(cls, data): def default_dataset_processes(cls, data):
if data.get("dataset_processes") is not None: if data.get("dataset_processes") is None:
if data.get("dataset_num_proc") is None: data["dataset_processes"] = get_default_process_count()
data["dataset_num_proc"] = data["dataset_processes"]
LOG.warning(
"dataset_processes is deprecated and will be removed in a future version. "
"Please use dataset_num_proc instead."
)
else:
LOG.warning(
"Both dataset_processes and dataset_num_proc are set. "
"Using dataset_num_proc and ignoring dataset_processes."
)
del data["dataset_processes"]
elif data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = get_default_process_count()
return data return data
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -1,71 +0,0 @@
"""
FSDP Configuration Schema
"""
from typing import Literal
from pydantic import BaseModel, Field
class FSDPConfig(BaseModel):
"""
FSDP Configuration Schema
"""
activation_checkpointing: bool | None = Field(
default=None,
description="Enable activation checkpointing to reduce memory usage during forward passes",
)
offload_params: bool | None = Field(
default=None,
description="Offload parameters to CPU to reduce GPU memory usage",
)
sync_module_states: bool | None = Field(
default=None,
description="Synchronize module states across all processes",
)
cpu_ram_efficient_loading: bool | None = Field(
default=None,
description="Enable CPU RAM efficient loading to reduce memory usage during model loading",
)
cpu_offload_pin_memory: bool | None = Field(
default=None,
description="Disabling this enables swap memory usage for resource-constrained setups when offload_params is enabled.",
)
use_orig_params: bool | None = Field(
default=None,
description="Use original parameters instead of flattened parameters",
)
state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
default=None,
description="Type of state dict to use for saving/loading checkpoints",
)
final_state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
default=None,
description="Final state dict type to use after training completion",
)
auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = (
Field(
default=None,
description="Policy for automatically wrapping modules with FSDP",
)
)
transformer_layer_cls_to_wrap: str | None = Field(
default=None,
description="Class name of transformer layers to wrap (e.g., 'LlamaDecoderLayer')",
)
reshard_after_forward: bool | None = Field(
default=None,
description="Reshard parameters after forward pass to save memory",
)
mixed_precision_policy: str | None = Field(
default=None,
description="Mixed precision policy for FSDP (e.g., 'fp16', 'bf16')",
)

View File

@@ -176,27 +176,3 @@ class RayConfig(BaseModel):
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker." "help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
}, },
) )
class OpenTelemetryConfig(BaseModel):
"""OpenTelemetry configuration subset"""
use_otel_metrics: bool | None = Field(
default=False,
json_schema_extra={
"description": "Enable OpenTelemetry metrics collection and Prometheus export"
},
)
otel_metrics_host: str | None = Field(
default="localhost",
json_schema_extra={
"title": "OpenTelemetry Metrics Host",
"description": "Host to bind the OpenTelemetry metrics server to",
},
)
otel_metrics_port: int | None = Field(
default=8000,
json_schema_extra={
"description": "Port for the Prometheus metrics HTTP server"
},
)

View File

@@ -167,9 +167,3 @@ class TRLConfig(BaseModel):
"description": "Whether to exclude truncated completions from loss calculation." "description": "Whether to exclude truncated completions from loss calculation."
}, },
) )
vllm_enable_sleep_mode: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable sleep mode for vLLM to offload VRAM when idle"
},
)

View File

@@ -1,7 +1,6 @@
"""Module with validation methods for config pydantic model.""" """Module with validation methods for config pydantic model."""
import json import json
import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@@ -783,6 +782,15 @@ class OptimizationValidationMixin:
return data return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
if data.get("deepspeed") and data.get("torch_compile"):
raise ValueError(
"torch_compile should be set within your deepspeed config file"
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_xentropy_patch_conflicts(cls, data): def check_xentropy_patch_conflicts(cls, data):
@@ -807,22 +815,21 @@ class OptimizationValidationMixin:
) )
return data return data
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_fsdp2_base_model_quant_ram_efficient_loading(self):
def check_fsdp2_cpu_offload_pin_memory(cls, data): fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
if not (fsdp_config := data.get("fsdp_config")): fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
return data load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
if fsdp_config.get("cpu_offload_pin_memory") is False: if fsdp_config and fsdp_version == 2:
if str(data.get("fsdp_version")) != "2": if fsdp_config.get("cpu_ram_efficient_loading") and (
load_in_8bit or load_in_4bit
):
raise ValueError( raise ValueError(
"FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2" "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
) )
if not fsdp_config.get("offload_params"): return self
raise ValueError(
"disabling cpu_offload_pin_memory requires enabling offload_params"
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -881,7 +888,7 @@ class OptimizationValidationMixin:
and self.fsdp_config and self.fsdp_config
and self.optimizer and self.optimizer
and "8bit" in self.optimizer.value and "8bit" in self.optimizer.value
and self.fsdp_config.offload_params and self.fsdp_config["offload_params"]
and str(self.fsdp_version) != "2" and str(self.fsdp_version) != "2"
): ):
raise ValueError( raise ValueError(
@@ -1306,50 +1313,40 @@ class ComplexValidationMixin:
if not self.context_parallel_size: if not self.context_parallel_size:
self.context_parallel_size = 1 self.context_parallel_size = 1
elif self.context_parallel_size > 1: elif self.context_parallel_size > 1:
if not self.flash_attention: use_flash_attention = getattr(self, "flash_attention", False)
use_sdp_attention = getattr(self, "sdp_attention", False)
if not (use_flash_attention or use_sdp_attention):
raise ValueError( raise ValueError(
"flash_attention: true must be set with context_parallel_size > 1" "context_parallel_size > 1 requires either flash_attention: true "
"or sdp_attention: true"
) )
if self.sample_packing and self.micro_batch_size > 1: if use_flash_attention:
raise ValueError( if self.sample_packing and self.micro_batch_size > 1:
"micro_batch_size must be set to 1 when sample_packing is enabled " raise ValueError(
"due to a `ring-flash-attn` requirement" "micro_batch_size must be set to 1 when sample_packing is enabled "
"due to a `ring-flash-attn` requirement"
)
try:
import ring_flash_attn # noqa: F401 # Required after monkey-patching
except ImportError as exception:
raise ImportError(
"context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
) )
try:
import transformers.modeling_flash_attention_utils
from transformers.utils import is_flash_attn_greater_or_equal
transformers.modeling_flash_attention_utils._flash_supports_window = (
True
)
sys.modules[
"transformers.modeling_flash_attention_utils"
]._flash_supports_window = True
sys.modules[
"transformers.modeling_flash_attention_utils"
]._flash_supports_window_size = True
sys.modules[
"transformers.modeling_flash_attention_utils"
].is_flash_attn_greater_or_equal = is_flash_attn_greater_or_equal
import ring_flash_attn # noqa: F401 # Required after monkey-patching
except ImportError as exception:
raise ImportError(
"context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return self return self
@model_validator(mode="after") @model_validator(mode="after")

View File

@@ -109,8 +109,8 @@ def prepare_debug_log(cfg, filename: str = "debug.log") -> str:
cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints") cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints")
) )
if not append: if not append and log_path.exists():
log_path.unlink(missing_ok=True) log_path.unlink()
fh = open(log_path, "a", encoding="utf-8") fh = open(log_path, "a", encoding="utf-8")
fh.flush() fh.flush()

View File

@@ -6,7 +6,6 @@ import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from tempfile import NamedTemporaryFile
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
@@ -16,7 +15,6 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -278,7 +276,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
prior_len = None prior_len = None
filter_map_kwargs = {} filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset): if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {} drop_long_kwargs = {}
@@ -318,7 +316,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.group_by_length: if cfg.group_by_length:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_length, add_length,
num_proc=cfg.dataset_num_proc, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length", desc="Group By Length",
) )
@@ -335,7 +333,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
) )
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
pose_fn, pose_fn,
num_proc=cfg.dataset_num_proc, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
@@ -344,7 +342,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
pose_fn, pose_fn,
num_proc=cfg.dataset_num_proc, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
@@ -469,7 +467,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
bin_size=cfg.sample_packing_bin_size, bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially, sequential=cfg.sample_packing_sequentially,
drop_last=True, drop_last=True,
num_processes=cfg.dataset_prcoesses, num_processes=cfg.dataset_processes,
mp_start_method=cfg.sample_packing_mp_start_method or "fork", mp_start_method=cfg.sample_packing_mp_start_method or "fork",
) )
@@ -542,13 +540,6 @@ def setup_deepspeed_env(cfg, stage=None):
) )
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
if isinstance(cfg.deepspeed, DictDefault):
with NamedTemporaryFile(
mode="w", delete=False, suffix=".json", prefix="deepspeed_config_"
) as temp_file:
temp_file.write(json.dumps(cfg.deepspeed.to_dict(), indent=4))
temp_file.close()
cfg.deepspeed = str(temp_file.name)
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str( os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(
cfg.gradient_accumulation_steps cfg.gradient_accumulation_steps
@@ -571,7 +562,6 @@ def setup_deepspeed_env(cfg, stage=None):
if ( if (
int(os.environ.get("WORLD_SIZE", "1")) == 1 int(os.environ.get("WORLD_SIZE", "1")) == 1
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1" and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
and cfg.use_ray is not True
): ):
os.environ["WORLD_SIZE"] = "1" # force it in case not set os.environ["WORLD_SIZE"] = "1" # force it in case not set
os.environ["LOCAL_RANK"] = "0" # force it in case not set os.environ["LOCAL_RANK"] = "0" # force it in case not set
@@ -605,10 +595,6 @@ def setup_fsdp_envs(cfg):
os.environ["FSDP_USE_ORIG_PARAMS"] = "true" os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
if cfg.fsdp_config.state_dict_type: if cfg.fsdp_config.state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type
if cfg.fsdp_config.cpu_offload_pin_memory is not None:
os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str(
cfg.fsdp_config.cpu_offload_pin_memory
).lower()
if cfg.fsdp_config.auto_wrap_policy: if cfg.fsdp_config.auto_wrap_policy:
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy
if cfg.fsdp_config.transformer_layer_cls_to_wrap: if cfg.fsdp_config.transformer_layer_cls_to_wrap:
@@ -641,7 +627,6 @@ def setup_parallelism_envs(cfg):
def prepare_optim_env(cfg): def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support(): if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None: if os.getenv("NCCL_P2P_DISABLE") is None:
LOG.warning("P2P support not detected, setting `NCCL_P2P_DISABLE=1`")
os.environ["NCCL_P2P_DISABLE"] = "1" os.environ["NCCL_P2P_DISABLE"] = "1"
# TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12 # TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12
if cfg.fsdp or cfg.fsdp_config: if cfg.fsdp or cfg.fsdp_config:
@@ -649,15 +634,11 @@ def prepare_optim_env(cfg):
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed:
stage = None stage = None
deepspeed_config = None
# check if the cfg.deepspeed is a file # check if the cfg.deepspeed is a file
if isinstance(cfg.deepspeed, DictDefault): if os.path.isfile(cfg.deepspeed):
deepspeed_config = cfg.deepspeed
elif os.path.isfile(cfg.deepspeed):
# parse with json # parse with json
with open(cfg.deepspeed, "r", encoding="utf-8") as fin: with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
deepspeed_config = json.load(fin) deepspeed_config = json.load(fin)
if deepspeed_config:
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None) stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage) setup_deepspeed_env(cfg, stage=stage)

View File

@@ -33,6 +33,7 @@ def parse_requirements():
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0] torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# don't install xformers on MacOS # don't install xformers on MacOS
@@ -62,6 +63,7 @@ def parse_requirements():
_install_requires.append("xformers==0.0.28.post2") _install_requires.append("xformers==0.0.28.post2")
else: else:
_install_requires.append("xformers==0.0.28.post3") _install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4): elif (major, minor) >= (2, 4):
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))

Some files were not shown because too many files have changed in this diff Show More