Compare commits
60 Commits
lora_bf16
...
fix/hpc-ro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83ff8bfa1a | ||
|
|
c37decb073 | ||
|
|
01a346d86a | ||
|
|
26f05b6008 | ||
|
|
ed58fa8a75 | ||
|
|
633afffacb | ||
|
|
4b1b4fa6d8 | ||
|
|
0f7c886b7b | ||
|
|
a4b921135b | ||
|
|
98333e639a | ||
|
|
9d4d39e939 | ||
|
|
bb33fda44d | ||
|
|
4dc018992d | ||
|
|
243620394a | ||
|
|
3750fdcf79 | ||
|
|
613bcf90e5 | ||
|
|
383f220cfd | ||
|
|
8bb871b5cf | ||
|
|
87565ecc05 | ||
|
|
93ba57396f | ||
|
|
aa1240acd8 | ||
|
|
4cdfdfebb5 | ||
|
|
6e2f5ccf9f | ||
|
|
8c7f63cf97 | ||
|
|
cd856b45b1 | ||
|
|
143dea4753 | ||
|
|
bc2ffb8204 | ||
|
|
153edcfe79 | ||
|
|
08b8fa62cc | ||
|
|
3a5c97e6e5 | ||
|
|
37f78c8592 | ||
|
|
ab63b92c38 | ||
|
|
6f8ce024d1 | ||
|
|
d0e9c3c1c5 | ||
|
|
4c3488cc9f | ||
|
|
130637a3fa | ||
|
|
377c510e95 | ||
|
|
409cfb8a87 | ||
|
|
ce74c20109 | ||
|
|
a6bfbe3400 | ||
|
|
f4376748f3 | ||
|
|
740d5a1d31 | ||
|
|
850c1a5f8d | ||
|
|
7fa8ac40cd | ||
|
|
f9748c4dc5 | ||
|
|
33975ce4bc | ||
|
|
e8b962d47f | ||
|
|
856ff12171 | ||
|
|
6bc959342b | ||
|
|
b3b92687c4 | ||
|
|
55d1be2ae6 | ||
|
|
08d831c3d5 | ||
|
|
7be8740c5c | ||
|
|
c51d6b06c3 | ||
|
|
09959fac70 | ||
|
|
4065bc14c6 | ||
|
|
e5c427f6de | ||
|
|
86d6ee7c05 | ||
|
|
d4cff1b7bb | ||
|
|
1ef6c196f7 |
49
.github/workflows/base.yml
vendored
49
.github/workflows/base.yml
vendored
@@ -25,20 +25,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
@@ -67,6 +53,20 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
@@ -122,13 +122,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
@@ -150,6 +143,20 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
15
.github/workflows/main.yml
vendored
15
.github/workflows/main.yml
vendored
@@ -15,11 +15,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -88,11 +83,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -162,11 +152,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
|
||||
14
.github/workflows/multi-gpu-e2e.yml
vendored
14
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,13 +26,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -47,6 +40,13 @@ jobs:
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
|
||||
20
.github/workflows/nightlies.yml
vendored
20
.github/workflows/nightlies.yml
vendored
@@ -12,16 +12,16 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -65,16 +65,16 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
2
.github/workflows/precommit-autoupdate.yml
vendored
2
.github/workflows/precommit-autoupdate.yml
vendored
@@ -2,7 +2,7 @@ name: Pre-commit auto-update
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 0' # Run weekly
|
||||
- cron: '0 0 1 * *' # Run monthly
|
||||
workflow_dispatch: # Manual kickoff
|
||||
|
||||
jobs:
|
||||
|
||||
10
.github/workflows/tests-nightly.yml
vendored
10
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.0"]
|
||||
pytorch_version: ["2.7.1", "2.8.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -102,14 +102,14 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
|
||||
44
.github/workflows/tests.yml
vendored
44
.github/workflows/tests.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -81,12 +81,12 @@ jobs:
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
pip3 install --no-cache-dir --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
@@ -130,7 +130,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -152,17 +152,17 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-build-isolation dist/axolotl*.tar.gz
|
||||
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
@@ -231,16 +231,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
@@ -289,15 +283,15 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
# - cuda: 128
|
||||
# cuda_version: 12.8.1
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.7.1
|
||||
# num_gpus: 1
|
||||
# axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -305,6 +299,12 @@ jobs:
|
||||
num_gpus: 1
|
||||
gpu_type: "B200"
|
||||
axolotl_extras: fbgemm-gpu
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -190,3 +190,6 @@ out/
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
||||
# scm auto-versioning
|
||||
src/axolotl/_version.py
|
||||
|
||||
@@ -11,13 +11,13 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.12
|
||||
rev: v0.14.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --select, I]
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.1
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
|
||||
@@ -73,7 +73,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.6.0
|
||||
- PyTorch ≥2.7.1
|
||||
|
||||
### Google Colab
|
||||
|
||||
|
||||
@@ -267,6 +267,7 @@ website:
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
- docs/optimizations.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
|
||||
@@ -32,6 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN uv pip install torchvision
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
@@ -9,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_PROCESSES="8"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -65,8 +65,13 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
sp_env = os.environ.copy()
|
||||
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
|
||||
exit(exit_code)
|
||||
try:
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
print(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
return exit_code
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Command '{cmd}' failed with exception {e}")
|
||||
|
||||
@@ -13,7 +13,7 @@ datasets:
|
||||
val_set_size: 0
|
||||
output_dir: temp_debug/axolotl_outputs/model
|
||||
dataset_prepared_path: temp_debug/axolotl_outputs/data
|
||||
dataset_processes: 1
|
||||
dataset_num_proc: 1
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
@@ -24,29 +24,35 @@ RUN apt-get update \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||
python3 -m pip cache purge
|
||||
|
||||
RUN if [ "$CUDA" != "130" ] ; then \
|
||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||
python3 -m pip cache purge; \
|
||||
fi
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="next"
|
||||
@@ -19,12 +19,12 @@ RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="nightly"
|
||||
@@ -19,14 +19,14 @@ RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -30,7 +30,13 @@ RUN uv venv --no-project --relocatable axolotl-venv
|
||||
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
|
||||
@@ -212,6 +212,14 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
|
||||
|
||||
```
|
||||
"arguments": "{\"...\": \"...\"}"
|
||||
```
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
|
||||
@@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet
|
||||
|
||||
### Pre-training without streaming
|
||||
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
|
||||
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
|
||||
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
|
||||
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
|
||||
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
|
||||
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
|
||||
|
||||
```yaml
|
||||
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
||||
// The flags below simplify debugging by overriding the axolotl config
|
||||
// with the debugging tips above. Modify as needed.
|
||||
"--dataset_processes=1", // limits data preprocessing to one process
|
||||
"--dataset_num_proc=1", // limits data preprocessing to one process
|
||||
"--max_steps=1", // limits training to just one step
|
||||
"--batch_size=1", // minimizes batch size
|
||||
"--micro_batch_size=1", // minimizes batch size
|
||||
|
||||
12
docs/faq.qmd
12
docs/faq.qmd
@@ -63,6 +63,14 @@ description: Frequently asked questions
|
||||
|
||||
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
||||
|
||||
**Q: Can we mix text and text+image datasets for VLM training?**
|
||||
|
||||
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
|
||||
|
||||
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
|
||||
|
||||
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
@@ -140,3 +148,7 @@ description: Frequently asked questions
|
||||
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
|
||||
|
||||
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
|
||||
|
||||
**Q: `Error parsing tool_calls arguments as JSON.`
|
||||
|
||||
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "FDSP + QLoRA"
|
||||
title: "FSDP + QLoRA"
|
||||
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
|
||||
format:
|
||||
html:
|
||||
@@ -23,6 +23,12 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Enabling Swap for FSDP2
|
||||
|
||||
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
|
||||
|
||||
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
|
||||
|
||||
## Example Config
|
||||
|
||||
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
|
||||
|
||||
@@ -5,10 +5,11 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU
|
||||
and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom
|
||||
autograd functions. Our goal was to leverage operator fusion and tensor re-use in order
|
||||
to improve speed and reduce memory usage during the forward and backward passes of
|
||||
these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
@@ -131,6 +132,5 @@ computation path.
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Additional operator fusions
|
||||
|
||||
@@ -27,3 +27,9 @@ learning_rate: 2e-5
|
||||
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||
self attention `q_proj` module.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
|
||||
|
||||
:::
|
||||
|
||||
@@ -88,6 +88,7 @@ fsdp_sync_module_states | **REMOVED**
|
||||
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
|
||||
fsdp_state_dict_type | state_dict_type
|
||||
fsdp_use_orig_params | **REMOVED**
|
||||
fsdp_activation_checkpointing | activation_checkpointing
|
||||
|
||||
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
|
||||
if you were using the following FSDP1 config:
|
||||
|
||||
@@ -13,6 +13,7 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
@@ -41,7 +42,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
# (optional) if doing lora, only finetune the Language model,
|
||||
# leave the vision model and vision tower frozen
|
||||
@@ -56,10 +56,14 @@ image_resize_algorithm: bilinear
|
||||
|
||||
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||
|
||||
::: {.callout-warning}
|
||||
::: {.callout-tip}
|
||||
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.
|
||||
:::
|
||||
|
||||
### Mllama {#sec-mllama}
|
||||
|
||||
```yaml
|
||||
@@ -94,10 +98,22 @@ chat_template: llava
|
||||
|
||||
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
```
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Magistral-Small-2509
|
||||
```
|
||||
|
||||
### Voxtral {#sec-voxtral}
|
||||
@@ -156,6 +172,14 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### Qwen3-VL {#sec-qwen3-vl}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### SmolVLM2 {#sec-smolvlm2}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
133
docs/optimizations.qmd
Normal file
133
docs/optimizations.qmd
Normal file
@@ -0,0 +1,133 @@
|
||||
---
|
||||
title: Optimizations Guide
|
||||
description: A guide to the performance and memory optimizations available in Axolotl.
|
||||
---
|
||||
|
||||
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
|
||||
|
||||
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
|
||||
|
||||
## Speed Optimizations
|
||||
|
||||
These optimizations focus on increasing training throughput and reducing total training time.
|
||||
|
||||
### Sample Packing
|
||||
|
||||
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
|
||||
|
||||
- **Config:** `sample_packing: true`
|
||||
- **Learn more:** [Sample Packing](multipack.qmd)
|
||||
|
||||
### Attention Implementations
|
||||
|
||||
Using an optimized attention implementation is critical for training speed.
|
||||
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
||||
|
||||
*Note: You should only enable one attention backend.*
|
||||
|
||||
### LoRA Optimizations
|
||||
|
||||
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
|
||||
|
||||
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
|
||||
|
||||
## Memory Optimizations
|
||||
|
||||
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
|
||||
|
||||
### Parameter Efficient Finetuning (LoRA & QLoRA)
|
||||
|
||||
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
|
||||
|
||||
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
|
||||
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
|
||||
|
||||
### Gradient Checkpointing & Activation Offloading
|
||||
|
||||
These techniques save VRAM by changing how activations are handled.
|
||||
|
||||
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
|
||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||
|
||||
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
|
||||
|
||||
### Liger Kernels
|
||||
|
||||
Provides efficient Triton kernels to improve training speed and reduce memory usage.
|
||||
|
||||
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
||||
|
||||
## Long Context Models
|
||||
|
||||
Techniques to train models on sequences longer than their original context window.
|
||||
|
||||
### RoPE Scaling
|
||||
|
||||
Extends a model's context window by interpolating its Rotary Position Embeddings.
|
||||
|
||||
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
|
||||
|
||||
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
|
||||
|
||||
### Artic Long Sequence Training (ALST)
|
||||
|
||||
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
|
||||
|
||||
- TiledMLP to reduce memory usage in MLP layers.
|
||||
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
|
||||
- Activation Offloading to CPU.
|
||||
|
||||
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
|
||||
|
||||
## Large Models (Distributed Training)
|
||||
|
||||
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
|
||||
|
||||
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
|
||||
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
|
||||
|
||||
### N-D Parallelism (Beta)
|
||||
|
||||
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
|
||||
|
||||
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Techniques to reduce the precision of model weights for memory savings.
|
||||
|
||||
### 4-bit Training (QLoRA)
|
||||
|
||||
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
|
||||
|
||||
### FP8 Training
|
||||
|
||||
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
|
||||
|
||||
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
|
||||
|
||||
### Quantization Aware Training (QAT)
|
||||
|
||||
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
|
||||
|
||||
- **Learn more:** [QAT Documentation](qat.qmd)
|
||||
|
||||
### GPTQ
|
||||
|
||||
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
||||
|
||||
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
||||
12
docs/qat.qmd
12
docs/qat.qmd
@@ -23,10 +23,18 @@ To enable QAT in axolotl, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
qat:
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
|
||||
```
|
||||
|
||||
We support the following quantization schemas:
|
||||
|
||||
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
|
||||
- `Int8DynamicActivationInt4Weight`
|
||||
- `Float8DynamicActivationFloat8Weight`
|
||||
- `Float8DynamicActivationInt4Weight`
|
||||
- `NVFP4`
|
||||
|
||||
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.
|
||||
|
||||
@@ -22,8 +22,8 @@ Quantization is configured using the `quantization` key in your configuration fi
|
||||
```yaml
|
||||
base_model: # The path to the model to quantize.
|
||||
quantization:
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
|
||||
|
||||
@@ -39,9 +39,8 @@ you used to train the model:
|
||||
# qat.yml
|
||||
qat:
|
||||
activation_dtype: int8
|
||||
weight_dtype: int8
|
||||
weight_dtype: int4
|
||||
group_size: 256
|
||||
quantize_embedding: true
|
||||
|
||||
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
|
||||
```
|
||||
|
||||
@@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chat_template.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### chat_template.default
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -6,6 +6,8 @@ LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-
|
||||
|
||||
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
Thanks to the team at LiquidAI for giving us early access to prepare for these releases.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
@@ -31,6 +33,14 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
|
||||
```
|
||||
|
||||
**LFM2-MoE**
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
|
||||
# LoRA SFT (1x48GB @ 16.2GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
@@ -45,14 +55,13 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
|
||||
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
|
||||
- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
base_model: LiquidAI/LFM2-350M
|
||||
|
||||
chunked_cross_entropy: true
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
|
||||
59
examples/LiquidAI/lfm2-8b-a1b-lora.yaml
Normal file
59
examples/LiquidAI/lfm2-8b-a1b-lora.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
base_model: LiquidAI/LFM2-8B-A1B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -3,6 +3,9 @@ trust_remote_code: true
|
||||
model_type: AutoModelForImageTextToText
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
|
||||
@@ -7,3 +7,24 @@ techniques. It is a combination of:
|
||||
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
|
||||
|
||||
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
tiled_mlp: true
|
||||
|
||||
# See Sequence Parallelism docs
|
||||
# https://docs.axolotl.ai/docs/sequence_parallelism.html
|
||||
context_parallel_size: int
|
||||
|
||||
plugins:
|
||||
# See Cut Cross Entropy docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# or Liger Kernel docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
# ...
|
||||
|
||||
```
|
||||
|
||||
110
examples/apertus/README.md
Normal file
110
examples/apertus/README.md
Normal file
@@ -0,0 +1,110 @@
|
||||
# Finetune Swiss-AI's Apertus with Axolotl
|
||||
|
||||
[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. (Optional, highly recommended) Install XIELU CUDA
|
||||
|
||||
```bash
|
||||
## Recommended for reduced VRAM and faster speeds
|
||||
|
||||
# Point to CUDA toolkit directory
|
||||
# For those using our Docker image, use the below path.
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/apertus/apertus-8b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 8.7 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`.
|
||||
- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
### XIELU Installation Issues
|
||||
|
||||
#### `ModuleNotFoundError: No module named 'torch'`
|
||||
|
||||
Please check these one by one:
|
||||
- Running in correct environment
|
||||
- Env has PyTorch installed
|
||||
- CUDA toolkit is at `CUDA_HOME`
|
||||
|
||||
If those didn't help, please try the below solutions:
|
||||
|
||||
1. Pass env for CMAKE and try install again:
|
||||
|
||||
```bash
|
||||
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
2. Git clone the repo and manually hardcode python path:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/nickjbrowning/XIELU
|
||||
cd xielu
|
||||
git checkout 59d6031
|
||||
|
||||
cd xielu
|
||||
nano CMakeLists.txt # or vi depending on your preference
|
||||
```
|
||||
|
||||
```diff
|
||||
execute_process(
|
||||
- COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||
+ COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
|
||||
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
|
||||
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
|
||||
)
|
||||
```
|
||||
|
||||
```bash
|
||||
pip3 install . --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: swiss-ai/Apertus-8B-Instruct-2509
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -19,6 +19,9 @@ cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -251,10 +251,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from axolotl.utils import patch_optimized_env\n",
|
||||
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
|
||||
"\n",
|
||||
"# speedup downloads from HF 🤗 and set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||
"patch_optimized_env()"
|
||||
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||
"set_pytorch_cuda_alloc_conf()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
|
||||
@@ -18,7 +18,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -23,7 +23,15 @@ pip3 install timm==1.0.17
|
||||
pip3 install librosa==0.11.0
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
3. Download sample dataset files
|
||||
|
||||
```bash
|
||||
# for text + vision + audio only
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# text only
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
|
||||
|
||||
In October 2025, OpenAI released safeguard models built upon GPT-OSS called [GPT-OSS-Safeguard](https://huggingface.co/collections/openai/gpt-oss-safeguard). They use the same architecture, so the same examples below can be re-used.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
@@ -64,6 +66,16 @@ axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offlo
|
||||
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
||||
```
|
||||
|
||||
### How to set reasoning_effort in template?
|
||||
|
||||
The harmony template has a feature to set the `reasoning_effort` during prompt building. The default is `medium`. If you would like to adjust this, you can add the following to your config:
|
||||
|
||||
```yaml
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: "high" # low | medium | high
|
||||
```
|
||||
|
||||
Currently, this applies globally. There is no method to apply per sample yet. If you are interested in adding this, please feel free to create an Issue to discuss.
|
||||
|
||||
### Inferencing your fine-tuned model
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
base_model: openai/gpt-oss-safeguard-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-safeguard-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||
lora_target_linear: true
|
||||
|
||||
# TODO: not supported for now, see peft#2710
|
||||
#lora_target_parameters: # target the experts in the last two layers
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
@@ -66,6 +66,7 @@ fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
# fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -29,7 +29,7 @@ flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
save_strategy: no
|
||||
torch_compile: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -12,15 +12,6 @@ chat_template: llama3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
@@ -0,0 +1,50 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/opentelemetry-example
|
||||
|
||||
adapter: qlora
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# OpenTelemetry Configuration
|
||||
use_otel_metrics: true
|
||||
otel_metrics_host: "localhost"
|
||||
otel_metrics_port: 8000
|
||||
|
||||
# Disable WandB
|
||||
use_wandb: false
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: false
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -46,7 +46,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -45,7 +45,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Finetune Magistral Small with Axolotl
|
||||
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506) and [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
|
||||
|
||||
## Getting started
|
||||
|
||||
@@ -36,29 +36,17 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Thinking
|
||||
|
||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities. The model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
|
||||
|
||||
Example format:
|
||||
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [{ "type": "text", "text": "..."}]},
|
||||
{"role": "assistant", "content": [{ "type": "thinking", "thinking": "..."}, { "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
### Vision
|
||||
|
||||
Example config: `./magistral-small-think-qlora.yaml`.
|
||||
MistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities.
|
||||
|
||||
The `thinking` section also supports an optional arg `closed: bool` (`True` default) which controls adding the closing `[/THINK]` tag.
|
||||
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
|
||||
|
||||
Limitations:
|
||||
- You cannot mix `content: str` with `content: list[dict]` as the `dataset.load_dataset` may complain about different types for `content` key.
|
||||
- This mode does not work with custom `train_detail` and `training` at the moment.
|
||||
|
||||
### TIPS
|
||||
### Tips
|
||||
|
||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
@@ -89,5 +77,5 @@ In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||
- Add parity to Preference Tuning, RL, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
|
||||
73
examples/magistral/think/README.md
Normal file
73
examples/magistral/think/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Magistral Small Thinking Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
Run the thinking model fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train examples/magistral/think/magistral-small-think-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 19.1 GiB VRAM.
|
||||
|
||||
### Tips
|
||||
|
||||
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
|
||||
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
|
||||
Example format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Options
|
||||
|
||||
The `thinking` section supports an optional `closed` parameter:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Internal reasoning here...",
|
||||
"closed": true // Default: true, controls adding the closing [/THINK] tag
|
||||
}
|
||||
```
|
||||
60
examples/magistral/vision/README.md
Normal file
60
examples/magistral/vision/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Magistral Small Vision Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 17GiB VRAM.
|
||||
|
||||
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
|
||||
|
||||
### Tips
|
||||
|
||||
Key differences from text-only model:
|
||||
- `max_tokens: 131072` for inference
|
||||
- Multi-modal dataset format required
|
||||
- Sample packing not supported
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
@@ -0,0 +1,64 @@
|
||||
base_model: mistralai/Magistral-Small-2509
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
51
examples/mistral/mistral-small/README.md
Normal file
51
examples/mistral/mistral-small/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Mistral Small 3.1/3.2 Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml
|
||||
```
|
||||
|
||||
This config uses about 29.4 GiB VRAM.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
@@ -8,12 +11,12 @@ skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
@@ -36,7 +39,7 @@ wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
@@ -48,8 +51,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
# flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
||||
sdp_attention: true
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
@@ -12,15 +12,6 @@ chat_template: phi_3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -45,8 +45,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
# flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||
sdp_attention: true
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
64
examples/qwen3-next/README.md
Normal file
64
examples/qwen3-next/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Finetune Qwen3-Next with Axolotl
|
||||
|
||||
[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Install Qwen3-Next transformers commit
|
||||
```bash
|
||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||
```
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 45.62 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
68
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
68
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen3-Next-80B-A3B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- linear_attn.in_proj_ba
|
||||
- linear_attn.in_proj_qkvz
|
||||
- linear_attn.out_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert_gate
|
||||
- mlp.gate
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -27,7 +27,14 @@ pip3 install 'mistral_common[audio]==1.8.3'
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
3. Download sample dataset files
|
||||
|
||||
```bash
|
||||
# for text + audio only
|
||||
wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# text only
|
||||
|
||||
@@ -32,7 +32,7 @@ line-length = 88
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "C90", "B"]
|
||||
select = ["E", "F", "W", "C90", "B", "I"]
|
||||
ignore = [
|
||||
"E203", # Whitespace before ':'
|
||||
"E501", # Line too long
|
||||
|
||||
@@ -5,31 +5,30 @@ bitsandbytes==0.47.0
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.6.1
|
||||
liger-kernel==0.6.3
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft>=0.17.0
|
||||
transformers==4.56.1
|
||||
huggingface_hub>=0.36.0
|
||||
peft>=0.17.1
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.10.0
|
||||
datasets==4.0.0
|
||||
transformers==4.57.1
|
||||
accelerate==1.10.1
|
||||
datasets==4.3.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.21.0
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
trl==0.24.0
|
||||
hf_xet==1.2.0
|
||||
kernels>=0.9.0
|
||||
trackio
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==5.41.1
|
||||
gradio==5.49.1
|
||||
|
||||
modal==1.0.2
|
||||
pydantic==2.10.6
|
||||
pydantic>=2.10.6
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
@@ -37,8 +36,8 @@ requests
|
||||
wandb
|
||||
einops
|
||||
colorama
|
||||
numba
|
||||
numpy>=1.24.4,<=2.0.1
|
||||
numba>=0.61.2
|
||||
numpy>=2.2.6
|
||||
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
@@ -51,7 +50,7 @@ python-dotenv==1.0.1
|
||||
|
||||
# remote filesystems
|
||||
s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
gcsfs>=2025.3.0
|
||||
adlfs>=2024.5.0
|
||||
ocifs==1.3.2
|
||||
|
||||
@@ -67,7 +66,7 @@ antlr4-python3-runtime==4.13.2
|
||||
torchao==0.13.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-lgpl==0.0.7
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.3
|
||||
mistral-common==1.8.5
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"'
|
||||
)
|
||||
|
||||
27
setup.py
27
setup.py
@@ -26,7 +26,6 @@ def parse_requirements(extras_require_map):
|
||||
_install_requires.append(line)
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
@@ -34,7 +33,6 @@ def parse_requirements(extras_require_map):
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"liger-kernel",
|
||||
]
|
||||
_install_requires = [
|
||||
@@ -51,7 +49,7 @@ def parse_requirements(extras_require_map):
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.6.0" # default to torch 2.6
|
||||
torch_version = "2.8.0" # default to torch 2.8.0
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
@@ -64,8 +62,15 @@ def parse_requirements(extras_require_map):
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 8):
|
||||
pass
|
||||
if (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
elif (major, minor) >= (2, 8):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
||||
elif (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -74,7 +79,7 @@ def parse_requirements(extras_require_map):
|
||||
extras_require_map.pop("vllm")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.31")
|
||||
extras_require_map["vllm"] = ["vllm>=0.10.0"]
|
||||
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
@@ -87,7 +92,6 @@ def parse_requirements(extras_require_map):
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 4):
|
||||
extras_require_map.pop("vllm")
|
||||
@@ -124,7 +128,6 @@ extras_require = {
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.8.3",
|
||||
"ring-flash-attn>=0.1.7",
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.5",
|
||||
@@ -162,7 +165,13 @@ extras_require = {
|
||||
"llmcompressor": [
|
||||
"llmcompressor==0.5.1",
|
||||
],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
|
||||
"opentelemetry": [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-prometheus",
|
||||
"prometheus-client",
|
||||
],
|
||||
}
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
|
||||
@@ -4,5 +4,7 @@ import os
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
|
||||
configure_logging()
|
||||
|
||||
@@ -23,7 +23,8 @@ from axolotl.utils.config import (
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.tee import prepare_debug_log
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -227,8 +228,11 @@ def load_cfg(
|
||||
},
|
||||
)
|
||||
|
||||
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
|
||||
# have to wait for cfg.output to be resolved. We could call this earlier if we write
|
||||
# to a temporary file, and then move it later.
|
||||
prepare_debug_log(cfg)
|
||||
prepare_optim_env(cfg)
|
||||
prepare_opinionated_env(cfg)
|
||||
normalize_config(cfg)
|
||||
normalize_cfg_datasets(cfg)
|
||||
setup_wandb_env_vars(cfg)
|
||||
@@ -241,7 +245,6 @@ def load_cfg(
|
||||
for k, v in cfg.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
"config:\n%s",
|
||||
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||
|
||||
@@ -85,9 +85,7 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
unpatch_llama4 = patch_llama4_linearized_modeling()
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model, torch_dtype=torch.bfloat16
|
||||
)
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16)
|
||||
processor = AutoProcessor.from_pretrained(model)
|
||||
processor.save_pretrained(output)
|
||||
|
||||
|
||||
@@ -17,8 +17,6 @@ from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.cli.utils.diffusion import (
|
||||
diffusion_inference,
|
||||
launch_diffusion_gradio_ui,
|
||||
render_html,
|
||||
run_diffusion,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
|
||||
@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
@@ -44,7 +44,7 @@ def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
patch_optimized_env()
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -69,7 +69,7 @@ def do_quantize(
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", torch_dtype=torch_dtype
|
||||
model_path, device_map="auto", dtype=torch_dtype
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
|
||||
@@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
|
||||
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
@@ -59,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
@@ -92,13 +92,14 @@ def ray_train_func(kwargs: dict):
|
||||
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
|
||||
# also renormalize the config now that TorchTrainer has spawned distributed workers
|
||||
cfg = DictDefault(kwargs["cfg"])
|
||||
prepare_optim_env(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
|
||||
resolve_dtype(cfg)
|
||||
|
||||
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
|
||||
if cfg.deepspeed:
|
||||
if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"):
|
||||
cfg.deepspeed = cfg.deepspeed.to_dict()
|
||||
|
||||
# initialize accelerator before model instantiation
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from colorama import Fore, Style
|
||||
|
||||
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
|
||||
|
||||
@@ -12,6 +12,9 @@ MOE_ARCH_BLOCK = {
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
}
|
||||
|
||||
@@ -29,7 +29,11 @@ from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils import (
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_opentelemetry_available,
|
||||
)
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
@@ -134,6 +138,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
|
||||
if self.cfg.save_first_step:
|
||||
callbacks.append(SaveModelOnFirstStepCallback())
|
||||
|
||||
@@ -435,7 +445,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
# don't use the HF gradient checkpointing, manually wrap
|
||||
training_args_kwargs["gradient_checkpointing"] = False
|
||||
training_args_kwargs["activation_offloading"] = True
|
||||
elif self.cfg.gradient_checkpointing:
|
||||
elif self.cfg.gradient_checkpointing is not None:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
@@ -491,6 +501,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"dion_momentum",
|
||||
"dion_rank_fraction",
|
||||
"dion_rank_multiple_of",
|
||||
"dataset_num_proc",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
@@ -514,9 +525,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
# max_length is not used in CausalTrainer
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
Trainer,
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
from trl.trainer.reward_trainer import DataCollatorForPreference
|
||||
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
@@ -28,7 +28,6 @@ from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
@@ -63,12 +62,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.relora:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback())
|
||||
|
||||
# TODO: check if can move to base class
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
@@ -460,7 +453,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorWithFlattening,
|
||||
RewardDataCollatorWithPadding,
|
||||
DataCollatorForPreference,
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
@@ -477,7 +470,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if kwargs and isinstance(kwargs, dict):
|
||||
kwargs.update(collator_cls_and_kwargs[1])
|
||||
elif self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
collator = DataCollatorForPreference
|
||||
tokenizer = collator_args.pop(0)
|
||||
kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
kwargs.pop("padding")
|
||||
elif use_batch_sampler_collator:
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
# supported multipack models, or non-flash-attention llama
|
||||
|
||||
@@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
else:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
@@ -129,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
# Handle when max_prompt_length == max_length from defaults
|
||||
# CPOTrainer requires strictly less than
|
||||
if (
|
||||
training_args_kwargs["max_prompt_length"]
|
||||
== training_args_kwargs["max_length"]
|
||||
):
|
||||
training_args_kwargs["max_prompt_length"] -= 1
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
@@ -144,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.GRPO:
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Mapping
|
||||
|
||||
def chat_message_transform_builder(
|
||||
train_on_inputs=False,
|
||||
conversations_field: str = "conversations",
|
||||
conversations_field: str = "messages",
|
||||
message_field_role: str | list[str] | None = None, # commonly "role"
|
||||
message_field_content: str | list[str] | None = None, # commonly "content"
|
||||
message_field_training: str | list[str] | None = None, # commonly "weight"
|
||||
@@ -20,13 +20,13 @@ def chat_message_transform_builder(
|
||||
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
||||
Defaults to False.
|
||||
conversations_field (str, optional):
|
||||
The field name of the conversations. Defaults to "conversations".
|
||||
The field name of the conversations. Defaults to "messages".
|
||||
message_field_role (str | list[str], optional):
|
||||
The field name of the role. Defaults to "role".
|
||||
The field name of the role.
|
||||
message_field_content (str | list[str], optional):
|
||||
The field name of the message content. Defaults to "content".
|
||||
The field name of the message content.
|
||||
message_field_training (str | list[str], optional):
|
||||
The field name of the train/weight. Defaults to "weight".
|
||||
The field name of the train/weight.
|
||||
|
||||
Returns:
|
||||
Callable:
|
||||
|
||||
@@ -225,17 +225,6 @@ class AxolotlTrainer(
|
||||
|
||||
data_collator = self.data_collator if is_training else self.eval_data_collator
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
if isinstance(dataset, datasets.Dataset):
|
||||
if is_training:
|
||||
if not self.args.sample_packing or self.args.pretraining:
|
||||
@@ -294,6 +283,18 @@ class AxolotlTrainer(
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
# Accelerator.free_memory() will destroy the references, so
|
||||
@@ -560,13 +561,6 @@ class AxolotlTrainer(
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
"limit_all_gathers" in self.args.fsdp_config
|
||||
and self.args.fsdp_config["limit_all_gathers"]
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
def additional_accelerator_args(
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -27,7 +27,6 @@ class DPOStrategy:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
|
||||
@@ -52,6 +52,7 @@ class GRPOStrategy:
|
||||
if trl.vllm_mode:
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user