Compare commits
37 Commits
fix/hpc-ro
...
vendor-moe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd85358543 | ||
|
|
55d98db0d0 | ||
|
|
d90ade3b1b | ||
|
|
824a641cee | ||
|
|
e003a05177 | ||
|
|
91393c4dc8 | ||
|
|
d578c53603 | ||
|
|
4db7a21ff7 | ||
|
|
3b2e05c563 | ||
|
|
1037ca3a97 | ||
|
|
6369dcd7b8 | ||
|
|
a81612305c | ||
|
|
d0da67eb17 | ||
|
|
8a1f5ae940 | ||
|
|
146ca48cba | ||
|
|
fd312f6058 | ||
|
|
ab8fa56b16 | ||
|
|
1640cd4006 | ||
|
|
3277d44d71 | ||
|
|
d3e1b0ef1a | ||
|
|
5b97633faa | ||
|
|
94cbc6d42d | ||
|
|
493616fc3d | ||
|
|
d2b25c7327 | ||
|
|
b670c45276 | ||
|
|
61faf4cbe4 | ||
|
|
8d8fa834a2 | ||
|
|
9d69c6fb3e | ||
|
|
92f2f6e73c | ||
|
|
e5d2aebe16 | ||
|
|
4ab9e3f58b | ||
|
|
5788832812 | ||
|
|
db782430f8 | ||
|
|
5c74edeefe | ||
|
|
18269ee6a9 | ||
|
|
6a45d804f9 | ||
|
|
95e607574a |
49
.github/workflows/base.yml
vendored
49
.github/workflows/base.yml
vendored
@@ -25,6 +25,20 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
@@ -53,20 +67,6 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
@@ -122,6 +122,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
@@ -143,20 +150,6 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
15
.github/workflows/main.yml
vendored
15
.github/workflows/main.yml
vendored
@@ -15,6 +15,11 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -83,6 +88,11 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -152,6 +162,11 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
|
||||
14
.github/workflows/multi-gpu-e2e.yml
vendored
14
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,6 +26,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -40,13 +47,6 @@ jobs:
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
|
||||
16
.github/workflows/nightlies.yml
vendored
16
.github/workflows/nightlies.yml
vendored
@@ -15,12 +15,12 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -68,12 +68,12 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
2
.github/workflows/precommit-autoupdate.yml
vendored
2
.github/workflows/precommit-autoupdate.yml
vendored
@@ -2,7 +2,7 @@ name: Pre-commit auto-update
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 1 * *' # Run monthly
|
||||
- cron: '0 0 * * 0' # Run weekly
|
||||
workflow_dispatch: # Manual kickoff
|
||||
|
||||
jobs:
|
||||
|
||||
10
.github/workflows/tests-nightly.yml
vendored
10
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.7.1", "2.8.0"]
|
||||
pytorch_version: ["2.6.0", "2.7.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -102,14 +102,14 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
|
||||
44
.github/workflows/tests.yml
vendored
44
.github/workflows/tests.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -81,12 +81,12 @@ jobs:
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-cache-dir --no-build-isolation -U -e .
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
@@ -130,7 +130,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -152,17 +152,17 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
|
||||
pip3 install --no-build-isolation dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
@@ -231,10 +231,16 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
@@ -283,15 +289,15 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
# - cuda: 128
|
||||
# cuda_version: 12.8.1
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.7.1
|
||||
# num_gpus: 1
|
||||
# axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -299,12 +305,6 @@ jobs:
|
||||
num_gpus: 1
|
||||
gpu_type: "B200"
|
||||
axolotl_extras: fbgemm-gpu
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -11,13 +11,13 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.3
|
||||
rev: v0.12.12
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
rev: v1.17.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
|
||||
@@ -73,7 +73,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.7.1
|
||||
- PyTorch ≥2.6.0
|
||||
|
||||
### Google Colab
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN uv pip install torchvision
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
@@ -9,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
ENV AXOLOTL_DATASET_PROCESSES="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -65,13 +65,8 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
sp_env = os.environ.copy()
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
try:
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
print(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
return exit_code
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Command '{cmd}' failed with exception {e}")
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
|
||||
exit(exit_code)
|
||||
|
||||
@@ -13,7 +13,7 @@ datasets:
|
||||
val_set_size: 0
|
||||
output_dir: temp_debug/axolotl_outputs/model
|
||||
dataset_prepared_path: temp_debug/axolotl_outputs/data
|
||||
dataset_num_proc: 1
|
||||
dataset_processes: 1
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
@@ -24,35 +24,29 @@ RUN apt-get update \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||
python3 -m pip cache purge
|
||||
|
||||
RUN if [ "$CUDA" != "130" ] ; then \
|
||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||
python3 -m pip cache purge; \
|
||||
fi
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||
fi
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="next"
|
||||
@@ -19,12 +19,12 @@ RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="nightly"
|
||||
@@ -19,14 +19,14 @@ RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -30,13 +30,7 @@ RUN uv venv --no-project --relocatable axolotl-venv
|
||||
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
|
||||
@@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
|
||||
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
|
||||
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
|
||||
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
|
||||
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
|
||||
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
|
||||
|
||||
```yaml
|
||||
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
||||
// The flags below simplify debugging by overriding the axolotl config
|
||||
// with the debugging tips above. Modify as needed.
|
||||
"--dataset_num_proc=1", // limits data preprocessing to one process
|
||||
"--dataset_processes=1", // limits data preprocessing to one process
|
||||
"--max_steps=1", // limits training to just one step
|
||||
"--batch_size=1", // minimizes batch size
|
||||
"--micro_batch_size=1", // minimizes batch size
|
||||
|
||||
@@ -63,14 +63,6 @@ description: Frequently asked questions
|
||||
|
||||
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
||||
|
||||
**Q: Can we mix text and text+image datasets for VLM training?**
|
||||
|
||||
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
|
||||
|
||||
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
|
||||
|
||||
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "FSDP + QLoRA"
|
||||
title: "FDSP + QLoRA"
|
||||
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
|
||||
format:
|
||||
html:
|
||||
@@ -23,12 +23,6 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Enabling Swap for FSDP2
|
||||
|
||||
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
|
||||
|
||||
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
|
||||
|
||||
## Example Config
|
||||
|
||||
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
|
||||
|
||||
@@ -5,11 +5,10 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU
|
||||
and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom
|
||||
autograd functions. Our goal was to leverage operator fusion and tensor re-use in order
|
||||
to improve speed and reduce memory usage during the forward and backward passes of
|
||||
these calculations.
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
@@ -132,5 +131,6 @@ computation path.
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Additional operator fusions
|
||||
|
||||
@@ -27,9 +27,3 @@ learning_rate: 2e-5
|
||||
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||
self attention `q_proj` module.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
|
||||
|
||||
:::
|
||||
|
||||
@@ -88,7 +88,6 @@ fsdp_sync_module_states | **REMOVED**
|
||||
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
|
||||
fsdp_state_dict_type | state_dict_type
|
||||
fsdp_use_orig_params | **REMOVED**
|
||||
fsdp_activation_checkpointing | activation_checkpointing
|
||||
|
||||
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
|
||||
if you were using the following FSDP1 config:
|
||||
|
||||
@@ -56,14 +56,10 @@ image_resize_algorithm: bilinear
|
||||
|
||||
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||
|
||||
::: {.callout-tip}
|
||||
::: {.callout-warning}
|
||||
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.
|
||||
:::
|
||||
|
||||
### Mllama {#sec-mllama}
|
||||
|
||||
```yaml
|
||||
@@ -172,14 +168,6 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### Qwen3-VL {#sec-qwen3-vl}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### SmolVLM2 {#sec-smolvlm2}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
@@ -219,21 +219,6 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chat_template.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### chat_template.default
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -6,8 +6,6 @@ LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-
|
||||
|
||||
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
Thanks to the team at LiquidAI for giving us early access to prepare for these releases.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
@@ -33,14 +31,6 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
|
||||
```
|
||||
|
||||
**LFM2-MoE**
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
|
||||
# LoRA SFT (1x48GB @ 16.2GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
@@ -55,13 +45,14 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
|
||||
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
|
||||
- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
base_model: LiquidAI/LFM2-350M
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
chunked_cross_entropy: true
|
||||
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
base_model: LiquidAI/LFM2-8B-A1B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -3,9 +3,6 @@ trust_remote_code: true
|
||||
model_type: AutoModelForImageTextToText
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
|
||||
|
||||
In October 2025, OpenAI released safeguard models built upon GPT-OSS called [GPT-OSS-Safeguard](https://huggingface.co/collections/openai/gpt-oss-safeguard). They use the same architecture, so the same examples below can be re-used.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
@@ -66,16 +64,6 @@ axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offlo
|
||||
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
||||
```
|
||||
|
||||
### How to set reasoning_effort in template?
|
||||
|
||||
The harmony template has a feature to set the `reasoning_effort` during prompt building. The default is `medium`. If you would like to adjust this, you can add the following to your config:
|
||||
|
||||
```yaml
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: "high" # low | medium | high
|
||||
```
|
||||
|
||||
Currently, this applies globally. There is no method to apply per sample yet. If you are interested in adding this, please feel free to create an Issue to discuss.
|
||||
|
||||
### Inferencing your fine-tuned model
|
||||
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: openai/gpt-oss-safeguard-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-safeguard-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||
lora_target_linear: true
|
||||
|
||||
# TODO: not supported for now, see peft#2710
|
||||
#lora_target_parameters: # target the experts in the last two layers
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
@@ -66,7 +66,6 @@ fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
# fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -29,7 +29,7 @@ flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
save_strategy: no
|
||||
|
||||
torch_compile: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/opentelemetry-example
|
||||
|
||||
adapter: qlora
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# OpenTelemetry Configuration
|
||||
use_otel_metrics: true
|
||||
otel_metrics_host: "localhost"
|
||||
otel_metrics_port: 8000
|
||||
|
||||
# Disable WandB
|
||||
use_wandb: false
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: false
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
||||
Run the thinking model fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train examples/magistral/think/magistral-small-think-qlora.yaml
|
||||
axolotl train magistral-small-think-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 19.1 GiB VRAM.
|
||||
|
||||
@@ -21,7 +21,7 @@ Before starting, ensure you have:
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml
|
||||
axolotl train magistral-small-vision-24B-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 17GiB VRAM.
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
# Mistral Small 3.1/3.2 Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml
|
||||
```
|
||||
|
||||
This config uses about 29.4 GiB VRAM.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
@@ -39,7 +39,7 @@ wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -5,30 +5,31 @@ bitsandbytes==0.47.0
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
liger-kernel==0.6.3
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.6.1
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.36.0
|
||||
peft>=0.17.1
|
||||
huggingface_hub>=0.33.0
|
||||
peft>=0.17.0
|
||||
transformers==4.56.1
|
||||
tokenizers>=0.21.1
|
||||
transformers==4.57.1
|
||||
accelerate==1.10.1
|
||||
datasets==4.3.0
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.24.0
|
||||
hf_xet==1.2.0
|
||||
kernels>=0.9.0
|
||||
trl==0.23.0
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
trackio
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==5.49.1
|
||||
gradio==5.41.1
|
||||
|
||||
modal==1.0.2
|
||||
pydantic>=2.10.6
|
||||
pydantic==2.10.6
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
@@ -36,8 +37,8 @@ requests
|
||||
wandb
|
||||
einops
|
||||
colorama
|
||||
numba>=0.61.2
|
||||
numpy>=2.2.6
|
||||
numba
|
||||
numpy>=1.24.4,<=2.0.1
|
||||
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
@@ -50,7 +51,7 @@ python-dotenv==1.0.1
|
||||
|
||||
# remote filesystems
|
||||
s3fs>=2024.5.0
|
||||
gcsfs>=2025.3.0
|
||||
gcsfs>=2024.5.0
|
||||
adlfs>=2024.5.0
|
||||
ocifs==1.3.2
|
||||
|
||||
@@ -66,7 +67,7 @@ antlr4-python3-runtime==4.13.2
|
||||
torchao==0.13.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.7
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.5
|
||||
|
||||
1
scripts/__init__.py
Normal file
1
scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utility scripts package."""
|
||||
5
scripts/benchmarks/__init__.py
Normal file
5
scripts/benchmarks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Benchmark helpers."""
|
||||
|
||||
from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
|
||||
|
||||
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]
|
||||
100
scripts/benchmarks/build_deepseek_v3_8b.py
Executable file
100
scripts/benchmarks/build_deepseek_v3_8b.py
Executable file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Instantiate a ~8.3B DeepSeek-V3 MoE model with random weights.
|
||||
|
||||
Run this on a GPU-equipped machine (e.g. 1× NVL H100) so the dense
|
||||
initialization completes quickly:
|
||||
|
||||
python scripts/benchmarks/build_deepseek_v3_8b.py --output deepseek-v3-8b-moe
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import DeepseekV3Config, DeepseekV3ForCausalLM
|
||||
|
||||
DTYPE_MAP = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
}
|
||||
|
||||
|
||||
def build_config() -> DeepseekV3Config:
|
||||
"""Return a DeepSeek V3 configuration totaling ~8.3B parameters."""
|
||||
|
||||
return DeepseekV3Config(
|
||||
vocab_size=32_000,
|
||||
hidden_size=3_072,
|
||||
intermediate_size=8_192,
|
||||
moe_intermediate_size=2_560,
|
||||
num_hidden_layers=20,
|
||||
num_attention_heads=24,
|
||||
num_key_value_heads=24,
|
||||
n_routed_experts=18,
|
||||
num_experts_per_tok=4,
|
||||
n_group=6,
|
||||
topk_group=4,
|
||||
kv_lora_rank=192,
|
||||
q_lora_rank=384,
|
||||
max_position_embeddings=2_048,
|
||||
rope_theta=10_000.0,
|
||||
rope_interleave=True,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
attention_dropout=0.0,
|
||||
attention_bias=False,
|
||||
n_shared_experts=1,
|
||||
routed_scaling_factor=2.5,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory to save the generated model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bfloat16",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
help="Storage dtype for the checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Torch RNG seed for reproducibility",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
output_dir = args.output
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config = build_config()
|
||||
model = DeepseekV3ForCausalLM(config)
|
||||
|
||||
dtype = DTYPE_MAP[args.dtype]
|
||||
model.to(dtype=dtype)
|
||||
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
print(f"Initialized DeepSeek-V3 MoE with {param_count / 1e9:.3f}B parameters")
|
||||
|
||||
model.save_pretrained(output_dir, safe_serialization=True)
|
||||
config.save_pretrained(output_dir)
|
||||
print(f"Saved model and config to {output_dir.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
190
scripts/benchmarks/deepseek_v3_group_gemm_table.py
Normal file
190
scripts/benchmarks/deepseek_v3_group_gemm_table.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python
|
||||
"""Reproduce TorchTitan CG GEMM timings for selected problem sizes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else:
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.kernels.moe import (
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scenario:
|
||||
num_groups: int
|
||||
m: int
|
||||
n: int
|
||||
k: int
|
||||
|
||||
|
||||
SCENARIOS: tuple[Scenario, ...] = (
|
||||
Scenario(num_groups=4, m=8192, n=4096, k=7168),
|
||||
Scenario(num_groups=4, m=8192, n=7168, k=2048),
|
||||
Scenario(num_groups=8, m=4096, n=4096, k=7168),
|
||||
Scenario(num_groups=8, m=4096, n=7168, k=2048),
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--device", default="cuda", choices=["cuda"], help="Execution device"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bf16",
|
||||
choices=["bf16", "fp16", "fp32"],
|
||||
help="Computation dtype",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="GROUP_SIZE_M expected by the kernel",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def pick_dtype(name: str) -> torch.dtype:
|
||||
return {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[name]
|
||||
|
||||
|
||||
def make_indices(
|
||||
num_groups: int, group_size: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
|
||||
return indices.repeat_interleave(group_size)
|
||||
|
||||
|
||||
def timed_call(fn, *args, warmup: int, iters: int) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(iters):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
return (time.perf_counter() - start) * 1000.0 / iters
|
||||
|
||||
|
||||
def run_scenario(
|
||||
scenario: Scenario,
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
warmup: int,
|
||||
iters: int,
|
||||
group_size_m: int,
|
||||
) -> dict:
|
||||
if scenario.m % scenario.num_groups != 0:
|
||||
raise ValueError(
|
||||
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
|
||||
)
|
||||
group_size = scenario.m // scenario.num_groups
|
||||
if group_size % group_size_m != 0:
|
||||
raise ValueError(
|
||||
f"Group size {group_size} must be a multiple of GROUP_SIZE_M ({group_size_m}) for the Triton kernel"
|
||||
)
|
||||
|
||||
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
|
||||
weights = torch.randn(
|
||||
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
|
||||
)
|
||||
indices = make_indices(scenario.num_groups, group_size, device)
|
||||
|
||||
def persistent():
|
||||
return cg_grouped_gemm_forward(inputs, weights, indices, group_size_m)
|
||||
|
||||
def baseline():
|
||||
return cg_grouped_gemm_forward_dynamic(inputs, weights, indices, group_size_m)
|
||||
|
||||
persistent_ms = timed_call(persistent, warmup=warmup, iters=iters)
|
||||
baseline_ms = timed_call(baseline, warmup=warmup, iters=iters)
|
||||
|
||||
return {
|
||||
"scenario": scenario,
|
||||
"persistent_ms": persistent_ms,
|
||||
"baseline_ms": baseline_ms,
|
||||
"speedup": baseline_ms / persistent_ms if persistent_ms > 0 else float("nan"),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - utility script
|
||||
args = parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if args.device != "cuda" or not torch.cuda.is_available():
|
||||
raise SystemExit("CUDA device required for this benchmark")
|
||||
|
||||
dtype = pick_dtype(args.dtype)
|
||||
device = torch.device(args.device)
|
||||
|
||||
print(
|
||||
f"device={device} dtype={dtype} warmup={args.warmup} iters={args.iters} group_size={args.group_size}"
|
||||
)
|
||||
print(
|
||||
f"{'groups':>7} {'m':>7} {'n':>7} {'k':>7} {'persistent':>12} {'baseline':>12} {'speedup':>8}"
|
||||
)
|
||||
for result in run_all(
|
||||
SCENARIOS,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
warmup=args.warmup,
|
||||
iters=args.iters,
|
||||
group_size_m=args.group_size,
|
||||
):
|
||||
scen = result["scenario"]
|
||||
print(
|
||||
f"{scen.num_groups:>7} {scen.m:>7} {scen.n:>7} {scen.k:>7}"
|
||||
f" {result['persistent_ms']:>11.3f} ms {result['baseline_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||
)
|
||||
|
||||
|
||||
def run_all(
|
||||
scenarios: Iterable[Scenario],
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
warmup: int,
|
||||
iters: int,
|
||||
group_size_m: int,
|
||||
) -> Iterable[dict]:
|
||||
for scenario in scenarios:
|
||||
yield run_scenario(
|
||||
scenario,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
warmup=warmup,
|
||||
iters=iters,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
301
scripts/benchmarks/deepseek_v3_moe.py
Normal file
301
scripts/benchmarks/deepseek_v3_moe.py
Normal file
@@ -0,0 +1,301 @@
|
||||
#!/usr/bin/env python
|
||||
# mypy: ignore-errors
|
||||
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
|
||||
DeepseekV3Config,
|
||||
)
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
except ImportError as exc: # pragma: no cover - utility script
|
||||
raise SystemExit(
|
||||
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
|
||||
) from exc
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
|
||||
|
||||
ACCURACY_TOLERANCE = 5e-3
|
||||
|
||||
DTYPE_MAP = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--batch", type=int, default=8, help="batch size")
|
||||
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
|
||||
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
|
||||
parser.add_argument(
|
||||
"--moe-intermediate-size",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="MoE intermediate projection size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-experts",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of routed experts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of experts per token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--groups",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Router groups (must divide n-experts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
default="bf16",
|
||||
help="Computation dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
choices=["auto", "cpu", "cuda"],
|
||||
help="Execution device",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--uniform-routing",
|
||||
action="store_true",
|
||||
help="Override router to distribute tokens evenly across experts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="GROUP_SIZE_M used by the Triton kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["cg", "mg"],
|
||||
default="mg",
|
||||
help="MoE kernel backend to benchmark",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def resolve_device(requested: str) -> torch.device:
|
||||
if requested == "auto":
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(requested)
|
||||
|
||||
|
||||
def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
|
||||
config = DeepseekV3Config(
|
||||
hidden_size=args.hidden_size,
|
||||
intermediate_size=args.moe_intermediate_size,
|
||||
moe_intermediate_size=args.moe_intermediate_size,
|
||||
n_routed_experts=args.n_experts,
|
||||
num_experts_per_tok=args.top_k,
|
||||
n_group=args.groups,
|
||||
topk_group=max(1, min(args.groups, args.top_k)),
|
||||
n_shared_experts=1,
|
||||
)
|
||||
module = DeepseekV3MoE(config)
|
||||
module.eval()
|
||||
return module
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def timed_forward(
|
||||
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
|
||||
) -> float:
|
||||
for _ in range(warmup):
|
||||
module(inputs)
|
||||
if inputs.is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(iters):
|
||||
module(inputs)
|
||||
if inputs.is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
return (elapsed / iters) * 1000.0
|
||||
|
||||
|
||||
def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = resolve_device(args.device)
|
||||
dtype = DTYPE_MAP[args.dtype]
|
||||
|
||||
if args.n_experts % args.groups != 0:
|
||||
raise SystemExit("n-experts must be divisible by groups")
|
||||
if args.top_k > args.n_experts:
|
||||
raise SystemExit("top-k cannot exceed number of experts")
|
||||
|
||||
if device.type == "cuda" and not torch.cuda.is_available():
|
||||
raise SystemExit("CUDA requested but not available")
|
||||
|
||||
baseline_module = build_module(args)
|
||||
original_moe = getattr(
|
||||
DeepseekV3MoE,
|
||||
"_axolotl_triton_original_moe",
|
||||
DeepseekV3MoE.moe,
|
||||
)
|
||||
baseline_module.moe = MethodType(original_moe, baseline_module)
|
||||
state_dict = baseline_module.state_dict()
|
||||
|
||||
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
|
||||
patched_module = build_module(args)
|
||||
patched_module.load_state_dict(state_dict)
|
||||
|
||||
baseline_module.to(device=device, dtype=dtype)
|
||||
patched_module.to(device=device, dtype=dtype)
|
||||
|
||||
tokens = args.batch * args.seq_len
|
||||
routed_tokens = tokens * args.top_k
|
||||
avg_tokens_per_expert = routed_tokens / args.n_experts
|
||||
|
||||
inputs = torch.randn(
|
||||
args.batch,
|
||||
args.seq_len,
|
||||
args.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
flat_inputs = inputs.view(-1, args.hidden_size)
|
||||
if args.uniform_routing:
|
||||
total_assignments = flat_inputs.size(0) * args.top_k
|
||||
base = total_assignments // args.n_experts
|
||||
remainder = total_assignments % args.n_experts
|
||||
counts = torch.full(
|
||||
(args.n_experts,),
|
||||
base,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
if remainder:
|
||||
counts[:remainder] += 1
|
||||
assignments = torch.repeat_interleave(
|
||||
torch.arange(args.n_experts, device=device), counts
|
||||
)
|
||||
assignments = assignments[torch.randperm(assignments.size(0))]
|
||||
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
|
||||
else:
|
||||
topk_idx, _ = patched_module.gate(flat_inputs)
|
||||
|
||||
tokens_per_expert = torch.bincount(
|
||||
topk_idx.reshape(-1), minlength=args.n_experts
|
||||
)
|
||||
min_tokens = int(tokens_per_expert.min().item())
|
||||
max_tokens = int(tokens_per_expert.max().item())
|
||||
|
||||
if args.uniform_routing:
|
||||
weights = torch.full(
|
||||
topk_idx.shape,
|
||||
1.0 / args.top_k,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
def _uniform_gate(self, hidden_states):
|
||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
token_count = flat.shape[0]
|
||||
return topk_idx[:token_count], weights[:token_count]
|
||||
|
||||
patched_module.gate.forward = _uniform_gate.__get__(
|
||||
patched_module.gate, patched_module.gate.__class__
|
||||
)
|
||||
baseline_module.gate.forward = _uniform_gate.__get__(
|
||||
baseline_module.gate, baseline_module.gate.__class__
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_output = baseline_module(inputs)
|
||||
patched_output = patched_module(inputs)
|
||||
max_diff = (ref_output - patched_output).abs().max().item()
|
||||
|
||||
baseline_vram = patched_vram = None
|
||||
if device.type == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
baseline_vram = torch.cuda.max_memory_allocated(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
patched_vram = torch.cuda.max_memory_allocated(device)
|
||||
|
||||
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
|
||||
|
||||
return {
|
||||
"device": device,
|
||||
"backend": args.backend,
|
||||
"dtype": dtype,
|
||||
"baseline_ms": baseline_ms,
|
||||
"patched_ms": patched_ms,
|
||||
"speedup": speedup,
|
||||
"max_diff": max_diff,
|
||||
"routed_tokens": routed_tokens,
|
||||
"avg_tokens": avg_tokens_per_expert,
|
||||
"min_tokens": min_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"baseline_vram": baseline_vram,
|
||||
"patched_vram": patched_vram,
|
||||
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
args = parse_args()
|
||||
result = benchmark_deepseek_v3(args)
|
||||
|
||||
print(
|
||||
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
||||
)
|
||||
print(
|
||||
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
||||
)
|
||||
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
|
||||
if result["baseline_vram"] is not None:
|
||||
print(
|
||||
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
|
||||
)
|
||||
print(
|
||||
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
|
||||
)
|
||||
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
|
||||
if not result["accuracy_ok"]:
|
||||
raise RuntimeError(
|
||||
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
275
scripts/benchmarks/deepseek_v3_moe_sweep.py
Normal file
275
scripts/benchmarks/deepseek_v3_moe_sweep.py
Normal file
@@ -0,0 +1,275 @@
|
||||
#!/usr/bin/env python
|
||||
# mypy: ignore-errors
|
||||
"""Sweep a set of DeepSeek V3 MoE benchmark configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
|
||||
ACCURACY_TOLERANCE,
|
||||
DTYPE_MAP,
|
||||
benchmark_deepseek_v3,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
default="bf16",
|
||||
help="Computation dtype for all benchmarks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
choices=["auto", "cpu", "cuda"],
|
||||
help="Execution device",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
help="Override GROUP_SIZE_M for every configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
default="mg",
|
||||
help="Comma separated list of backends to benchmark (subset of cg,mg)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-uniform-routing",
|
||||
action="store_true",
|
||||
help="Disable uniform routing for every configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-mixtral-long",
|
||||
action="store_true",
|
||||
help="Add an 8×8192 Mixtral-style run to the sweep",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
help="Optional CSV file to store results",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_namespace(
|
||||
base: dict, args: argparse.Namespace, backend: str
|
||||
) -> SimpleNamespace:
|
||||
combined = dict(base)
|
||||
combined.update(
|
||||
{
|
||||
"dtype": args.dtype,
|
||||
"device": args.device,
|
||||
"backend": backend,
|
||||
"warmup": args.warmup,
|
||||
"iters": args.iters,
|
||||
"seed": args.seed,
|
||||
"uniform_routing": not args.no_uniform_routing,
|
||||
}
|
||||
)
|
||||
if args.group_size is not None:
|
||||
combined["group_size"] = args.group_size
|
||||
return SimpleNamespace(**combined)
|
||||
|
||||
|
||||
ARCHETYPES = (
|
||||
(
|
||||
"mixtral",
|
||||
{
|
||||
"hidden_size": 4096,
|
||||
"moe_intermediate_size": 14336,
|
||||
"n_experts": 8,
|
||||
"top_k": 2,
|
||||
"groups": 1,
|
||||
"group_size": 128,
|
||||
},
|
||||
[(4, 2048), (8, 4096)],
|
||||
),
|
||||
(
|
||||
"qwen",
|
||||
{
|
||||
"hidden_size": 6144,
|
||||
"moe_intermediate_size": 24576,
|
||||
"n_experts": 16,
|
||||
"top_k": 4,
|
||||
"groups": 8,
|
||||
"group_size": 128,
|
||||
},
|
||||
[(4, 4096), (8, 8192)],
|
||||
),
|
||||
(
|
||||
"deepseek_v3",
|
||||
{
|
||||
"hidden_size": 12288,
|
||||
"moe_intermediate_size": 49152,
|
||||
"n_experts": 128,
|
||||
"top_k": 8,
|
||||
"groups": 16,
|
||||
"group_size": 128,
|
||||
},
|
||||
[(4, 4096), (8, 8192)],
|
||||
),
|
||||
)
|
||||
|
||||
MIXTRAL_LONG_SHAPES = [(8, 8192)]
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - utility script
|
||||
args = parse_args()
|
||||
|
||||
grid = []
|
||||
for label, base_cfg, shapes in ARCHETYPES:
|
||||
for batch, seq_len in shapes:
|
||||
cfg = {
|
||||
"label": label,
|
||||
"batch": batch,
|
||||
"seq_len": seq_len,
|
||||
**base_cfg,
|
||||
}
|
||||
if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]:
|
||||
continue
|
||||
grid.append(cfg)
|
||||
|
||||
if args.include_mixtral_long:
|
||||
base_cfg = ARCHETYPES[0][1]
|
||||
for batch, seq_len in MIXTRAL_LONG_SHAPES:
|
||||
grid.append(
|
||||
{
|
||||
"label": "mixtral_long",
|
||||
"batch": batch,
|
||||
"seq_len": seq_len,
|
||||
**base_cfg,
|
||||
}
|
||||
)
|
||||
|
||||
if not grid:
|
||||
raise SystemExit("No valid parameter combinations produced")
|
||||
|
||||
header = (
|
||||
"model",
|
||||
"batch",
|
||||
"seq_len",
|
||||
"hidden_size",
|
||||
"moe_intermediate",
|
||||
"n_experts",
|
||||
"top_k",
|
||||
"groups",
|
||||
"backend",
|
||||
"baseline_ms",
|
||||
"patched_ms",
|
||||
"speedup",
|
||||
"baseline_vram_mib",
|
||||
"patched_vram_mib",
|
||||
"min_tokens",
|
||||
"max_tokens",
|
||||
"max_diff",
|
||||
"accuracy_ok",
|
||||
)
|
||||
rows = []
|
||||
|
||||
raw_backends = [
|
||||
token.strip() for token in args.backends.split(",") if token.strip()
|
||||
]
|
||||
if not raw_backends:
|
||||
raw_backends = ["mg"]
|
||||
valid_backends = []
|
||||
for backend in raw_backends:
|
||||
if backend not in {"cg", "mg"}:
|
||||
raise SystemExit(f"Unsupported backend '{backend}' requested")
|
||||
if backend not in valid_backends:
|
||||
valid_backends.append(backend)
|
||||
|
||||
uniform_flag = not args.no_uniform_routing
|
||||
print(
|
||||
f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}"
|
||||
)
|
||||
print(
|
||||
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
||||
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'diff':>10} {'acc':>5}"
|
||||
)
|
||||
|
||||
for cfg in grid:
|
||||
for backend in valid_backends:
|
||||
ns = make_namespace(cfg, args, backend)
|
||||
result = benchmark_deepseek_v3(ns)
|
||||
baseline_vram_mib = (
|
||||
result["baseline_vram"] / (1024**2)
|
||||
if result["baseline_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
patched_vram_mib = (
|
||||
result["patched_vram"] / (1024**2)
|
||||
if result["patched_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
rows.append(
|
||||
(
|
||||
cfg["label"],
|
||||
cfg["batch"],
|
||||
cfg["seq_len"],
|
||||
cfg["hidden_size"],
|
||||
cfg["moe_intermediate_size"],
|
||||
cfg["n_experts"],
|
||||
cfg["top_k"],
|
||||
cfg["groups"],
|
||||
backend,
|
||||
result["baseline_ms"],
|
||||
result["patched_ms"],
|
||||
result["speedup"],
|
||||
baseline_vram_mib,
|
||||
patched_vram_mib,
|
||||
result["min_tokens"],
|
||||
result["max_tokens"],
|
||||
result["max_diff"],
|
||||
result["accuracy_ok"],
|
||||
)
|
||||
)
|
||||
status = "OK" if result["accuracy_ok"] else "FAIL"
|
||||
print(
|
||||
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
|
||||
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {result['max_diff']:>10.3e} {status:>5}"
|
||||
)
|
||||
if not result["accuracy_ok"]:
|
||||
LOG.warning(
|
||||
"Accuracy tolerance exceeded for %s backend=%s: diff=%.3e (> %.1e)",
|
||||
cfg["label"],
|
||||
backend,
|
||||
result["max_diff"],
|
||||
ACCURACY_TOLERANCE,
|
||||
)
|
||||
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with args.output.open("w", newline="") as fp:
|
||||
writer = csv.writer(fp)
|
||||
writer.writerow(header)
|
||||
writer.writerows(rows)
|
||||
print(f"Results written to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
|
||||
)
|
||||
|
||||
26
setup.py
26
setup.py
@@ -26,6 +26,7 @@ def parse_requirements(extras_require_map):
|
||||
_install_requires.append(line)
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
@@ -33,6 +34,7 @@ def parse_requirements(extras_require_map):
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"liger-kernel",
|
||||
]
|
||||
_install_requires = [
|
||||
@@ -49,7 +51,7 @@ def parse_requirements(extras_require_map):
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.8.0" # default to torch 2.8.0
|
||||
torch_version = "2.6.0" # default to torch 2.6
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
@@ -62,15 +64,8 @@ def parse_requirements(extras_require_map):
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
elif (major, minor) >= (2, 8):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
||||
if (major, minor) >= (2, 8):
|
||||
pass
|
||||
elif (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -79,7 +74,7 @@ def parse_requirements(extras_require_map):
|
||||
extras_require_map.pop("vllm")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.31")
|
||||
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
||||
extras_require_map["vllm"] = ["vllm>=0.10.0"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
@@ -92,6 +87,7 @@ def parse_requirements(extras_require_map):
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 4):
|
||||
extras_require_map.pop("vllm")
|
||||
@@ -165,13 +161,7 @@ extras_require = {
|
||||
"llmcompressor": [
|
||||
"llmcompressor==0.5.1",
|
||||
],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
|
||||
"opentelemetry": [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-prometheus",
|
||||
"prometheus-client",
|
||||
],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
|
||||
}
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
|
||||
@@ -85,7 +85,9 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
unpatch_llama4 = patch_llama4_linearized_modeling()
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16)
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model, torch_dtype=torch.bfloat16
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model)
|
||||
processor.save_pretrained(output)
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ def do_quantize(
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", dtype=torch_dtype
|
||||
model_path, device_map="auto", torch_dtype=torch_dtype
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
|
||||
@@ -99,7 +99,7 @@ def ray_train_func(kwargs: dict):
|
||||
resolve_dtype(cfg)
|
||||
|
||||
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
|
||||
if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"):
|
||||
if cfg.deepspeed:
|
||||
cfg.deepspeed = cfg.deepspeed.to_dict()
|
||||
|
||||
# initialize accelerator before model instantiation
|
||||
|
||||
@@ -12,9 +12,6 @@ MOE_ARCH_BLOCK = {
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
}
|
||||
|
||||
@@ -29,11 +29,7 @@ from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.utils import (
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_opentelemetry_available,
|
||||
)
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
@@ -138,12 +134,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
|
||||
if self.cfg.save_first_step:
|
||||
callbacks.append(SaveModelOnFirstStepCallback())
|
||||
|
||||
@@ -501,7 +491,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"dion_momentum",
|
||||
"dion_rank_fraction",
|
||||
"dion_rank_multiple_of",
|
||||
"dataset_num_proc",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
@@ -525,6 +514,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
# max_length is not used in CausalTrainer
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
Trainer,
|
||||
)
|
||||
from trl.trainer.reward_trainer import DataCollatorForPreference
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
@@ -28,6 +28,7 @@ from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
@@ -62,6 +63,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.relora:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback())
|
||||
|
||||
# TODO: check if can move to base class
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
@@ -453,7 +460,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorWithFlattening,
|
||||
DataCollatorForPreference,
|
||||
RewardDataCollatorWithPadding,
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
@@ -470,10 +477,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if kwargs and isinstance(kwargs, dict):
|
||||
kwargs.update(collator_cls_and_kwargs[1])
|
||||
elif self.cfg.reward_model:
|
||||
collator = DataCollatorForPreference
|
||||
tokenizer = collator_args.pop(0)
|
||||
kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
kwargs.pop("padding")
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif use_batch_sampler_collator:
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
# supported multipack models, or non-flash-attention llama
|
||||
|
||||
@@ -225,6 +225,17 @@ class AxolotlTrainer(
|
||||
|
||||
data_collator = self.data_collator if is_training else self.eval_data_collator
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
if isinstance(dataset, datasets.Dataset):
|
||||
if is_training:
|
||||
if not self.args.sample_packing or self.args.pretraining:
|
||||
@@ -283,18 +294,6 @@ class AxolotlTrainer(
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
# Accelerator.free_memory() will destroy the references, so
|
||||
@@ -561,6 +560,13 @@ class AxolotlTrainer(
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
"limit_all_gathers" in self.args.fsdp_config
|
||||
and self.args.fsdp_config["limit_all_gathers"]
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
def additional_accelerator_args(
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -52,7 +52,6 @@ class GRPOStrategy:
|
||||
if trl.vllm_mode:
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -31,7 +31,6 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- apertus
|
||||
- arcee
|
||||
- cohere
|
||||
- cohere2
|
||||
@@ -45,22 +44,14 @@ plugins:
|
||||
- glm
|
||||
- glm4
|
||||
- glm4_moe
|
||||
- glm4v
|
||||
- glm4v_moe
|
||||
- gpt_oss
|
||||
- granite
|
||||
- granitemoe
|
||||
- granitemoeshared
|
||||
- granitemoehybrid
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- lfm2
|
||||
- lfm2_moe
|
||||
- lfm2_vl
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- llava
|
||||
- mistral
|
||||
- mistral3
|
||||
- mixtral
|
||||
@@ -74,8 +65,6 @@ plugins:
|
||||
- qwen2_5_vl
|
||||
- qwen3
|
||||
- qwen3_moe
|
||||
- qwen3_vl
|
||||
- qwen3_vl_moe
|
||||
- qwen3_next
|
||||
- smollm3
|
||||
- seed_oss
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -360,7 +360,7 @@ def _diffusion_step(
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
||||
logits = shift_logits_to_input_positions(outputs.logits)
|
||||
logits = outputs.logits
|
||||
|
||||
# Only sample at currently masked positions
|
||||
if current_mask.any():
|
||||
|
||||
@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
input_ids=noisy_batch.long(),
|
||||
attention_mask=bidirectional_mask,
|
||||
)
|
||||
logits = shift_logits_to_input_positions(outputs.logits)
|
||||
logits = outputs.logits
|
||||
|
||||
if masked_indices.sum() > 0:
|
||||
valid_indices = torch.where(masked_indices)
|
||||
|
||||
@@ -157,10 +157,3 @@ def create_bidirectional_attention_mask(
|
||||
|
||||
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||
return bidirectional_mask.unsqueeze(1)
|
||||
|
||||
|
||||
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Align next-token logits with their input token positions for diffusion."""
|
||||
if logits.size(1) <= 1:
|
||||
return logits
|
||||
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
|
||||
@@ -72,9 +72,9 @@ def kldiv_forward_llama_like(
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||
# self._loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
loss = self._loss_function(
|
||||
loss = self.loss_function(
|
||||
self.lm_head.weight,
|
||||
hidden_states,
|
||||
target_token_ids,
|
||||
|
||||
@@ -29,8 +29,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_accepts_loss_kwargs = True
|
||||
|
||||
loss_fn = LigerFusedLinearKLTopKLogprobLoss(
|
||||
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
|
||||
self.args.kd_ce_alpha, # hard label loss
|
||||
self.args.kd_alpha, # kd loss
|
||||
self.args.kd_temperature,
|
||||
@@ -38,14 +37,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||
normalize_topk=self.args.kd_normalize_topk,
|
||||
)
|
||||
target = self.model
|
||||
|
||||
# Unwrap PEFT wrapper
|
||||
if hasattr(target, "get_base_model"):
|
||||
target = target.get_base_model()
|
||||
|
||||
# Set on the actual model instance
|
||||
target._loss_function = loss_fn
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
|
||||
21
src/axolotl/kernels/moe/__init__.py
Normal file
21
src/axolotl/kernels/moe/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Mixture-of-Experts kernel implementations."""
|
||||
|
||||
from .indices import generate_permute_indices
|
||||
from .tt_cg_gemm import (
|
||||
ContiguousGroupedGEMM,
|
||||
ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
"generate_permute_indices",
|
||||
"mg_grouped_gemm",
|
||||
]
|
||||
5
src/axolotl/kernels/moe/indices/__init__.py
Normal file
5
src/axolotl/kernels/moe/indices/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Token permutation utilities for grouped MoE kernels."""
|
||||
|
||||
from .indices import generate_permute_indices
|
||||
|
||||
__all__ = ["generate_permute_indices"]
|
||||
144
src/axolotl/kernels/moe/indices/indices.py
Normal file
144
src/axolotl/kernels/moe/indices/indices.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Vendored token permutation kernels from TorchTitan."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
__all__ = ["generate_permute_indices"]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fill_indices_kernel(
|
||||
tokens_per_expert_group_ptr,
|
||||
start_index_values_ptr,
|
||||
write_offsets_ptr,
|
||||
output_ptr,
|
||||
experts_per_rank: tl.constexpr,
|
||||
num_ranks: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_programs = tl.num_programs(axis=0)
|
||||
|
||||
for expert_id in range(pid, experts_per_rank, num_programs):
|
||||
write_offset = tl.load(write_offsets_ptr + expert_id)
|
||||
|
||||
for r in range(num_ranks):
|
||||
idx = r * experts_per_rank + expert_id
|
||||
|
||||
start_index = tl.load(start_index_values_ptr + idx)
|
||||
length = tl.load(tokens_per_expert_group_ptr + idx)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for chunk_start in range(0, length, BLOCK_SIZE):
|
||||
chunk_offsets = chunk_start + offsets
|
||||
mask = chunk_offsets < length
|
||||
values = start_index + chunk_offsets
|
||||
dest_indices = write_offset + chunk_offsets
|
||||
tl.store(output_ptr + dest_indices, values, mask=mask)
|
||||
|
||||
write_offset += length
|
||||
|
||||
|
||||
def fill_indices_wrapper(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
start_index_values: torch.Tensor,
|
||||
write_offsets: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
block_size: int = 128,
|
||||
max_blocks: int = 1024,
|
||||
):
|
||||
permuted_indices = torch.full(
|
||||
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
|
||||
)
|
||||
num_blocks = min(experts_per_rank, max_blocks)
|
||||
grid = (num_blocks,)
|
||||
_fill_indices_kernel[grid](
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
permuted_indices,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
return permuted_indices
|
||||
|
||||
|
||||
def fill_indices_cpu(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
start_index_values: torch.Tensor,
|
||||
write_offsets: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
):
|
||||
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
|
||||
for expert_id in range(experts_per_rank):
|
||||
write_start = write_offsets[expert_id].item()
|
||||
for r in range(num_ranks):
|
||||
idx = r * experts_per_rank + expert_id
|
||||
start_index = start_index_values[idx].item()
|
||||
length = tokens_per_expert_group[idx].item()
|
||||
if length > 0:
|
||||
end_idx = min(write_start + length, max_len)
|
||||
permuted_indices[write_start:end_idx] = torch.arange(
|
||||
start_index,
|
||||
start_index + (end_idx - write_start),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
write_start += length
|
||||
return permuted_indices
|
||||
|
||||
|
||||
def generate_permute_indices(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
alignment: int,
|
||||
use_cpu: bool = False,
|
||||
):
|
||||
start_index_values = (
|
||||
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
|
||||
)
|
||||
|
||||
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
|
||||
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
|
||||
|
||||
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
|
||||
torch.int32
|
||||
)
|
||||
|
||||
m_offsets = torch.cumsum(m_sizes, 0)
|
||||
write_offsets = m_offsets - m_sizes
|
||||
|
||||
if use_cpu:
|
||||
permuted_indices = fill_indices_cpu(
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
max_len,
|
||||
)
|
||||
else:
|
||||
permuted_indices = fill_indices_wrapper(
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
max_len,
|
||||
)
|
||||
|
||||
return permuted_indices, m_sizes, m_offsets.to(torch.int32)
|
||||
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
|
||||
|
||||
from .cg_backward import ContiguousGroupedGEMM
|
||||
from .cg_forward import (
|
||||
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
]
|
||||
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Vendored backward pass for Triton contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .cg_forward import cg_grouped_gemm_forward
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dx(
|
||||
grad_output_ptr,
|
||||
b_ptr,
|
||||
grad_input_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""Compute gradients with respect to inputs."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
|
||||
tile_m = pid // num_k_tiles
|
||||
tile_k = pid % num_k_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_k = offs_k < K
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for n in range(0, N, BLOCK_SIZE_N):
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
|
||||
mask_n = offs_n < N
|
||||
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
mask_w = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||
|
||||
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
|
||||
|
||||
grad_input += tl.dot(go, w)
|
||||
|
||||
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_gi = mask_m[:, None] & mask_k[None, :]
|
||||
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dw(
|
||||
grad_output_ptr,
|
||||
inputs_ptr,
|
||||
grad_weights_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Simplified kernel for expert weight gradients."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
|
||||
if expert_id < NUM_EXPERTS:
|
||||
n_tiles = K // BLOCK_SIZE_K
|
||||
tile_n = position_id // n_tiles
|
||||
tile_k = position_id % n_tiles
|
||||
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if n_start < N and k_start < K:
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
|
||||
group_start = group_idx * GROUP_SIZE_M
|
||||
group_expert = tl.load(indices_ptr + group_start)
|
||||
|
||||
if group_expert == expert_id:
|
||||
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
|
||||
m_start = group_start + m_offset
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
|
||||
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
|
||||
|
||||
go_ptrs = (
|
||||
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
)
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||
|
||||
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_in = mask_m[:, None] & mask_k[None, :]
|
||||
inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32)
|
||||
|
||||
go_t = tl.trans(go)
|
||||
grad_weights += tl.dot(go_t, inp)
|
||||
|
||||
grad_w_ptrs = (
|
||||
grad_weights_ptr
|
||||
+ expert_id * N * K
|
||||
+ offs_n[:, None] * K
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
mask_gw = mask_n[:, None] & mask_k[None, :]
|
||||
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_weights(
|
||||
grad_output: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for expert weights."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
_, K = inputs.shape
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
grad_weights = torch.zeros(
|
||||
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
block_size_n = min(128, N)
|
||||
block_size_k = min(32, K)
|
||||
block_size_m = min(32, group_size_m)
|
||||
|
||||
n_tiles = triton.cdiv(N, block_size_n)
|
||||
k_tiles = triton.cdiv(K, block_size_k)
|
||||
grid = (num_experts * n_tiles * k_tiles,)
|
||||
|
||||
_kernel_cg_backward_dw[grid](
|
||||
grad_output,
|
||||
inputs,
|
||||
grad_weights,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
)
|
||||
|
||||
return grad_weights
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_inputs(
|
||||
grad_output: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for inputs."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
num_experts, _, K = expert_weights.shape
|
||||
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
grad_inputs = torch.zeros(
|
||||
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
|
||||
)
|
||||
|
||||
_kernel_cg_backward_dx[grid](
|
||||
grad_output,
|
||||
expert_weights,
|
||||
grad_inputs,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return grad_inputs
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function with full backward support."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
ctx.save_for_backward(inputs, expert_weights, expert_indices)
|
||||
ctx.group_size_m = group_size_m
|
||||
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inputs, expert_weights, expert_indices = ctx.saved_tensors
|
||||
group_size_m = ctx.group_size_m
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
num_experts = expert_weights.shape[0]
|
||||
|
||||
grad_inputs = cg_grouped_gemm_backward_inputs(
|
||||
grad_output=grad_output,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_weights = cg_grouped_gemm_backward_weights(
|
||||
grad_output=grad_output,
|
||||
inputs=inputs,
|
||||
expert_indices=expert_indices,
|
||||
num_experts=num_experts,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_indices = None
|
||||
grad_group_size_m = None
|
||||
|
||||
return grad_inputs, grad_weights, grad_indices, grad_group_size_m
|
||||
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Vendored forward Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * super_group_m
|
||||
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_persistent_forward(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
NUM_SMS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
SUPER_GROUP_M: tl.constexpr = 32,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward (persistent variant).
|
||||
"""
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
num_pid_in_group = SUPER_GROUP_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
m_start = tile_m_idx * BLOCK_SIZE_M
|
||||
n_start = tile_n_idx * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for ki in range(k_tiles):
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = (
|
||||
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
|
||||
c = accumulator.to(tl.float32)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_forward_aligned(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward for aligned inputs.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
tile_m = pid // num_n_tiles
|
||||
tile_n = pid % num_n_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
|
||||
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
acc += tl.dot(a, b.T)
|
||||
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
|
||||
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
grid = (NUM_SMS, 1, 1)
|
||||
_kernel_cg_persistent_forward[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
NUM_SMS=NUM_SMS,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward_dynamic(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
_kernel_cg_forward_aligned[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function for contiguous grouped GEMM forward pass only."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pragma: no cover - not implemented
|
||||
raise NotImplementedError("Backward pass not implemented")
|
||||
|
||||
|
||||
def cg_grouped_gemm(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Convenience wrapper for the forward-only autograd function."""
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
return ContiguousGroupedGEMM.apply(
|
||||
inputs, expert_weights, expert_indices, group_size_m
|
||||
)
|
||||
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Reference implementation for contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def pytorch_reference(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Simple PyTorch implementation for verification."""
|
||||
|
||||
M_total, K = inputs.shape
|
||||
num_experts, N, _ = expert_weights.shape
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
for i in range(0, M_total, group_size_m):
|
||||
end_idx = min(i + group_size_m, M_total)
|
||||
expert_idx = expert_indices[i].item()
|
||||
expert_weight = expert_weights[expert_idx]
|
||||
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
|
||||
|
||||
return output
|
||||
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.runtime import driver
|
||||
|
||||
|
||||
class CudaUtils:
|
||||
"""Helper utilities for CUDA specific Triton features."""
|
||||
|
||||
@staticmethod
|
||||
def is_cuda() -> bool:
|
||||
return driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
@staticmethod
|
||||
def verify_tma() -> bool:
|
||||
return (
|
||||
CudaUtils.is_cuda()
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_num_sms() -> int:
|
||||
if not CudaUtils.is_cuda():
|
||||
raise RuntimeError("Triton is not running on CUDA backend")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
class TmaDescriptorHelper:
|
||||
"""Helper class for managing TMA descriptors in Triton kernels."""
|
||||
|
||||
class KernelParamWrapper:
|
||||
def __init__(self, desc: torch.Tensor):
|
||||
self.desc = desc
|
||||
|
||||
def tma_desc_cpu_ptr(self) -> int:
|
||||
return self.desc.data_ptr()
|
||||
|
||||
def __init__(self, tma_size: int = 128):
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError(
|
||||
"TMA not supported on this device (requires Hopper or newer)"
|
||||
)
|
||||
if "nv_tma_desc_type" not in dir(tl):
|
||||
raise RuntimeError(
|
||||
"TMA grid constant descriptors not supported in your Triton version"
|
||||
)
|
||||
|
||||
self.tma_size = tma_size
|
||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
||||
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
||||
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def init_tma_descriptor(self, name: str) -> None:
|
||||
self.descriptors[name] = torch.empty(
|
||||
self.tma_size, device="cpu", dtype=torch.int8
|
||||
)
|
||||
|
||||
def fill_1d_tma_descriptor(
|
||||
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_1d_tma_descriptor_inner(
|
||||
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def fill_2d_tma_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
ptr: int,
|
||||
dim1: int,
|
||||
dim0: int,
|
||||
block_dim1: int,
|
||||
block_dim0: int,
|
||||
element_size: int,
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_2d_tma_descriptor_inner(
|
||||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def get_tma_descriptor_kernel_param(
|
||||
self, name: str
|
||||
) -> "TmaDescriptorHelper.KernelParamWrapper":
|
||||
if name not in self.descriptors or self.descriptors[name] is None:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
return self.KernelParamWrapper(self.descriptors[name])
|
||||
|
||||
|
||||
HOPPER_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
STANDARD_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def early_config_prune(configs, args, **kwargs):
|
||||
"""Filter out configurations that would exceed shared memory capacity."""
|
||||
k = kwargs.get("K", 0)
|
||||
valid_configs = [
|
||||
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
|
||||
]
|
||||
if not valid_configs and configs:
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
|
||||
)
|
||||
]
|
||||
|
||||
return valid_configs
|
||||
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .mg_grouped_gemm import grouped_gemm_forward
|
||||
from .tma_autotuning import ALIGN_SIZE_M
|
||||
|
||||
__all__ = [
|
||||
"grouped_gemm_forward",
|
||||
"ALIGN_SIZE_M",
|
||||
]
|
||||
761
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
761
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
@@ -0,0 +1,761 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# credit - flat index forward kernel is derived from FBGemm:
|
||||
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
||||
|
||||
# pyre-unsafe
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .tma_autotuning import (
|
||||
_NV_CONFIGS,
|
||||
CudaUtils,
|
||||
early_config_prune,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
_allocator_registered = False
|
||||
|
||||
|
||||
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
|
||||
return torch.empty(size, device="cuda", dtype=torch.int8)
|
||||
|
||||
|
||||
def _ensure_triton_allocator() -> None:
|
||||
global _allocator_registered
|
||||
if not _allocator_registered:
|
||||
triton.set_allocator(_torch_allocator)
|
||||
_allocator_registered = True
|
||||
|
||||
|
||||
# ============== Start Triton Kernels ===============
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_NV_CONFIGS,
|
||||
key=["G", "M_BUCKET", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_mg_forward_hopper(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m_sizes,
|
||||
M_TOTAL,
|
||||
# problem sizes
|
||||
G: tl.constexpr,
|
||||
M_BUCKET: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
# config
|
||||
NUM_SMS: tl.constexpr,
|
||||
USE_EPILOGUE_SUBTILING: tl.constexpr,
|
||||
# tiles
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
) -> None:
|
||||
"""Flat index style forward kernel for Hopper using tensor descriptors."""
|
||||
tbidx = tl.program_id(0)
|
||||
|
||||
c_dtype = c_ptr.dtype.element_ty
|
||||
n_size = N // G
|
||||
|
||||
a_desc = tl.make_tensor_descriptor(
|
||||
a_ptr,
|
||||
shape=[M_TOTAL, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
b_desc = tl.make_tensor_descriptor(
|
||||
b_ptr,
|
||||
shape=[N, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
M_end = tl.full([], 0, dtype=tl.int32)
|
||||
processed_tiles = 0
|
||||
|
||||
for g in range(G):
|
||||
M_start = M_end
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size > 0:
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
||||
group_num_tiles = num_m_tiles * num_n_tiles
|
||||
|
||||
while (
|
||||
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||
):
|
||||
group_index = tbidx - processed_tiles
|
||||
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_n_index = group_index // num_m_tiles
|
||||
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
|
||||
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
||||
global_n_offset = (g * n_size + n_offset).to(tl.int32)
|
||||
|
||||
for k_offset in range(0, K, BLOCK_SIZE_K):
|
||||
k_remaining = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
||||
|
||||
a = a_desc.load([m_offset, k_offset])
|
||||
a_mask = row_mask[:, None] & k_mask[None, :]
|
||||
a = tl.where(a_mask, a, tl.zeros_like(a))
|
||||
|
||||
b = b_desc.load([global_n_offset, k_offset])
|
||||
b_mask = col_mask[:, None] & k_mask[None, :]
|
||||
b = tl.where(b_mask, b, tl.zeros_like(b))
|
||||
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
|
||||
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
|
||||
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
|
||||
0, BLOCK_SIZE_N
|
||||
)
|
||||
col_store_mask = local_col_offsets < n_size
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
if USE_EPILOGUE_SUBTILING:
|
||||
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
||||
acc = tl.permute(acc, (0, 2, 1))
|
||||
acc0, acc1 = tl.split(acc)
|
||||
|
||||
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
|
||||
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
|
||||
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
|
||||
tl.store(
|
||||
ptr0,
|
||||
acc0.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask0[None, :],
|
||||
)
|
||||
|
||||
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
|
||||
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
|
||||
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
|
||||
tl.store(
|
||||
ptr1,
|
||||
acc1.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask1[None, :],
|
||||
)
|
||||
else:
|
||||
ptr = (
|
||||
c_ptr
|
||||
+ global_row[:, None] * n_size
|
||||
+ local_col_offsets[None, :]
|
||||
)
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
tbidx += NUM_SMS
|
||||
|
||||
processed_tiles += group_num_tiles
|
||||
|
||||
|
||||
"""
|
||||
Backward pass for grouped GEMM with Triton, where grouping is M*G
|
||||
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
|
||||
"""
|
||||
|
||||
|
||||
# ---- dx flat linear indexed ----
|
||||
@triton.autotune(
|
||||
configs=_NV_CONFIGS,
|
||||
key=["G", "M_BUCKET", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_mg_dx_tma(
|
||||
grad_output_ptr,
|
||||
w_ptr,
|
||||
grad_input_ptr,
|
||||
m_sizes,
|
||||
M_TOTAL,
|
||||
# problem sizes
|
||||
G: tl.constexpr,
|
||||
M_BUCKET: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
# config
|
||||
NUM_SMS: tl.constexpr,
|
||||
# tiles
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
) -> None:
|
||||
"""Compute grad_input = grad_output @ w using tensor descriptors."""
|
||||
tbidx = tl.program_id(0)
|
||||
|
||||
c_dtype = grad_input_ptr.dtype.element_ty
|
||||
|
||||
grad_output_desc = tl.make_tensor_descriptor(
|
||||
grad_output_ptr,
|
||||
shape=[M_TOTAL, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
w_desc = tl.make_tensor_descriptor(
|
||||
w_ptr,
|
||||
shape=[N, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
M_end = tl.full([], 0, dtype=tl.int32)
|
||||
processed_tiles = 0
|
||||
|
||||
for g in range(G):
|
||||
M_start = M_end
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size > 0:
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
group_num_tiles = num_m_tiles * num_k_tiles
|
||||
|
||||
while (
|
||||
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||
):
|
||||
group_index = tbidx - processed_tiles
|
||||
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_k_index = group_index // num_m_tiles
|
||||
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
k_offset = tile_k_index * BLOCK_SIZE_K
|
||||
k_remaining_total = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
|
||||
for n_offset in range(0, N, BLOCK_SIZE_N):
|
||||
n_remaining = N - n_offset
|
||||
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
||||
|
||||
grad_y = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_y_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
|
||||
|
||||
w_tile = w_desc.load([n_offset, k_offset])
|
||||
w_mask = n_mask[:, None] & k_mask[None, :]
|
||||
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
|
||||
|
||||
accumulator += tl.dot(grad_y, w_tile)
|
||||
|
||||
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
|
||||
0, BLOCK_SIZE_M
|
||||
)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
|
||||
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
||||
col_store_mask = col_offsets < K
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
tbidx += NUM_SMS
|
||||
|
||||
processed_tiles += group_num_tiles
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_NV_CONFIGS,
|
||||
key=["G", "M_BUCKET", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_mg_dw_tma(
|
||||
x_ptr,
|
||||
grad_output_ptr,
|
||||
grad_weight_ptr,
|
||||
m_sizes,
|
||||
M_TOTAL,
|
||||
# problem sizes
|
||||
G: tl.constexpr,
|
||||
M_BUCKET: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
# config
|
||||
NUM_SMS: tl.constexpr,
|
||||
# tiles
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
) -> None:
|
||||
"""Compute grad_weight = grad_output.T @ x using tensor descriptors."""
|
||||
tbidx = tl.program_id(0)
|
||||
|
||||
c_dtype = grad_weight_ptr.dtype.element_ty
|
||||
|
||||
x_desc = tl.make_tensor_descriptor(
|
||||
x_ptr,
|
||||
shape=[M_TOTAL, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
grad_output_desc = tl.make_tensor_descriptor(
|
||||
grad_output_ptr,
|
||||
shape=[M_TOTAL, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
|
||||
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
total_tiles = num_n_tiles * num_k_tiles
|
||||
|
||||
for tile_idx in range(tbidx, total_tiles, NUM_SMS):
|
||||
tile_n_idx = tile_idx % num_n_tiles
|
||||
tile_k_idx = tile_idx // num_n_tiles
|
||||
|
||||
n_offset = tile_n_idx * BLOCK_SIZE_N
|
||||
n_remaining = N - n_offset
|
||||
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
||||
|
||||
k_offset = tile_k_idx * BLOCK_SIZE_K
|
||||
k_remaining = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
M_end = tl.full([], 0, dtype=tl.int32)
|
||||
for g in range(G):
|
||||
M_start = M_end
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size > 0:
|
||||
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
|
||||
rows_remaining = m_size - m_offset_local
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
m_offset = (M_start + m_offset_local).to(tl.int32)
|
||||
|
||||
x_block = x_desc.load([m_offset, k_offset])
|
||||
x_mask = row_mask[:, None] & k_mask[None, :]
|
||||
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
|
||||
|
||||
grad_block = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_block = tl.where(
|
||||
grad_mask, grad_block, tl.zeros_like(grad_block)
|
||||
)
|
||||
|
||||
contribution = tl.dot(
|
||||
grad_block.to(tl.float32).T,
|
||||
x_block.to(tl.float32),
|
||||
)
|
||||
accumulator += contribution
|
||||
|
||||
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
|
||||
row_store_mask = row_offsets < N
|
||||
|
||||
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
||||
col_store_mask = col_offsets < K
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
ptr = grad_weight_ptr + row_offsets[:, None] * K + col_offsets[None, :]
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
|
||||
# ======== End Triton kernels ========
|
||||
# ======== End Triton kernels ========
|
||||
|
||||
# ======== Triton wrapper functions ========
|
||||
|
||||
# ----- main forward pass wrapper -----
|
||||
|
||||
|
||||
def grouped_gemm_forward(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
tma_size: int = 128,
|
||||
using_fp8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Grouped GEMM forward using Hopper TMA kernels."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
||||
if using_fp8:
|
||||
raise NotImplementedError(
|
||||
"FP8 path not implemented with the new Triton API yet"
|
||||
)
|
||||
|
||||
G = m_sizes.shape[0]
|
||||
|
||||
assert x.is_contiguous()
|
||||
assert w.is_contiguous()
|
||||
assert m_sizes.is_contiguous()
|
||||
|
||||
M_total, K = x.shape
|
||||
N = w.shape[0]
|
||||
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
|
||||
|
||||
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
|
||||
if M_total == 0:
|
||||
return y
|
||||
|
||||
NUM_SMS = CudaUtils.get_num_sms()
|
||||
USE_EPILOGUE_SUBTILING = False
|
||||
|
||||
def grid(_meta):
|
||||
return (NUM_SMS,)
|
||||
|
||||
M_BUCKET = triton.next_power_of_2(M_total)
|
||||
_kernel_mg_forward_hopper[grid](
|
||||
x,
|
||||
w,
|
||||
y,
|
||||
m_sizes,
|
||||
M_total,
|
||||
G,
|
||||
M_BUCKET,
|
||||
N,
|
||||
K,
|
||||
NUM_SMS,
|
||||
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
# ======== Improved Backward =============
|
||||
def grouped_gemm_backward(
|
||||
grad_output: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
use_tma: bool = True,
|
||||
tma_size: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Unified backward pass for grouped GeMM with M*G grouping.
|
||||
Uses optimized TMA-based implementations for both dx and dw when available.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of output, shape [M_total, N]
|
||||
x: Input tensor from forward pass, shape [M_total, K]
|
||||
w: Weight tensor from forward pass, shape [N, K]
|
||||
m_sizes: Group sizes tensor, shape [G]
|
||||
use_tma: Whether to try using TMA acceleration (if available)
|
||||
tma_size: Size of TMA descriptor in bytes
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of gradients with respect to x and w: (grad_x, grad_w)
|
||||
"""
|
||||
logging.info("Starting unified grouped_gemm_backward")
|
||||
|
||||
# do this once, seems expensive
|
||||
NUM_SMS = CudaUtils.get_num_sms()
|
||||
|
||||
# Basic validation
|
||||
M_total, K_x = x.shape
|
||||
M_grad, N = grad_output.shape
|
||||
N_w, K_w = w.shape
|
||||
|
||||
# Check dimensions
|
||||
if K_x != K_w:
|
||||
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
|
||||
if M_total != M_grad:
|
||||
raise ValueError(
|
||||
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
|
||||
)
|
||||
|
||||
# Check total M matches sum of group sizes
|
||||
sum_m_sizes = m_sizes.sum().item()
|
||||
if M_total != sum_m_sizes:
|
||||
raise ValueError(
|
||||
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
||||
)
|
||||
|
||||
# Make sure inputs are contiguous
|
||||
grad_output = grad_output.contiguous()
|
||||
x = x.contiguous()
|
||||
w = w.contiguous()
|
||||
m_sizes = m_sizes.contiguous()
|
||||
|
||||
# Check TMA support
|
||||
if use_tma and not CudaUtils.verify_tma():
|
||||
logging.info("TMA requested but not supported on this device")
|
||||
use_tma = False
|
||||
|
||||
# Compute grad_x using flat linear implementation
|
||||
try:
|
||||
logging.info("Computing grad_x with flat linear kernel")
|
||||
|
||||
# Use TMA-optimized implementation
|
||||
grad_x = grouped_gemm_dx_tma(
|
||||
grad_output=grad_output,
|
||||
w=w,
|
||||
m_sizes=m_sizes,
|
||||
num_sms=NUM_SMS,
|
||||
tma_size=tma_size,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in grad_x computation: {e}")
|
||||
raise
|
||||
|
||||
# Compute grad_w using flat linear style implementation
|
||||
try:
|
||||
logging.info("Computing grad_w with flat linear kernel")
|
||||
|
||||
grad_w = grouped_gemm_dw_tma(
|
||||
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in grad_w computation: {e}")
|
||||
raise
|
||||
|
||||
return grad_x, grad_w
|
||||
|
||||
|
||||
# ----- dx backward pass wrapper -----
|
||||
|
||||
|
||||
def grouped_gemm_dx_tma(
|
||||
grad_output: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_sms: int = 132,
|
||||
tma_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Compute grad_x using the Hopper grouped GEMM kernel."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise NotImplementedError("Optimized dx computation requires TMA support")
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
w = w.contiguous()
|
||||
m_sizes = m_sizes.contiguous()
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
N_w, K = w.shape
|
||||
if N != N_w:
|
||||
raise ValueError(f"Grad_output N ({N}) must match weight N ({N_w})")
|
||||
|
||||
if m_sizes.sum().item() != M_total:
|
||||
raise ValueError("Sum of m_sizes must equal the number of rows in grad_output")
|
||||
|
||||
grad_x = torch.empty(
|
||||
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
NUM_SMS = num_sms
|
||||
|
||||
def grid(_meta):
|
||||
return (NUM_SMS,)
|
||||
|
||||
M_BUCKET = triton.next_power_of_2(M_total)
|
||||
_kernel_mg_dx_tma[grid](
|
||||
grad_output,
|
||||
w,
|
||||
grad_x,
|
||||
m_sizes,
|
||||
M_total,
|
||||
m_sizes.shape[0],
|
||||
M_BUCKET,
|
||||
N,
|
||||
K,
|
||||
NUM_SMS,
|
||||
)
|
||||
return grad_x
|
||||
|
||||
|
||||
# ======== dw wrapper function ==========
|
||||
|
||||
|
||||
def grouped_gemm_dw_tma(
|
||||
x: torch.Tensor,
|
||||
grad_output: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_sms: int = 132,
|
||||
tma_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Compute grad_w using the Hopper grouped GEMM kernel."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
|
||||
|
||||
x = x.contiguous()
|
||||
grad_output = grad_output.contiguous()
|
||||
m_sizes = m_sizes.contiguous()
|
||||
|
||||
M_total, K = x.shape
|
||||
M_grad, N = grad_output.shape
|
||||
if M_total != M_grad:
|
||||
raise ValueError("x and grad_output must have matching batch dimension")
|
||||
if m_sizes.sum().item() != M_total:
|
||||
raise ValueError("Sum of m_sizes must equal the number of rows in the inputs")
|
||||
|
||||
grad_w = torch.zeros((N, K), device=x.device, dtype=x.dtype)
|
||||
|
||||
NUM_SMS = num_sms
|
||||
|
||||
def grid(_meta):
|
||||
return (NUM_SMS,)
|
||||
|
||||
M_BUCKET = triton.next_power_of_2(M_total)
|
||||
_kernel_mg_dw_tma[grid](
|
||||
x,
|
||||
grad_output,
|
||||
grad_w,
|
||||
m_sizes,
|
||||
M_total,
|
||||
m_sizes.shape[0],
|
||||
M_BUCKET,
|
||||
N,
|
||||
K,
|
||||
NUM_SMS,
|
||||
)
|
||||
return grad_w
|
||||
|
||||
|
||||
# ======== End Backwards Wrapper Functions =============
|
||||
|
||||
# ======== PyTorch wrapper functions ========
|
||||
|
||||
|
||||
class GroupedGemmMg(torch.autograd.Function):
|
||||
"""
|
||||
Autograd function for GroupedGEMM with M*G grouping.
|
||||
Supports both standard and FP8 quantized operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128, using_fp8=False):
|
||||
"""
|
||||
Forward pass of GroupedGEMM.
|
||||
|
||||
Args:
|
||||
x: Input tensor, shape [M_total, K]
|
||||
w: Weight tensor, shape [N, K]
|
||||
m_sizes: Tensor of shape [G] containing the size of each group
|
||||
use_tma: Whether to try using TMA acceleration (if available)
|
||||
tma_size: Size of TMA descriptor in bytes
|
||||
using_fp8: Whether to use FP8 quantization
|
||||
|
||||
Returns:
|
||||
Output tensor, shape [M_total, N]
|
||||
"""
|
||||
|
||||
# Use regular forward without quantization
|
||||
output = grouped_gemm_forward(
|
||||
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
|
||||
)
|
||||
|
||||
# Save inputs and parameters for backward pass
|
||||
ctx.save_for_backward(x, w, m_sizes)
|
||||
ctx.use_tma = use_tma
|
||||
ctx.tma_size = tma_size
|
||||
|
||||
ctx.save_for_backward(x, w, m_sizes)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
Backward pass of M*G GroupedGEMM.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of output, shape [M_total, N]
|
||||
|
||||
Returns:
|
||||
Tuple of gradients:
|
||||
- grad_x: Gradient with respect to x, shape [M_total, K]
|
||||
- grad_w: Gradient with respect to w, shape [N, K]
|
||||
- None: Gradient with respect to m_sizes (not differentiable)
|
||||
- None: Gradient with respect to use_tma (not differentiable)
|
||||
- None: Gradient with respect to tma_size (not differentiable)
|
||||
|
||||
"""
|
||||
# Retrieve saved tensors and parameters
|
||||
|
||||
x, w, m_sizes = ctx.saved_tensors
|
||||
|
||||
use_tma = ctx.use_tma
|
||||
tma_size = ctx.tma_size
|
||||
|
||||
# Compute gradients using the unified implementation
|
||||
grad_x, grad_w = grouped_gemm_backward(
|
||||
grad_output=grad_output,
|
||||
x=x,
|
||||
w=w,
|
||||
m_sizes=m_sizes,
|
||||
use_tma=use_tma,
|
||||
tma_size=tma_size,
|
||||
)
|
||||
|
||||
# Return gradients for all inputs (None for non-differentiable parameters)
|
||||
return grad_x, grad_w, None, None, None, None
|
||||
|
||||
|
||||
def mg_grouped_gemm(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
use_tma: bool = True,
|
||||
tma_size: int = 128,
|
||||
using_fp8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
|
||||
Supports both standard precision and FP8 quantized operations.
|
||||
|
||||
Args:
|
||||
x: Input tensor, shape [M_total, K]
|
||||
w: Weight tensor, shape [N, K]
|
||||
m_sizes: Tensor of shape [G] containing the size of each group
|
||||
use_tma: Whether to try using TMA acceleration (if available)
|
||||
tma_size: Size of TMA descriptor in bytes
|
||||
using_fp8: Whether to use FP8 quantization
|
||||
|
||||
Returns:
|
||||
Output tensor, shape [M_total, N]
|
||||
"""
|
||||
return GroupedGemmMg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
|
||||
232
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
232
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
||||
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from triton.runtime import driver # @manual
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
# ===== Supporting utils, CUDA and TMA =====
|
||||
|
||||
|
||||
class CudaUtils:
|
||||
@staticmethod
|
||||
def is_cuda() -> bool:
|
||||
"""Check if Triton is running on CUDA backend."""
|
||||
return driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
@staticmethod
|
||||
def verify_tma() -> bool:
|
||||
"""Check if TMA is supported on the current device."""
|
||||
return (
|
||||
CudaUtils.is_cuda()
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_num_sms() -> int:
|
||||
"""Get the number of streaming multiprocessors on the current device."""
|
||||
if not CudaUtils.is_cuda():
|
||||
raise RuntimeError("Triton is not running on CUDA backend")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
class TmaDescriptorHelper:
|
||||
"""Helper class for managing TMA descriptors in Triton kernels.
|
||||
|
||||
Args:
|
||||
tma_size: Size of the TMA descriptor in bytes
|
||||
"""
|
||||
|
||||
class KernelParamWrapper:
|
||||
"""Wrapper to implement the TmaDescKernelParam interface."""
|
||||
|
||||
def __init__(self, desc: torch.Tensor):
|
||||
self.desc = desc
|
||||
|
||||
def tma_desc_cpu_ptr(self) -> int:
|
||||
"""Return the CPU pointer to the TMA descriptor."""
|
||||
return self.desc.data_ptr()
|
||||
|
||||
def __init__(self, tma_size: int = 128):
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError(
|
||||
"TMA not supported on this device (requires Hopper or newer)"
|
||||
)
|
||||
|
||||
self.tma_size = tma_size
|
||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
|
||||
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
|
||||
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def init_tma_descriptor(self, name: str) -> None:
|
||||
"""Initialize a TMA descriptor with the given name.
|
||||
|
||||
Call this method outside of the lambda function for grid size.
|
||||
"""
|
||||
self.descriptors[name] = torch.empty(
|
||||
self.tma_size, device="cpu", dtype=torch.int8
|
||||
)
|
||||
|
||||
def fill_1d_tma_descriptor(
|
||||
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||
) -> None:
|
||||
"""Fill a 1D TMA descriptor.
|
||||
|
||||
Call this method inside the lambda function for grid size.
|
||||
"""
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_1d_tma_descriptor_inner(
|
||||
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def fill_2d_tma_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
ptr: int,
|
||||
dim1: int,
|
||||
dim0: int,
|
||||
block_dim1: int,
|
||||
block_dim0: int,
|
||||
element_size: int,
|
||||
) -> None:
|
||||
"""Fill a 2D TMA descriptor.
|
||||
|
||||
Call this method inside the lambda function for grid size.
|
||||
"""
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_2d_tma_descriptor_inner(
|
||||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
||||
"""Get the TMA descriptor kernel parameter for the given name."""
|
||||
if name not in self.descriptors or self.descriptors[name] is None:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
return self.KernelParamWrapper(self.descriptors[name])
|
||||
|
||||
|
||||
# ====== Autotuning utilities ======
|
||||
ALIGN_SIZE_M = 128
|
||||
|
||||
_NV_CONFIGS = [
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
},
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
)
|
||||
for block_size_m in [
|
||||
ALIGN_SIZE_M,
|
||||
]
|
||||
for block_size_n in [64, 128, 256]
|
||||
for block_size_k in [64, 128, 256]
|
||||
for num_stages in [3, 4]
|
||||
for num_warps in [4, 8]
|
||||
for num_ctas in [1]
|
||||
]
|
||||
|
||||
|
||||
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
||||
device = torch.cuda.current_device()
|
||||
# Check for all possible pointer parameter names
|
||||
if "grad_input_ptr" in named_args:
|
||||
ptr_name = "grad_input_ptr"
|
||||
elif "c_ptr" in named_args:
|
||||
ptr_name = "c_ptr"
|
||||
elif "grad_weight_ptr" in named_args:
|
||||
ptr_name = "grad_weight_ptr"
|
||||
else:
|
||||
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
||||
|
||||
if dtsize is None:
|
||||
dtsize = named_args[ptr_name].element_size()
|
||||
if dtype is None:
|
||||
dtype = named_args[ptr_name].dtype
|
||||
|
||||
pruned_configs = []
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
||||
kw["BLOCK_SIZE_M"],
|
||||
kw["BLOCK_SIZE_N"],
|
||||
kw["BLOCK_SIZE_K"],
|
||||
config.num_stages,
|
||||
)
|
||||
G, M, N, K = (
|
||||
named_args["G"],
|
||||
named_args["M_BUCKET"],
|
||||
named_args["N"],
|
||||
named_args["K"],
|
||||
)
|
||||
|
||||
# 1. make sure we have enough smem
|
||||
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
||||
"max_shared_mem"
|
||||
]
|
||||
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory > max_shared_memory:
|
||||
continue
|
||||
|
||||
M_PER_GROUP = M // G
|
||||
MIN_M_TILES = 64
|
||||
# 2. make sure we don't load M tiles that are too big
|
||||
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
||||
continue
|
||||
# 3. make sure we don't load N tiles that are too small
|
||||
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
||||
continue
|
||||
|
||||
num_sm = driver.active.utils.get_device_properties(device)[
|
||||
"multiprocessor_count"
|
||||
]
|
||||
N_TILES = N // BLOCK_N
|
||||
MIN_N_TILES = 64
|
||||
# 4. make sure we don't load N tiles that are too big
|
||||
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
||||
continue
|
||||
# 5. make sure we don't load N tiles that are too small
|
||||
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
||||
continue
|
||||
# 6. make sure K can be evenly divided
|
||||
if K % BLOCK_K != 0:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
# ======== End Autotuning utilities ========
|
||||
@@ -515,6 +515,9 @@ class ModelLoader:
|
||||
if self.cfg.model_quantization_config_kwargs:
|
||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||
else:
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
if self.cfg.gptq:
|
||||
if not hasattr(self.model_config, "quantization_config"):
|
||||
@@ -549,7 +552,9 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
||||
"load_in_4bit", False
|
||||
):
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
"llm_int8_threshold": 6.0,
|
||||
@@ -575,7 +580,9 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit:
|
||||
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
||||
"load_in_8bit", False
|
||||
):
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -589,6 +596,11 @@ class ModelLoader:
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||
self.model_kwargs.pop("load_in_8bit", None)
|
||||
self.model_kwargs.pop("load_in_4bit", None)
|
||||
|
||||
def _set_attention_config(self):
|
||||
"""Sample packing uses custom FA2 patch"""
|
||||
if self.cfg.attn_implementation:
|
||||
|
||||
@@ -190,6 +190,15 @@ class PatchManager:
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
|
||||
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
|
||||
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
|
||||
LOG.info(
|
||||
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
|
||||
)
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
|
||||
@@ -4,7 +4,6 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
@@ -278,11 +277,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
mesh = getattr(accelerator.state, "device_mesh", None)
|
||||
|
||||
# Disable memory pinning if requested
|
||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||
if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "") == "false":
|
||||
fsdp2_plugin.cpu_offload.pin_memory = False
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
@@ -347,6 +341,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
)
|
||||
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading:
|
||||
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||
fsdp2_load_full_state_dict(
|
||||
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
|
||||
)
|
||||
|
||||
401
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
401
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||
from axolotl.kernels.moe.indices import generate_permute_indices
|
||||
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
_GROUP_SIZE_M = 128
|
||||
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
if not hidden_states.is_cuda or hidden_states.shape[0] == 0:
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability(hidden_states.device)
|
||||
if major < 9:
|
||||
LOG.debug(
|
||||
"Skipping Triton MoE kernels: requires compute capability >= 90, found %s",
|
||||
major,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_combined_expert_weights(
|
||||
module, dtype: torch.dtype, device: torch.device, backend: str
|
||||
) -> None:
|
||||
if not hasattr(module, "_axolotl_original_specs"):
|
||||
module._axolotl_original_specs = {}
|
||||
if not hasattr(module, "_axolotl_mg_shapes"):
|
||||
module._axolotl_mg_shapes = {}
|
||||
|
||||
prev_backend = getattr(module, "_axolotl_combined_backend", None)
|
||||
if getattr(module, "_axolotl_combined_weights", False):
|
||||
if prev_backend != backend:
|
||||
_restore_expert_weights(module)
|
||||
else:
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
param = module.get_parameter(param_name)
|
||||
if param.device != device or param.dtype != dtype:
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
param.to(device=device, dtype=dtype).contiguous()
|
||||
)
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
module._axolotl_combined_backend = backend
|
||||
return
|
||||
|
||||
module._axolotl_mg_shapes = {}
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
weights = []
|
||||
orig_device = None
|
||||
orig_dtype = None
|
||||
orig_shape = None
|
||||
for expert in module.experts:
|
||||
lin = expert.get_submodule(name)
|
||||
weight_param = lin._parameters.get("weight")
|
||||
if weight_param is None:
|
||||
raise RuntimeError("Expected expert linear layers to have weights")
|
||||
if orig_device is None:
|
||||
orig_device = weight_param.device
|
||||
orig_dtype = weight_param.dtype
|
||||
orig_shape = tuple(weight_param.shape)
|
||||
weights.append(weight_param.detach().to(device=device, dtype=dtype))
|
||||
if "weight" in lin._parameters:
|
||||
del lin._parameters["weight"]
|
||||
if "bias" in lin._parameters:
|
||||
del lin._parameters["bias"]
|
||||
if backend == "cg":
|
||||
combined_weight = torch.stack(weights, dim=0).contiguous()
|
||||
else:
|
||||
combined_weight = torch.cat(weights, dim=0).contiguous()
|
||||
module._axolotl_mg_shapes[name] = orig_shape
|
||||
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
|
||||
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
|
||||
|
||||
module._axolotl_combined_weights = True
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
module._axolotl_combined_backend = backend
|
||||
|
||||
|
||||
def _restore_expert_weights(module) -> None:
|
||||
if not getattr(module, "_axolotl_combined_weights", False):
|
||||
return
|
||||
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
combined = module._parameters.pop(param_name)
|
||||
orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
|
||||
name, (combined.device, combined.dtype, None)
|
||||
)
|
||||
rows_per = orig_shape[0] if orig_shape else None
|
||||
for idx, expert in enumerate(module.experts):
|
||||
lin = expert.get_submodule(name)
|
||||
if combined.dim() == 3:
|
||||
slice_tensor = combined[idx]
|
||||
elif rows_per is not None:
|
||||
start = idx * rows_per
|
||||
end = start + rows_per
|
||||
slice_tensor = combined[start:end]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unable to recover expert weight shape during restore"
|
||||
)
|
||||
lin._parameters["weight"] = torch.nn.Parameter(
|
||||
slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
|
||||
)
|
||||
|
||||
module._axolotl_combined_weights = False
|
||||
module._axolotl_combined_dtype = None
|
||||
module._axolotl_combined_device = None
|
||||
module._axolotl_combined_backend = None
|
||||
module._axolotl_original_specs = {}
|
||||
module._axolotl_mg_shapes = {}
|
||||
|
||||
|
||||
def _run_cg_grouped_gemm(
|
||||
module,
|
||||
grouped_hidden: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_size_m: int,
|
||||
hidden_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
|
||||
|
||||
expert_index_tensor = torch.repeat_interleave(
|
||||
torch.arange(num_experts, device=device, dtype=torch.int32),
|
||||
m_sizes.to(torch.int64),
|
||||
)
|
||||
|
||||
gate_weights = module.get_parameter("gate_proj_weight")
|
||||
if gate_weights.dim() == 2:
|
||||
out_dim = gate_weights.shape[0] // num_experts
|
||||
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
|
||||
|
||||
up_weights = module.get_parameter("up_proj_weight")
|
||||
if up_weights.dim() == 2:
|
||||
out_dim = up_weights.shape[0] // num_experts
|
||||
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
|
||||
|
||||
down_weights = module.get_parameter("down_proj_weight")
|
||||
if down_weights.dim() == 2:
|
||||
out_dim = down_weights.shape[0] // num_experts
|
||||
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
|
||||
|
||||
gate_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
gate_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
up_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
up_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
return (
|
||||
gate_out.to(hidden_dtype),
|
||||
up_out.to(hidden_dtype),
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
)
|
||||
|
||||
gate_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("gate_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
up_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("up_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
down_out = mg_grouped_gemm(
|
||||
hidden_grouped,
|
||||
module.get_parameter("down_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
|
||||
return (
|
||||
gate_out.to(hidden_dtype),
|
||||
up_out.to(hidden_dtype),
|
||||
down_out.to(hidden_dtype),
|
||||
)
|
||||
|
||||
|
||||
def _moe_triton_forward(
|
||||
module,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
group_size_m: int,
|
||||
backend: str,
|
||||
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if not _is_triton_eligible(hidden_states):
|
||||
return fallback(hidden_states, topk_indices, topk_weights)
|
||||
|
||||
device = hidden_states.device
|
||||
hidden_dtype = hidden_states.dtype
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
top_k = topk_indices.size(-1)
|
||||
|
||||
expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0)
|
||||
expert_assignments = topk_indices.reshape(-1)
|
||||
if expanded_hidden.numel() == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
sort_perm = torch.argsort(expert_assignments)
|
||||
sorted_hidden = expanded_hidden.index_select(0, sort_perm)
|
||||
sorted_assignments = expert_assignments.index_select(0, sort_perm)
|
||||
|
||||
num_experts = len(module.experts)
|
||||
counts = torch.bincount(sorted_assignments, minlength=num_experts)
|
||||
total_actual = int(counts.sum().item())
|
||||
if total_actual == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
if not getattr(module, "_axolotl_triton_logged", False):
|
||||
min_tokens = int(counts.min().item())
|
||||
max_tokens = int(counts.max().item())
|
||||
LOG.info(
|
||||
"DeepseekV3MoE Triton: tokens per expert (min=%s, max=%s, avg=%.1f) with group_size=%s",
|
||||
min_tokens,
|
||||
max_tokens,
|
||||
total_actual / max(1, num_experts),
|
||||
group_size_m,
|
||||
)
|
||||
module._axolotl_triton_logged = True
|
||||
|
||||
counts_int = counts.to(torch.int32)
|
||||
aligned_counts = (
|
||||
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
|
||||
) * group_size_m
|
||||
max_len = int(aligned_counts.sum().item())
|
||||
|
||||
permuted_indices, m_sizes, _ = generate_permute_indices(
|
||||
counts_int.to(device),
|
||||
experts_per_rank=num_experts,
|
||||
num_ranks=1,
|
||||
max_len=max_len,
|
||||
alignment=group_size_m,
|
||||
use_cpu=not hidden_states.is_cuda,
|
||||
)
|
||||
|
||||
permuted_indices = permuted_indices.to(device)
|
||||
m_sizes = m_sizes.to(device)
|
||||
|
||||
permuted_indices_long = permuted_indices.to(torch.int64)
|
||||
valid_mask = permuted_indices_long >= 0
|
||||
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
|
||||
source_indices = permuted_indices_long[valid_mask]
|
||||
padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1)
|
||||
|
||||
grouped_hidden = hidden_states.new_empty((max_len, hidden_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
grouped_hidden.index_copy_(
|
||||
0,
|
||||
valid_positions,
|
||||
sorted_hidden.index_select(0, source_indices),
|
||||
)
|
||||
if valid_positions.numel() < max_len:
|
||||
grouped_hidden.index_fill_(0, padded_positions, 0)
|
||||
|
||||
m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
|
||||
|
||||
if backend == "mg":
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
|
||||
gate_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("gate_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
up_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("up_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
else:
|
||||
gate_out, up_out, down_weights, expert_index_tensor = _run_cg_grouped_gemm(
|
||||
module,
|
||||
grouped_hidden,
|
||||
m_sizes,
|
||||
num_experts,
|
||||
group_size_m,
|
||||
hidden_dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
||||
if valid_positions.numel() > 0:
|
||||
gate_valid = gate_out.index_select(0, valid_positions)
|
||||
up_valid = up_out.index_select(0, valid_positions)
|
||||
hidden_concat = act_fn(gate_valid) * up_valid
|
||||
else:
|
||||
hidden_concat = torch.empty(
|
||||
(0, gate_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||
)
|
||||
|
||||
intermediate_dim = hidden_concat.shape[-1]
|
||||
hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
|
||||
if valid_positions.numel() < max_len:
|
||||
hidden_grouped.index_fill_(0, padded_positions, 0)
|
||||
|
||||
if backend == "mg":
|
||||
down_out = mg_grouped_gemm(
|
||||
hidden_grouped,
|
||||
module.get_parameter("down_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
else:
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
hidden_grouped,
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
).to(hidden_dtype)
|
||||
|
||||
if valid_positions.numel() > 0:
|
||||
down_valid = down_out.index_select(0, valid_positions)
|
||||
else:
|
||||
down_valid = torch.empty(
|
||||
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||
)
|
||||
|
||||
sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim))
|
||||
if down_valid.numel() > 0:
|
||||
sorted_outputs.index_copy_(0, source_indices, down_valid)
|
||||
|
||||
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
||||
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
|
||||
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
|
||||
|
||||
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
|
||||
return weighted.sum(dim=1)
|
||||
|
||||
|
||||
def patch_deepseek_v3_moe(
|
||||
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
|
||||
) -> None:
|
||||
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
|
||||
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
|
||||
if backend not in {"cg", "mg"}:
|
||||
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
|
||||
|
||||
# Record the unpatched implementation so callers can access a true baseline even
|
||||
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
|
||||
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
|
||||
DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe
|
||||
|
||||
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
|
||||
return
|
||||
|
||||
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
|
||||
DeepseekV3MoE._axolotl_triton_backend = backend
|
||||
DeepseekV3MoE._axolotl_group_size_m = group_size_m
|
||||
|
||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
|
||||
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
|
||||
if backend_sel == "cg" and group_size_sel != _GROUP_SIZE_M:
|
||||
LOG.debug(
|
||||
"Adjusting group_size_m=%s to %s for CG backend",
|
||||
group_size_sel,
|
||||
_GROUP_SIZE_M,
|
||||
)
|
||||
group_size_sel = _GROUP_SIZE_M
|
||||
try:
|
||||
return _moe_triton_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
topk_indices,
|
||||
topk_weights,
|
||||
group_size_sel,
|
||||
backend_sel,
|
||||
original_moe,
|
||||
)
|
||||
except Exception as err: # surface Triton failures explicitly
|
||||
_restore_expert_weights(self)
|
||||
LOG.error("DeepseekV3MoE Triton path failed: %s", err)
|
||||
raise
|
||||
|
||||
DeepseekV3MoE.moe = patched_moe
|
||||
DeepseekV3MoE._axolotl_triton_patch = True
|
||||
@@ -134,11 +134,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
|
||||
return Qwen2Attention
|
||||
|
||||
if model_type == "qwen3_vl":
|
||||
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextAttention
|
||||
|
||||
return Qwen3VLTextAttention
|
||||
|
||||
if model_type == "mllama":
|
||||
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention
|
||||
|
||||
|
||||
@@ -45,8 +45,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"gpt_oss",
|
||||
"arcee",
|
||||
"seed_oss",
|
||||
"lfm2",
|
||||
"lfm2_moe",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
||||
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
|
||||
PATCHED_GUARD = (
|
||||
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
|
||||
)
|
||||
|
||||
|
||||
def patch_prepare_context_parallel_inputs() -> None:
|
||||
|
||||
@@ -6,10 +6,8 @@ from typing import Optional
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.Image import Resampling
|
||||
from torch import Tensor, zeros_like
|
||||
from transformers import ProcessorMixin
|
||||
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
|
||||
from transformers.image_utils import load_image
|
||||
from transformers.models.smolvlm import SmolVLMProcessor
|
||||
from transformers.models.voxtral import VoxtralProcessor
|
||||
|
||||
from axolotl.utils.dict import remove_none_values
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -71,10 +71,10 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
]
|
||||
|
||||
return {
|
||||
"chosen_input_ids": chosen_tokenized["input_ids"],
|
||||
"input_ids_chosen": chosen_tokenized["input_ids"],
|
||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||
"labels_chosen": 1.0,
|
||||
"rejected_input_ids": rejected_tokenized["input_ids"],
|
||||
"input_ids_rejected": rejected_tokenized["input_ids"],
|
||||
"attention_mask_rejected": rejected_tokenized["attention_mask"],
|
||||
"labels_rejected": 0.0,
|
||||
}
|
||||
|
||||
@@ -120,123 +120,3 @@ def default(cfg, dataset_idx=0, **kwargs):
|
||||
return result
|
||||
|
||||
return transform_fn, {"remove_columns": [field_messages]}
|
||||
|
||||
|
||||
def argilla_chat(cfg, dataset_idx=0, **kwargs):
|
||||
"""
|
||||
DPO chat template strategy for argilla-style datasets.
|
||||
|
||||
For argilla-style datasets where chosen/rejected contain full conversations
|
||||
instead of single response messages. Extracts the conversation history from
|
||||
the chosen field and formats both chosen/rejected responses using the
|
||||
configured chat template.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing chat_template and dataset settings
|
||||
dataset_idx: Index of the dataset in the config (default: 0)
|
||||
**kwargs: Additional keyword arguments (unused)
|
||||
|
||||
Returns:
|
||||
tuple: (transform_fn, dataset_kwargs) where:
|
||||
- transform_fn: Function to transform dataset samples
|
||||
- dataset_kwargs: Dict with 'remove_columns' specifying columns to drop
|
||||
|
||||
Dataset format:
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
"""
|
||||
ds_cfg = cfg["datasets"][dataset_idx]
|
||||
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
|
||||
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
message_property_mappings = ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
)
|
||||
role_map_inv = ds_cfg.get(
|
||||
"roles",
|
||||
{
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
)
|
||||
role_map = {}
|
||||
for target, sources in role_map_inv.items():
|
||||
for source in sources:
|
||||
role_map[source] = target
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
chat_template_string = get_chat_template(
|
||||
user_choice=chat_template_choice,
|
||||
jinja_template=chat_template_jinja,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
chosen_raw = sample[field_chosen]
|
||||
rejected_raw = sample[field_rejected]
|
||||
|
||||
# Extract messages (all but last) and responses (last message)
|
||||
chosen_messages = [
|
||||
{
|
||||
"role": role_map[m[message_property_mappings["role"]]],
|
||||
"content": m[message_property_mappings["content"]],
|
||||
}
|
||||
for m in chosen_raw[:-1]
|
||||
]
|
||||
chosen_response = {
|
||||
"role": role_map[chosen_raw[-1][message_property_mappings["role"]]],
|
||||
"content": chosen_raw[-1][message_property_mappings["content"]],
|
||||
}
|
||||
|
||||
rejected_response = {
|
||||
"role": role_map[rejected_raw[-1][message_property_mappings["role"]]],
|
||||
"content": rejected_raw[-1][message_property_mappings["content"]],
|
||||
}
|
||||
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
result = {}
|
||||
result["prompt"] = tokenizer.apply_chat_template(
|
||||
chosen_messages,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
result["chosen"] = tokenizer.apply_chat_template(
|
||||
[dummy_user_message, chosen_response],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
chosen_strip_index = result["chosen"].find(chosen_response["content"])
|
||||
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||
|
||||
result["rejected"] = tokenizer.apply_chat_template(
|
||||
[dummy_user_message, rejected_response],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
rejected_strip_index = result["rejected"].find(rejected_response["content"])
|
||||
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
|
||||
|
||||
return result
|
||||
|
||||
return transform_fn, {"remove_columns": [field_chosen, field_rejected]}
|
||||
|
||||
@@ -40,6 +40,11 @@ from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
try:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
@@ -136,6 +141,8 @@ def setup_signal_handler(
|
||||
def terminate_handler(_, __, model_weakref):
|
||||
if model_weakref() is not None:
|
||||
_model = model_weakref()
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
_model = BetterTransformer.reverse(_model)
|
||||
_model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
@@ -314,6 +321,9 @@ def save_trained_model(
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
@@ -525,17 +535,6 @@ def setup_model_and_trainer(
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
if cfg.use_ray:
|
||||
try:
|
||||
import ray.train.huggingface.transformers
|
||||
|
||||
trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"The Ray integration with Hugging Face Transformers is not available. "
|
||||
"To use Ray, install the 'ray[train]' package."
|
||||
)
|
||||
|
||||
return (
|
||||
trainer,
|
||||
model,
|
||||
|
||||
@@ -17,13 +17,6 @@ def is_comet_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
|
||||
def is_opentelemetry_available():
|
||||
return (
|
||||
importlib.util.find_spec("opentelemetry") is not None
|
||||
and importlib.util.find_spec("prometheus_client") is not None
|
||||
)
|
||||
|
||||
|
||||
def get_pytorch_version() -> tuple[int, int, int]:
|
||||
"""
|
||||
Get Pytorch version as a tuple of (major, minor, patch).
|
||||
|
||||
@@ -16,8 +16,8 @@ import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
import yaml
|
||||
from datasets import load_dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
@@ -28,6 +28,8 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
IntervalStrategy,
|
||||
SaveStrategy,
|
||||
)
|
||||
from trl.models import unwrap_model_for_generation
|
||||
@@ -54,6 +56,40 @@ IGNORE_INDEX = -100
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SaveBetterTransformerModelCallback(TrainerCallback):
|
||||
"""Callback to save the BetterTransformer wrapped model"""
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
) -> TrainerControl:
|
||||
# Save
|
||||
if (
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
control.should_save = True
|
||||
|
||||
if control.should_save:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
model = BetterTransformer.reverse(kwargs["model"])
|
||||
model.save_pretrained(checkpoint_folder)
|
||||
# FIXME - need to cleanup old checkpoints
|
||||
|
||||
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||
control.should_save = False
|
||||
return control
|
||||
|
||||
|
||||
class LossWatchDogCallback(TrainerCallback):
|
||||
"""Callback to track loss and stop training if loss is too high"""
|
||||
|
||||
@@ -760,37 +796,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
|
||||
try:
|
||||
with open(self.axolotl_config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
chat_tpl = cfg.get("chat_template_jinja")
|
||||
if chat_tpl:
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=True, suffix=".jinja", prefix="chat_template_"
|
||||
) as temp_ct_file:
|
||||
if (
|
||||
isinstance(chat_tpl, str)
|
||||
and os.path.exists(chat_tpl)
|
||||
and os.path.isfile(chat_tpl)
|
||||
):
|
||||
copyfile(chat_tpl, temp_ct_file.name)
|
||||
else:
|
||||
temp_ct_file.write(str(chat_tpl))
|
||||
temp_ct_file.flush()
|
||||
|
||||
artifact = wandb.Artifact(
|
||||
f"chat-template-{wandb.run.id}", type="jinja-template"
|
||||
)
|
||||
artifact.add_file(temp_ct_file.name)
|
||||
wandb.log_artifact(artifact)
|
||||
wandb.save(temp_ct_file.name)
|
||||
LOG.info(
|
||||
"The chat_template_jinja has been saved to the WandB run under files."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError, yaml.YAMLError) as err:
|
||||
LOG.warning(f"Error while saving chat_template_jinja to WandB: {err}")
|
||||
|
||||
if args.deepspeed:
|
||||
try:
|
||||
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
"""OpenTelemetry metrics callback for Axolotl training"""
|
||||
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from opentelemetry import metrics
|
||||
from opentelemetry.exporter.prometheus import PrometheusMetricReader
|
||||
from opentelemetry.metrics import set_meter_provider
|
||||
from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
OPENTELEMETRY_AVAILABLE = True
|
||||
except ImportError:
|
||||
LOG.warning("OpenTelemetry not available. pip install [opentelemetry]")
|
||||
OPENTELEMETRY_AVAILABLE = False
|
||||
|
||||
|
||||
class OpenTelemetryMetricsCallback(TrainerCallback):
|
||||
"""
|
||||
TrainerCallback that exports training metrics to OpenTelemetry/Prometheus.
|
||||
|
||||
This callback automatically tracks key training metrics including:
|
||||
- Training loss
|
||||
- Evaluation loss
|
||||
- Learning rate
|
||||
- Epoch progress
|
||||
- Global step count
|
||||
- Gradient norm
|
||||
|
||||
Metrics are exposed via HTTP endpoint for Prometheus scraping.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
LOG.warning("OpenTelemetry not available, metrics will not be collected")
|
||||
self.metrics_enabled = False
|
||||
return
|
||||
|
||||
self.cfg = cfg
|
||||
self.metrics_host = getattr(cfg, "otel_metrics_host", "localhost")
|
||||
self.metrics_port = getattr(cfg, "otel_metrics_port", 8000)
|
||||
self.metrics_enabled = True
|
||||
self.server_started = False
|
||||
self.metrics_lock = threading.Lock()
|
||||
|
||||
try:
|
||||
# Create Prometheus metrics reader
|
||||
prometheus_reader = PrometheusMetricReader()
|
||||
|
||||
# Create meter provider with Prometheus exporter
|
||||
provider = SDKMeterProvider(metric_readers=[prometheus_reader])
|
||||
set_meter_provider(provider)
|
||||
|
||||
# Get meter for creating metrics
|
||||
self.meter = metrics.get_meter("axolotl.training")
|
||||
|
||||
# Create metrics
|
||||
self._create_metrics()
|
||||
|
||||
except Exception as e:
|
||||
LOG.warning(f"Failed to initialize OpenTelemetry metrics: {e}")
|
||||
self.metrics_enabled = False
|
||||
|
||||
def _create_metrics(self):
|
||||
"""Create all metrics that will be tracked"""
|
||||
self.train_loss_gauge = self.meter.create_gauge(
|
||||
name="axolotl_train_loss",
|
||||
description="Current training loss",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.eval_loss_gauge = self.meter.create_gauge(
|
||||
name="axolotl_eval_loss",
|
||||
description="Current evaluation loss",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.learning_rate_gauge = self.meter.create_gauge(
|
||||
name="axolotl_learning_rate",
|
||||
description="Current learning rate",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.epoch_gauge = self.meter.create_gauge(
|
||||
name="axolotl_epoch",
|
||||
description="Current training epoch",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.global_step_counter = self.meter.create_counter(
|
||||
name="axolotl_global_steps",
|
||||
description="Total training steps completed",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.grad_norm_gauge = self.meter.create_gauge(
|
||||
name="axolotl_gradient_norm",
|
||||
description="Gradient norm",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.memory_usage_gauge = self.meter.create_gauge(
|
||||
name="axolotl_memory_usage",
|
||||
description="Current memory usage in MB",
|
||||
unit="MB",
|
||||
)
|
||||
|
||||
def _start_metrics_server(self):
|
||||
"""Start the HTTP server for metrics exposure"""
|
||||
if self.server_started:
|
||||
return
|
||||
|
||||
try:
|
||||
start_http_server(self.metrics_port, addr=self.metrics_host)
|
||||
self.server_started = True
|
||||
LOG.info(
|
||||
f"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(f"Failed to start OpenTelemetry metrics server: {e}")
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the beginning of training"""
|
||||
if not self.metrics_enabled:
|
||||
return
|
||||
|
||||
self._start_metrics_server()
|
||||
LOG.info("OpenTelemetry metrics collection started")
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs: Optional[Dict[str, float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called when logging occurs"""
|
||||
if not self.metrics_enabled or not logs:
|
||||
return
|
||||
|
||||
if "loss" in logs:
|
||||
self.train_loss_gauge.set(logs["loss"])
|
||||
|
||||
if "eval_loss" in logs:
|
||||
self.eval_loss_gauge.set(logs["eval_loss"])
|
||||
|
||||
if "learning_rate" in logs:
|
||||
self.learning_rate_gauge.set(logs["learning_rate"])
|
||||
|
||||
if "epoch" in logs:
|
||||
self.epoch_gauge.set(logs["epoch"])
|
||||
|
||||
if "grad_norm" in logs:
|
||||
self.grad_norm_gauge.set(logs["grad_norm"])
|
||||
if "memory_usage" in logs:
|
||||
self.memory_usage_gauge.set(logs["memory_usage"])
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the end of each training step"""
|
||||
if not self.metrics_enabled:
|
||||
return
|
||||
|
||||
# Update step counter and epoch
|
||||
self.global_step_counter.add(1)
|
||||
if state.epoch is not None:
|
||||
self.epoch_gauge.set(state.epoch)
|
||||
|
||||
def on_evaluate(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
metrics: Optional[Dict[str, float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called after evaluation"""
|
||||
if not self.metrics_enabled or not metrics:
|
||||
return
|
||||
|
||||
if "eval_loss" in metrics:
|
||||
self.eval_loss_gauge.set(metrics["eval_loss"])
|
||||
|
||||
# Record any other eval metrics as gauges
|
||||
for key, value in metrics.items():
|
||||
if key.startswith("eval_") and isinstance(value, (int, float)):
|
||||
# Create gauge for this metric if it doesn't exist
|
||||
gauge_name = f"axolotl_{key}"
|
||||
try:
|
||||
gauge = self.meter.create_gauge(
|
||||
name=gauge_name,
|
||||
description=f"Evaluation metric: {key}",
|
||||
unit="1",
|
||||
)
|
||||
gauge.set(value)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Failed to create/update metric {gauge_name}: {e}")
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the end of training"""
|
||||
if not self.metrics_enabled:
|
||||
return
|
||||
|
||||
LOG.info("Training completed. OpenTelemetry metrics collection finished.")
|
||||
LOG.info(
|
||||
f"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics"
|
||||
)
|
||||
@@ -113,7 +113,7 @@ def _map_dataset(
|
||||
|
||||
dataset = dataset.map(
|
||||
ds_transform_fn,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Mapping RL Dataset",
|
||||
**map_kwargs,
|
||||
@@ -234,7 +234,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
|
||||
@@ -239,11 +239,6 @@ def _load_from_local_path(
|
||||
return load_dataset(dataset_config.path, **load_dataset_kwargs)
|
||||
elif local_path.is_file():
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
|
||||
# For single file datasets, HF always creates only a "train" split
|
||||
if dataset_type in ("json", "csv", "text"):
|
||||
load_dataset_kwargs["split"] = "train"
|
||||
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.path,
|
||||
@@ -414,7 +409,7 @@ def save_preprocessed_dataset(
|
||||
) -> None:
|
||||
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||
num_workers = cfg.dataset_num_proc or get_default_process_count()
|
||||
num_workers = cfg.dataset_processes or get_default_process_count()
|
||||
if isinstance(dataset, IterableDataset):
|
||||
ds_from_iter = Dataset.from_generator(
|
||||
functools.partial(_generate_from_iterable_dataset, dataset),
|
||||
|
||||
@@ -223,7 +223,7 @@ def handle_long_seq_in_dataset(
|
||||
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
|
||||
drop_long_kwargs = {}
|
||||
|
||||
@@ -80,7 +80,7 @@ def get_dataset_wrapper(
|
||||
"""
|
||||
# Common parameters for dataset wrapping
|
||||
dataset_kwargs: dict[str, Any] = {
|
||||
"process_count": cfg.dataset_num_proc,
|
||||
"process_count": cfg.dataset_processes,
|
||||
"keep_in_memory": cfg.dataset_keep_in_memory is True,
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ import os
|
||||
|
||||
|
||||
def get_default_process_count():
|
||||
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
|
||||
return int(axolotl_dataset_num_proc)
|
||||
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
|
||||
return int(axolotl_dataset_processes)
|
||||
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
|
||||
|
||||
@@ -3,46 +3,66 @@ utils to get GPU info for the current environment
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
from importlib.metadata import version
|
||||
|
||||
import torch
|
||||
from accelerate.utils.environment import (
|
||||
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||
get_gpu_info,
|
||||
)
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def check_cuda_p2p_ib_support():
|
||||
if not accelerate_check_cuda_p2p_ib_support():
|
||||
return False
|
||||
if not check_cuda_p2p_support():
|
||||
if not check_runpod_p2p_support():
|
||||
return False
|
||||
unsupported_devices = {"RTX 6000 Ada", "L40S"}
|
||||
try:
|
||||
device_names, device_count = get_gpu_info()
|
||||
if 1 < device_count < 8:
|
||||
if any(
|
||||
unsupported_device in device_name
|
||||
for device_name in device_names
|
||||
for unsupported_device in unsupported_devices
|
||||
):
|
||||
return False
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
def check_cuda_p2p_support() -> bool:
|
||||
def check_runpod_p2p_support() -> bool:
|
||||
if "RUNPOD_GPU_COUNT" not in os.environ:
|
||||
return True
|
||||
try:
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1"))
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
if world_size > 1:
|
||||
node_world_size = int(os.environ.get("NODE_WORLD_SIZE", "8"))
|
||||
local_other_rank = (local_rank // node_world_size) * node_world_size
|
||||
local_other_rank += 1 if (local_rank % node_world_size) == 0 else 0
|
||||
if gpu_count >= 2:
|
||||
# run `nvidia-smi topo -p2p n` and inspect the GPU0 row
|
||||
try:
|
||||
can_p2p = torch.cuda.can_device_access_peer(local_rank, local_other_rank)
|
||||
except AssertionError as exc:
|
||||
# some sort of logic error in indexing processes, assume p2p is fine for now
|
||||
LOG.warning(exc)
|
||||
result = subprocess.run( # nosec B603 B607
|
||||
["nvidia-smi", "topo", "-p2p", "n"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (
|
||||
subprocess.CalledProcessError,
|
||||
FileNotFoundError,
|
||||
subprocess.TimeoutExpired,
|
||||
):
|
||||
return True # fail-open if detection fails
|
||||
output_lines = result.stdout.strip().split("\n")
|
||||
# filter rows that start with "GPU0" (avoid header row)
|
||||
gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")]
|
||||
if not gpu0_rows:
|
||||
return True
|
||||
return can_p2p
|
||||
|
||||
# consider P2P supported if any OK is present in the GPU0 row
|
||||
return "OK" in gpu0_rows[-1]
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ def load_sharded_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
use_cache=False,
|
||||
dtype=torch.float32,
|
||||
torch_dtype=torch.float32,
|
||||
_attn_implementation=model_config._attn_implementation,
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
)
|
||||
@@ -158,7 +158,7 @@ def load_sharded_model(
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
dtype=torch_dtype,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
)
|
||||
return model
|
||||
|
||||
@@ -5,7 +5,6 @@ into fixed-capacity batches to optimize memory usage and training throughput.
|
||||
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
@@ -292,10 +291,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.total_token_slots = 0
|
||||
|
||||
# The number of times to calculate batches to determine minimum packed dataset length
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
self.num_count_samples = (
|
||||
1 if world_size >= num_count_samples else num_count_samples
|
||||
)
|
||||
self.num_count_samples = num_count_samples
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warning(
|
||||
|
||||
@@ -24,13 +24,11 @@ from axolotl.utils.schemas.datasets import (
|
||||
)
|
||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||
from axolotl.utils.schemas.fsdp import FSDPConfig
|
||||
from axolotl.utils.schemas.integrations import (
|
||||
CometConfig,
|
||||
GradioConfig,
|
||||
LISAConfig,
|
||||
MLFlowConfig,
|
||||
OpenTelemetryConfig,
|
||||
RayConfig,
|
||||
WandbConfig,
|
||||
)
|
||||
@@ -61,7 +59,6 @@ class AxolotlInputConfig(
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
CometConfig,
|
||||
OpenTelemetryConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RayConfig,
|
||||
@@ -116,6 +113,19 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
moe_kernels: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
|
||||
},
|
||||
)
|
||||
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
|
||||
default="mg",
|
||||
json_schema_extra={
|
||||
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
|
||||
},
|
||||
)
|
||||
|
||||
trainer_cls: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -236,7 +246,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
dataset_processes: int | None = Field(
|
||||
default=None,
|
||||
deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.",
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
|
||||
@@ -244,16 +253,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
},
|
||||
)
|
||||
dataset_num_proc: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
|
||||
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
dataset_exact_deduplication: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -681,7 +680,8 @@ class AxolotlInputConfig(
|
||||
json_schema_extra={"description": "FSDP configuration"},
|
||||
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
|
||||
)
|
||||
fsdp_config: FSDPConfig | None = Field(
|
||||
# TODO @SalmanMohammadi strongly type this as its own schema
|
||||
fsdp_config: dict[str, Any] | None = Field(
|
||||
default=None, json_schema_extra={"description": "FSDP configuration options"}
|
||||
)
|
||||
fsdp_version: int | None = Field(
|
||||
@@ -1327,22 +1327,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def default_dataset_num_proc(cls, data):
|
||||
if data.get("dataset_processes") is not None:
|
||||
if data.get("dataset_num_proc") is None:
|
||||
data["dataset_num_proc"] = data["dataset_processes"]
|
||||
LOG.warning(
|
||||
"dataset_processes is deprecated and will be removed in a future version. "
|
||||
"Please use dataset_num_proc instead."
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"Both dataset_processes and dataset_num_proc are set. "
|
||||
"Using dataset_num_proc and ignoring dataset_processes."
|
||||
)
|
||||
del data["dataset_processes"]
|
||||
elif data.get("dataset_num_proc") is None:
|
||||
data["dataset_num_proc"] = get_default_process_count()
|
||||
def default_dataset_processes(cls, data):
|
||||
if data.get("dataset_processes") is None:
|
||||
data["dataset_processes"] = get_default_process_count()
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user