Compare commits
47 Commits
moekernels
...
liger-063
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ee7ce5c85 | ||
|
|
a41ca4d06f | ||
|
|
9d4d39e939 | ||
|
|
bb33fda44d | ||
|
|
4dc018992d | ||
|
|
243620394a | ||
|
|
3750fdcf79 | ||
|
|
613bcf90e5 | ||
|
|
383f220cfd | ||
|
|
8bb871b5cf | ||
|
|
87565ecc05 | ||
|
|
93ba57396f | ||
|
|
aa1240acd8 | ||
|
|
4cdfdfebb5 | ||
|
|
6e2f5ccf9f | ||
|
|
8c7f63cf97 | ||
|
|
cd856b45b1 | ||
|
|
143dea4753 | ||
|
|
bc2ffb8204 | ||
|
|
153edcfe79 | ||
|
|
08b8fa62cc | ||
|
|
3a5c97e6e5 | ||
|
|
37f78c8592 | ||
|
|
ab63b92c38 | ||
|
|
6f8ce024d1 | ||
|
|
d0e9c3c1c5 | ||
|
|
4c3488cc9f | ||
|
|
130637a3fa | ||
|
|
377c510e95 | ||
|
|
409cfb8a87 | ||
|
|
ce74c20109 | ||
|
|
a6bfbe3400 | ||
|
|
f4376748f3 | ||
|
|
740d5a1d31 | ||
|
|
850c1a5f8d | ||
|
|
7fa8ac40cd | ||
|
|
f9748c4dc5 | ||
|
|
33975ce4bc | ||
|
|
e8b962d47f | ||
|
|
856ff12171 | ||
|
|
6bc959342b | ||
|
|
b3b92687c4 | ||
|
|
55d1be2ae6 | ||
|
|
08d831c3d5 | ||
|
|
7be8740c5c | ||
|
|
c51d6b06c3 | ||
|
|
09959fac70 |
35
.github/workflows/base.yml
vendored
35
.github/workflows/base.yml
vendored
@@ -25,20 +25,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
@@ -67,6 +53,13 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
@@ -122,13 +115,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
@@ -150,6 +136,13 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
15
.github/workflows/main.yml
vendored
15
.github/workflows/main.yml
vendored
@@ -15,11 +15,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -88,11 +83,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -162,11 +152,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
|
||||
7
.github/workflows/multi-gpu-e2e.yml
vendored
7
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,13 +26,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
|
||||
20
.github/workflows/nightlies.yml
vendored
20
.github/workflows/nightlies.yml
vendored
@@ -12,16 +12,16 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -65,16 +65,16 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
10
.github/workflows/tests-nightly.yml
vendored
10
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.0"]
|
||||
pytorch_version: ["2.7.1", "2.8.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -102,14 +102,14 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
|
||||
18
.github/workflows/tests.yml
vendored
18
.github/workflows/tests.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
||||
pytorch_version: ["2.7.1", "2.8.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -81,12 +81,12 @@ jobs:
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
pip3 install --no-cache-dir --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
@@ -130,7 +130,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
||||
pytorch_version: ["2.7.1", "2.8.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -156,13 +156,13 @@ jobs:
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-build-isolation dist/axolotl*.tar.gz
|
||||
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
@@ -286,12 +286,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
|
||||
@@ -11,13 +11,13 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.12
|
||||
rev: v0.14.0
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.1
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
|
||||
@@ -73,7 +73,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.6.0
|
||||
- PyTorch ≥2.7.1
|
||||
|
||||
### Google Colab
|
||||
|
||||
|
||||
@@ -267,6 +267,7 @@ website:
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
- docs/optimizations.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
@@ -285,7 +286,6 @@ website:
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
- docs/moe_backends.md
|
||||
- docs/nd_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
|
||||
@@ -32,6 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN uv pip install torchvision
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
@@ -9,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_PROCESSES="8"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
@@ -65,8 +65,13 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
sp_env = os.environ.copy()
|
||||
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
|
||||
exit(exit_code)
|
||||
try:
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
print(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
return exit_code
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Command '{cmd}' failed with exception {e}")
|
||||
|
||||
@@ -13,7 +13,7 @@ datasets:
|
||||
val_set_size: 0
|
||||
output_dir: temp_debug/axolotl_outputs/model
|
||||
dataset_prepared_path: temp_debug/axolotl_outputs/data
|
||||
dataset_processes: 1
|
||||
dataset_num_proc: 1
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
|
||||
@@ -47,6 +47,8 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
|
||||
@@ -30,7 +30,13 @@ RUN uv venv --no-project --relocatable axolotl-venv
|
||||
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
|
||||
@@ -212,6 +212,14 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
|
||||
|
||||
```
|
||||
"arguments": "{\"...\": \"...\"}"
|
||||
```
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
|
||||
@@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet
|
||||
|
||||
### Pre-training without streaming
|
||||
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
|
||||
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
|
||||
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
|
||||
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
|
||||
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
|
||||
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
|
||||
|
||||
```yaml
|
||||
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
||||
// The flags below simplify debugging by overriding the axolotl config
|
||||
// with the debugging tips above. Modify as needed.
|
||||
"--dataset_processes=1", // limits data preprocessing to one process
|
||||
"--dataset_num_proc=1", // limits data preprocessing to one process
|
||||
"--max_steps=1", // limits training to just one step
|
||||
"--batch_size=1", // minimizes batch size
|
||||
"--micro_batch_size=1", // minimizes batch size
|
||||
|
||||
12
docs/faq.qmd
12
docs/faq.qmd
@@ -63,6 +63,14 @@ description: Frequently asked questions
|
||||
|
||||
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
||||
|
||||
**Q: Can we mix text and text+image datasets for VLM training?**
|
||||
|
||||
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
|
||||
|
||||
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
|
||||
|
||||
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
@@ -140,3 +148,7 @@ description: Frequently asked questions
|
||||
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
|
||||
|
||||
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
|
||||
|
||||
**Q: `Error parsing tool_calls arguments as JSON.`
|
||||
|
||||
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "FDSP + QLoRA"
|
||||
title: "FSDP + QLoRA"
|
||||
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
|
||||
format:
|
||||
html:
|
||||
@@ -23,6 +23,12 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Enabling Swap for FSDP2
|
||||
|
||||
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
|
||||
|
||||
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
|
||||
|
||||
## Example Config
|
||||
|
||||
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
|
||||
|
||||
@@ -5,10 +5,11 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU
|
||||
and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom
|
||||
autograd functions. Our goal was to leverage operator fusion and tensor re-use in order
|
||||
to improve speed and reduce memory usage during the forward and backward passes of
|
||||
these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
@@ -131,6 +132,5 @@ computation path.
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Additional operator fusions
|
||||
|
||||
@@ -27,3 +27,9 @@ learning_rate: 2e-5
|
||||
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||
self attention `q_proj` module.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
|
||||
|
||||
:::
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
MoE Backends in Axolotl
|
||||
|
||||
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
|
||||
|
||||
- Set `moe_backend: auto|torch_grouped|naive`
|
||||
|
||||
Behavior
|
||||
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
|
||||
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
|
||||
- naive: keeps the reference per-expert loop.
|
||||
|
||||
Notes
|
||||
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
|
||||
- No changes to training scripts are required; selection happens inside the model forward.
|
||||
|
||||
Example
|
||||
moe_backend: torch_grouped
|
||||
accelerate launch -m axolotl.cli.train path/to/config.yaml
|
||||
@@ -88,6 +88,7 @@ fsdp_sync_module_states | **REMOVED**
|
||||
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
|
||||
fsdp_state_dict_type | state_dict_type
|
||||
fsdp_use_orig_params | **REMOVED**
|
||||
fsdp_activation_checkpointing | activation_checkpointing
|
||||
|
||||
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
|
||||
if you were using the following FSDP1 config:
|
||||
|
||||
@@ -13,6 +13,7 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
@@ -41,7 +42,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
# (optional) if doing lora, only finetune the Language model,
|
||||
# leave the vision model and vision tower frozen
|
||||
@@ -56,10 +56,14 @@ image_resize_algorithm: bilinear
|
||||
|
||||
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||
|
||||
::: {.callout-warning}
|
||||
::: {.callout-tip}
|
||||
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.
|
||||
:::
|
||||
|
||||
### Mllama {#sec-mllama}
|
||||
|
||||
```yaml
|
||||
@@ -94,10 +98,22 @@ chat_template: llava
|
||||
|
||||
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
```
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Magistral-Small-2509
|
||||
```
|
||||
|
||||
### Voxtral {#sec-voxtral}
|
||||
@@ -156,6 +172,14 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### Qwen3-VL {#sec-qwen3-vl}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### SmolVLM2 {#sec-smolvlm2}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
133
docs/optimizations.qmd
Normal file
133
docs/optimizations.qmd
Normal file
@@ -0,0 +1,133 @@
|
||||
---
|
||||
title: Optimizations Guide
|
||||
description: A guide to the performance and memory optimizations available in Axolotl.
|
||||
---
|
||||
|
||||
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
|
||||
|
||||
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
|
||||
|
||||
## Speed Optimizations
|
||||
|
||||
These optimizations focus on increasing training throughput and reducing total training time.
|
||||
|
||||
### Sample Packing
|
||||
|
||||
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
|
||||
|
||||
- **Config:** `sample_packing: true`
|
||||
- **Learn more:** [Sample Packing](multipack.qmd)
|
||||
|
||||
### Attention Implementations
|
||||
|
||||
Using an optimized attention implementation is critical for training speed.
|
||||
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
||||
|
||||
*Note: You should only enable one attention backend.*
|
||||
|
||||
### LoRA Optimizations
|
||||
|
||||
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
|
||||
|
||||
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
|
||||
|
||||
## Memory Optimizations
|
||||
|
||||
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
|
||||
|
||||
### Parameter Efficient Finetuning (LoRA & QLoRA)
|
||||
|
||||
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
|
||||
|
||||
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
|
||||
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
|
||||
|
||||
### Gradient Checkpointing & Activation Offloading
|
||||
|
||||
These techniques save VRAM by changing how activations are handled.
|
||||
|
||||
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
|
||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||
|
||||
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
|
||||
|
||||
### Liger Kernels
|
||||
|
||||
Provides efficient Triton kernels to improve training speed and reduce memory usage.
|
||||
|
||||
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
||||
|
||||
## Long Context Models
|
||||
|
||||
Techniques to train models on sequences longer than their original context window.
|
||||
|
||||
### RoPE Scaling
|
||||
|
||||
Extends a model's context window by interpolating its Rotary Position Embeddings.
|
||||
|
||||
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
|
||||
|
||||
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
|
||||
|
||||
### Artic Long Sequence Training (ALST)
|
||||
|
||||
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
|
||||
|
||||
- TiledMLP to reduce memory usage in MLP layers.
|
||||
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
|
||||
- Activation Offloading to CPU.
|
||||
|
||||
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
|
||||
|
||||
## Large Models (Distributed Training)
|
||||
|
||||
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
|
||||
|
||||
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
|
||||
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
|
||||
|
||||
### N-D Parallelism (Beta)
|
||||
|
||||
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
|
||||
|
||||
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Techniques to reduce the precision of model weights for memory savings.
|
||||
|
||||
### 4-bit Training (QLoRA)
|
||||
|
||||
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
|
||||
|
||||
### FP8 Training
|
||||
|
||||
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
|
||||
|
||||
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
|
||||
|
||||
### Quantization Aware Training (QAT)
|
||||
|
||||
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
|
||||
|
||||
- **Learn more:** [QAT Documentation](qat.qmd)
|
||||
|
||||
### GPTQ
|
||||
|
||||
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
||||
|
||||
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
||||
@@ -30,6 +30,7 @@ qat:
|
||||
```
|
||||
|
||||
We support the following quantization schemas:
|
||||
|
||||
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
|
||||
- `Int8DynamicActivationInt4Weight`
|
||||
- `Float8DynamicActivationFloat8Weight`
|
||||
|
||||
@@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chat_template.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### chat_template.default
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -6,6 +6,8 @@ LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-
|
||||
|
||||
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
Thanks to the team at LiquidAI for giving us early access to prepare for these releases.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
@@ -31,6 +33,14 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
|
||||
```
|
||||
|
||||
**LFM2-MoE**
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
|
||||
# LoRA SFT (1x48GB @ 16.2GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
@@ -45,14 +55,13 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
|
||||
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
|
||||
- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
base_model: LiquidAI/LFM2-350M
|
||||
|
||||
chunked_cross_entropy: true
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
|
||||
59
examples/LiquidAI/lfm2-8b-a1b-lora.yaml
Normal file
59
examples/LiquidAI/lfm2-8b-a1b-lora.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
base_model: LiquidAI/LFM2-8B-A1B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -3,6 +3,9 @@ trust_remote_code: true
|
||||
model_type: AutoModelForImageTextToText
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
|
||||
@@ -7,3 +7,24 @@ techniques. It is a combination of:
|
||||
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
|
||||
|
||||
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
tiled_mlp: true
|
||||
|
||||
# See Sequence Parallelism docs
|
||||
# https://docs.axolotl.ai/docs/sequence_parallelism.html
|
||||
context_parallel_size: int
|
||||
|
||||
plugins:
|
||||
# See Cut Cross Entropy docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# or Liger Kernel docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
# ...
|
||||
|
||||
```
|
||||
|
||||
110
examples/apertus/README.md
Normal file
110
examples/apertus/README.md
Normal file
@@ -0,0 +1,110 @@
|
||||
# Finetune Swiss-AI's Apertus with Axolotl
|
||||
|
||||
[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. (Optional, highly recommended) Install XIELU CUDA
|
||||
|
||||
```bash
|
||||
## Recommended for reduced VRAM and faster speeds
|
||||
|
||||
# Point to CUDA toolkit directory
|
||||
# For those using our Docker image, use the below path.
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/apertus/apertus-8b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 8.7 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`.
|
||||
- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
### XIELU Installation Issues
|
||||
|
||||
#### `ModuleNotFoundError: No module named 'torch'`
|
||||
|
||||
Please check these one by one:
|
||||
- Running in correct environment
|
||||
- Env has PyTorch installed
|
||||
- CUDA toolkit is at `CUDA_HOME`
|
||||
|
||||
If those didn't help, please try the below solutions:
|
||||
|
||||
1. Pass env for CMAKE and try install again:
|
||||
|
||||
```bash
|
||||
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
2. Git clone the repo and manually hardcode python path:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/nickjbrowning/XIELU
|
||||
cd xielu
|
||||
git checkout 59d6031
|
||||
|
||||
cd xielu
|
||||
nano CMakeLists.txt # or vi depending on your preference
|
||||
```
|
||||
|
||||
```diff
|
||||
execute_process(
|
||||
- COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||
+ COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
|
||||
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
|
||||
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
|
||||
)
|
||||
```
|
||||
|
||||
```bash
|
||||
pip3 install . --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: swiss-ai/Apertus-8B-Instruct-2509
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -19,6 +19,9 @@ cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -18,7 +18,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -23,7 +23,15 @@ pip3 install timm==1.0.17
|
||||
pip3 install librosa==0.11.0
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
3. Download sample dataset files
|
||||
|
||||
```bash
|
||||
# for text + vision + audio only
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# text only
|
||||
|
||||
@@ -66,6 +66,7 @@ fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
# fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -29,7 +29,7 @@ flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
save_strategy: no
|
||||
torch_compile: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -12,15 +12,6 @@ chat_template: llama3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
@@ -0,0 +1,50 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/opentelemetry-example
|
||||
|
||||
adapter: qlora
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# OpenTelemetry Configuration
|
||||
use_otel_metrics: true
|
||||
otel_metrics_host: "localhost"
|
||||
otel_metrics_port: 8000
|
||||
|
||||
# Disable WandB
|
||||
use_wandb: false
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: false
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -46,7 +46,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -45,7 +45,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Finetune Magistral Small with Axolotl
|
||||
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506) and [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
|
||||
|
||||
## Getting started
|
||||
|
||||
@@ -36,29 +36,17 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Thinking
|
||||
|
||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities. The model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
|
||||
|
||||
Example format:
|
||||
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [{ "type": "text", "text": "..."}]},
|
||||
{"role": "assistant", "content": [{ "type": "thinking", "thinking": "..."}, { "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
### Vision
|
||||
|
||||
Example config: `./magistral-small-think-qlora.yaml`.
|
||||
MistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities.
|
||||
|
||||
The `thinking` section also supports an optional arg `closed: bool` (`True` default) which controls adding the closing `[/THINK]` tag.
|
||||
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
|
||||
|
||||
Limitations:
|
||||
- You cannot mix `content: str` with `content: list[dict]` as the `dataset.load_dataset` may complain about different types for `content` key.
|
||||
- This mode does not work with custom `train_detail` and `training` at the moment.
|
||||
|
||||
### TIPS
|
||||
### Tips
|
||||
|
||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
@@ -89,5 +77,5 @@ In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||
- Add parity to Preference Tuning, RL, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
|
||||
73
examples/magistral/think/README.md
Normal file
73
examples/magistral/think/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Magistral Small Thinking Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
Run the thinking model fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train examples/magistral/think/magistral-small-think-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 19.1 GiB VRAM.
|
||||
|
||||
### Tips
|
||||
|
||||
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
|
||||
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
|
||||
Example format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Options
|
||||
|
||||
The `thinking` section supports an optional `closed` parameter:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Internal reasoning here...",
|
||||
"closed": true // Default: true, controls adding the closing [/THINK] tag
|
||||
}
|
||||
```
|
||||
60
examples/magistral/vision/README.md
Normal file
60
examples/magistral/vision/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Magistral Small Vision Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 17GiB VRAM.
|
||||
|
||||
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
|
||||
|
||||
### Tips
|
||||
|
||||
Key differences from text-only model:
|
||||
- `max_tokens: 131072` for inference
|
||||
- Multi-modal dataset format required
|
||||
- Sample packing not supported
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
@@ -0,0 +1,64 @@
|
||||
base_model: mistralai/Magistral-Small-2509
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
51
examples/mistral/mistral-small/README.md
Normal file
51
examples/mistral/mistral-small/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Mistral Small 3.1/3.2 Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml
|
||||
```
|
||||
|
||||
This config uses about 29.4 GiB VRAM.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
@@ -8,12 +11,12 @@ skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
@@ -36,7 +39,7 @@ wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
@@ -48,8 +51,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
# flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
||||
sdp_attention: true
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
@@ -1,53 +0,0 @@
|
||||
base_model: Qwen/Qwen1.5-MoE-A2.7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
# Keep VRAM low
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/qwen2-moe-qlora-10gb
|
||||
|
||||
# Train small to fit 10GB
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 5
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.03
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
special_tokens:
|
||||
@@ -12,15 +12,6 @@ chat_template: phi_3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -45,8 +45,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
# flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||
sdp_attention: true
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
64
examples/qwen3-next/README.md
Normal file
64
examples/qwen3-next/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Finetune Qwen3-Next with Axolotl
|
||||
|
||||
[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Install Qwen3-Next transformers commit
|
||||
```bash
|
||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||
```
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 45.62 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
68
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
68
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen3-Next-80B-A3B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- linear_attn.in_proj_ba
|
||||
- linear_attn.in_proj_qkvz
|
||||
- linear_attn.out_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert_gate
|
||||
- mlp.gate
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -27,7 +27,14 @@ pip3 install 'mistral_common[audio]==1.8.3'
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
3. Download sample dataset files
|
||||
|
||||
```bash
|
||||
# for text + audio only
|
||||
wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# text only
|
||||
|
||||
@@ -5,20 +5,19 @@ bitsandbytes==0.47.0
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.6.1
|
||||
liger-kernel==0.6.3
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft>=0.17.0
|
||||
transformers==4.56.1
|
||||
peft>=0.17.1
|
||||
tokenizers>=0.21.1
|
||||
transformers==4.57.1
|
||||
accelerate==1.10.1
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.23.0
|
||||
trl==0.23.1
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
trackio
|
||||
@@ -70,4 +69,4 @@ schedulefree==1.4.1
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.3
|
||||
mistral-common==1.8.5
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Benchmark Hugging Face Qwen2 MoE block with and without grouped_mm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
except Exception: # pragma: no cover
|
||||
tg = None
|
||||
|
||||
|
||||
def bench(run, *, iters: int, warmup: int, sync: bool = True) -> float:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for _ in range(warmup):
|
||||
run()
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
run()
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - start) * 1000.0)
|
||||
return sum(times) / len(times)
|
||||
|
||||
|
||||
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||||
return 6.0 * tokens * top_k * hidden * inter
|
||||
|
||||
|
||||
def load_hf_block(
|
||||
hidden: int,
|
||||
inter: int,
|
||||
experts: int,
|
||||
top_k: int,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
transformers_src = project_root / "transformers" / "src"
|
||||
if transformers_src.exists() and str(transformers_src) not in sys.path:
|
||||
sys.path.append(str(transformers_src))
|
||||
|
||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
cfg = Qwen2MoeConfig(
|
||||
hidden_size=hidden,
|
||||
moe_intermediate_size=inter,
|
||||
shared_expert_intermediate_size=inter,
|
||||
num_experts=experts,
|
||||
num_experts_per_tok=top_k,
|
||||
norm_topk_prob=True,
|
||||
qkv_bias=True,
|
||||
)
|
||||
|
||||
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped.load_state_dict(block.state_dict())
|
||||
return block, block_grouped
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
|
||||
p.add_argument("--bsz", type=int, default=8)
|
||||
p.add_argument("--seq", type=int, default=1024)
|
||||
p.add_argument("--hidden", type=int, default=4096)
|
||||
p.add_argument("--inter", type=int, default=14336)
|
||||
p.add_argument("--experts", type=int, default=32)
|
||||
p.add_argument("--top_k", type=int, default=4)
|
||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||
p.add_argument("--iters", type=int, default=50)
|
||||
p.add_argument("--warmup", type=int, default=10)
|
||||
p.add_argument("--profile", action="store_true")
|
||||
p.add_argument(
|
||||
"--compile",
|
||||
action="store_true",
|
||||
help="Torch.compile both paths before benchmarking",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[args.dtype]
|
||||
|
||||
torch.manual_seed(0)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
block_naive, block_grouped = load_hf_block(
|
||||
args.hidden,
|
||||
args.inter,
|
||||
args.experts,
|
||||
args.top_k,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
tokens = args.bsz * args.seq
|
||||
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||||
print(
|
||||
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
|
||||
f"experts={args.experts} top_k={args.top_k}"
|
||||
)
|
||||
|
||||
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
|
||||
|
||||
# Optional torch.compile
|
||||
run_grouped_impl = None
|
||||
if args.compile:
|
||||
dynamo.config.capture_scalar_outputs = True
|
||||
dynamo.config.allow_unspec_int_on_nn_module = True
|
||||
try:
|
||||
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||
except Exception as exc: # pragma: no cover
|
||||
print(f"torch.compile naive failed ({exc}); using eager")
|
||||
else:
|
||||
|
||||
def grouped_forward(inp, *, block=block_grouped):
|
||||
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
|
||||
y, _ = tg.moe_ffn_forward_grouped(
|
||||
inp, block.gate, block.experts, block.top_k
|
||||
)
|
||||
return y
|
||||
|
||||
try:
|
||||
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
|
||||
except Exception as exc: # pragma: no cover
|
||||
print(f"torch.compile grouped failed ({exc}); using eager")
|
||||
run_grouped_impl = None
|
||||
|
||||
def run_naive(block=block_naive, data=x):
|
||||
y, _ = block(data)
|
||||
return y
|
||||
|
||||
def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
|
||||
if impl is not None:
|
||||
return impl(data)
|
||||
if tg is None or not tg.available():
|
||||
return torch.empty(0)
|
||||
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
|
||||
y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
|
||||
return y if y is not None else torch.empty(0)
|
||||
|
||||
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
|
||||
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
|
||||
print(
|
||||
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
y_ref = run_naive()
|
||||
|
||||
if tg is None or not tg.available():
|
||||
print("torch_grouped\tN/A (unavailable)")
|
||||
return
|
||||
|
||||
y_grouped = run_grouped()
|
||||
if y_grouped.numel() == 0:
|
||||
print("torch_grouped\tN/A (op not callable)")
|
||||
return
|
||||
|
||||
t_grouped = bench(run_grouped, iters=args.iters, warmup=args.warmup)
|
||||
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_grouped
|
||||
print(
|
||||
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
|
||||
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
|
||||
diff = (y_ref.float() - y_grouped.float()).abs()
|
||||
print(
|
||||
"torch_grouped_check: "
|
||||
f"max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} "
|
||||
f"rel_l2={(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item():.3e}"
|
||||
)
|
||||
|
||||
if args.profile:
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
|
||||
) as prof:
|
||||
run_naive()
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
|
||||
) as prof:
|
||||
run_grouped()
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,311 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Sweep grouped_mm vs naive performance for Qwen2 MoE block."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import sys
|
||||
import time
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
except Exception: # pragma: no cover
|
||||
tg = None
|
||||
|
||||
|
||||
def _parse_list(arg: str) -> List[int]:
|
||||
return [int(v) for v in arg.split(",") if v]
|
||||
|
||||
|
||||
def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float:
|
||||
for _ in range(warmup):
|
||||
run()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times: List[float] = []
|
||||
for _ in range(iters):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
run()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - start) * 1000.0)
|
||||
return sum(times) / len(times)
|
||||
|
||||
|
||||
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||||
return 6.0 * tokens * top_k * hidden * inter
|
||||
|
||||
|
||||
def _load_block(
|
||||
hidden: int,
|
||||
inter: int,
|
||||
experts: int,
|
||||
top_k: int,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
transformers_src = project_root / "transformers" / "src"
|
||||
if transformers_src.exists() and str(transformers_src) not in sys.path:
|
||||
sys.path.append(str(transformers_src))
|
||||
|
||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
cfg = Qwen2MoeConfig(
|
||||
hidden_size=hidden,
|
||||
moe_intermediate_size=inter,
|
||||
shared_expert_intermediate_size=inter,
|
||||
num_experts=experts,
|
||||
num_experts_per_tok=top_k,
|
||||
norm_topk_prob=True,
|
||||
qkv_bias=True,
|
||||
)
|
||||
|
||||
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped.load_state_dict(block.state_dict())
|
||||
return block, block_grouped
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
bsz: int
|
||||
seq: int
|
||||
hidden: int
|
||||
inter: int
|
||||
experts: int
|
||||
top_k: int
|
||||
dtype: str
|
||||
naive_ms: float
|
||||
grouped_ms: float
|
||||
speedup: float
|
||||
naive_tflops: float
|
||||
grouped_tflops: float
|
||||
max_abs: float
|
||||
mean_abs: float
|
||||
rel_l2: float
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="Grouped MoE sweep")
|
||||
p.add_argument("--batch-sizes", default="4,8,16")
|
||||
p.add_argument("--seq-lens", default="512,1024,2048")
|
||||
p.add_argument("--hidden", default="2048,4096")
|
||||
p.add_argument("--inter", default="5632,8192,14336")
|
||||
p.add_argument("--experts", default="8,16,32")
|
||||
p.add_argument("--top-k", default="1,2,4")
|
||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||
p.add_argument("--iters", type=int, default=25)
|
||||
p.add_argument("--warmup", type=int, default=5)
|
||||
p.add_argument("--csv", type=Path, default=None)
|
||||
p.add_argument("--compile", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[args.dtype]
|
||||
|
||||
if tg is None or not tg.available():
|
||||
print("torch_grouped unavailable; sweep aborted")
|
||||
return
|
||||
|
||||
bs_list = _parse_list(args.batch_sizes)
|
||||
seq_list = _parse_list(args.seq_lens)
|
||||
hidden_list = _parse_list(args.hidden)
|
||||
inter_list = _parse_list(args.inter)
|
||||
expert_list = _parse_list(args.experts)
|
||||
topk_list = _parse_list(args.top_k)
|
||||
|
||||
results: List[Result] = []
|
||||
|
||||
print(
|
||||
"bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
|
||||
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
|
||||
)
|
||||
|
||||
for bsz in bs_list:
|
||||
for seq in seq_list:
|
||||
tokens = bsz * seq
|
||||
for hidden in hidden_list:
|
||||
for inter in inter_list:
|
||||
for experts in expert_list:
|
||||
for top_k in topk_list:
|
||||
torch.manual_seed(0)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
block_naive, block_grouped = _load_block(
|
||||
hidden,
|
||||
inter,
|
||||
experts,
|
||||
top_k,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
x = torch.randn(
|
||||
bsz, seq, hidden, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
compiled_impl = None
|
||||
if args.compile:
|
||||
dynamo.config.capture_scalar_outputs = True
|
||||
dynamo.config.allow_unspec_int_on_nn_module = True
|
||||
try:
|
||||
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
print(
|
||||
f"torch.compile naive failed ({exc}); using eager"
|
||||
)
|
||||
else:
|
||||
|
||||
def grouped_forward(inp, *, block=block_grouped):
|
||||
block.experts._ax_parent_block_ref = (
|
||||
weakref.ref(block)
|
||||
) # type: ignore[attr-defined]
|
||||
y, _ = tg.moe_ffn_forward_grouped(
|
||||
inp,
|
||||
block.gate,
|
||||
block.experts,
|
||||
block.top_k,
|
||||
)
|
||||
return y
|
||||
|
||||
try:
|
||||
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
print(
|
||||
f"torch.compile grouped failed ({exc}); using eager"
|
||||
)
|
||||
compiled_impl = None
|
||||
|
||||
def run_naive(block=block_naive, data=x):
|
||||
y, _ = block(data)
|
||||
return y
|
||||
|
||||
def run_grouped(
|
||||
block=block_grouped, data=x, impl=compiled_impl
|
||||
):
|
||||
if impl is not None:
|
||||
return impl(data)
|
||||
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
|
||||
y, _ = tg.moe_ffn_forward_grouped(
|
||||
data,
|
||||
block.gate,
|
||||
block.experts,
|
||||
block.top_k,
|
||||
)
|
||||
return y
|
||||
|
||||
naive_ms = _bench(
|
||||
run_naive,
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
device=device,
|
||||
)
|
||||
y_naive = run_naive()
|
||||
|
||||
grouped_ms = _bench(
|
||||
run_grouped,
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
device=device,
|
||||
)
|
||||
y_grouped = run_grouped()
|
||||
|
||||
diff = (y_naive.float() - y_grouped.float()).abs()
|
||||
res = Result(
|
||||
bsz,
|
||||
seq,
|
||||
hidden,
|
||||
inter,
|
||||
experts,
|
||||
top_k,
|
||||
args.dtype,
|
||||
naive_ms,
|
||||
grouped_ms,
|
||||
naive_ms / grouped_ms,
|
||||
_estimate_flops(tokens, hidden, inter, top_k)
|
||||
/ ((naive_ms / 1000.0) * 1e12),
|
||||
_estimate_flops(tokens, hidden, inter, top_k)
|
||||
/ ((grouped_ms / 1000.0) * 1e12),
|
||||
diff.max().item(),
|
||||
diff.mean().item(),
|
||||
(
|
||||
(
|
||||
diff.pow(2).sum()
|
||||
/ (y_naive.float().pow(2).sum() + 1e-12)
|
||||
)
|
||||
.sqrt()
|
||||
.item()
|
||||
),
|
||||
)
|
||||
results.append(res)
|
||||
print(
|
||||
f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t"
|
||||
f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t"
|
||||
f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
|
||||
)
|
||||
|
||||
if args.csv:
|
||||
fieldnames = [
|
||||
"bsz",
|
||||
"seq",
|
||||
"hidden",
|
||||
"inter",
|
||||
"experts",
|
||||
"top_k",
|
||||
"dtype",
|
||||
"naive_ms",
|
||||
"grouped_ms",
|
||||
"speedup",
|
||||
"naive_tflops",
|
||||
"grouped_tflops",
|
||||
"max_abs",
|
||||
"mean_abs",
|
||||
"rel_l2",
|
||||
]
|
||||
with args.csv.open("w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for r in results:
|
||||
writer.writerow(
|
||||
{
|
||||
"bsz": r.bsz,
|
||||
"seq": r.seq,
|
||||
"hidden": r.hidden,
|
||||
"inter": r.inter,
|
||||
"experts": r.experts,
|
||||
"top_k": r.top_k,
|
||||
"dtype": r.dtype,
|
||||
"naive_ms": f"{r.naive_ms:.4f}",
|
||||
"grouped_ms": f"{r.grouped_ms:.4f}",
|
||||
"speedup": f"{r.speedup:.4f}",
|
||||
"naive_tflops": f"{r.naive_tflops:.4f}",
|
||||
"grouped_tflops": f"{r.grouped_tflops:.4f}",
|
||||
"max_abs": f"{r.max_abs:.6e}",
|
||||
"mean_abs": f"{r.mean_abs:.6e}",
|
||||
"rel_l2": f"{r.rel_l2:.6e}",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import weakref
|
||||
|
||||
main()
|
||||
@@ -1,205 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Benchmark Torchtitan MoE grouped vs naive expert execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# Ensure torchtitan is importable when running from the axolotl tree
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
|
||||
if str(_TITAN_PATH) not in sys.path:
|
||||
sys.path.append(str(_TITAN_PATH))
|
||||
|
||||
from torchtitan.models.moe import MoE, MoEArgs
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Torchtitan MoE microbenchmark")
|
||||
p.add_argument("--bsz", type=int, default=8)
|
||||
p.add_argument("--seq", type=int, default=1024)
|
||||
p.add_argument("--hidden", type=int, default=4096)
|
||||
p.add_argument("--inter", type=int, default=14336)
|
||||
p.add_argument("--experts", type=int, default=8)
|
||||
p.add_argument("--top_k", type=int, default=2)
|
||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||
p.add_argument("--iters", type=int, default=50)
|
||||
p.add_argument("--warmup", type=int, default=10)
|
||||
p.add_argument("--init-std", type=float, default=0.02)
|
||||
p.add_argument(
|
||||
"--score-before",
|
||||
action="store_true",
|
||||
help="Apply routing scores before expert computation (default: after)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--score-func",
|
||||
choices=["softmax", "sigmoid"],
|
||||
default="softmax",
|
||||
)
|
||||
p.add_argument(
|
||||
"--route-norm",
|
||||
action="store_true",
|
||||
help="Enable Torchtitan router normalization when using sigmoid scores.",
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def _map_dtype(arg: str) -> torch.dtype:
|
||||
return {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[arg]
|
||||
|
||||
|
||||
def _estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||||
# Two up projections + one down projection per expert/token combination.
|
||||
return 6.0 * tokens * top_k * hidden * inter
|
||||
|
||||
|
||||
def _prepare_module(
|
||||
moe: MoE,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> MoE:
|
||||
moe = moe.to(device=device)
|
||||
for param in moe.parameters():
|
||||
param.data = param.data.to(dtype)
|
||||
if param.grad is not None:
|
||||
param.grad = None
|
||||
|
||||
buffers = dict(moe.named_buffers())
|
||||
for name, buf in buffers.items():
|
||||
if name == "tokens_per_expert":
|
||||
moe._buffers[name] = torch.zeros_like(
|
||||
buf, dtype=torch.float32, device=device
|
||||
)
|
||||
elif name == "expert_bias" and buf is not None:
|
||||
moe._buffers[name] = torch.zeros_like(
|
||||
buf, dtype=torch.float32, device=device
|
||||
)
|
||||
else:
|
||||
moe._buffers[name] = buf.to(device=device, dtype=dtype)
|
||||
moe.eval()
|
||||
return moe
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _forward_fn(module: MoE, x: torch.Tensor) -> torch.Tensor:
|
||||
return module(x)
|
||||
|
||||
|
||||
def _bench(fn, *, iters: int, warmup: int, sync: bool = True) -> float:
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn()
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - start) * 1000.0)
|
||||
return sum(times) / len(times)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _parse_args()
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
dtype = _map_dtype(args.dtype)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
moe_args_grouped = MoEArgs(
|
||||
num_experts=args.experts,
|
||||
num_shared_experts=0,
|
||||
score_func=args.score_func,
|
||||
route_norm=args.route_norm,
|
||||
top_k=args.top_k,
|
||||
use_grouped_mm=True,
|
||||
score_before_experts=args.score_before,
|
||||
load_balance_coeff=None,
|
||||
)
|
||||
moe_grouped = MoE(moe_args_grouped, dim=args.hidden, hidden_dim=args.inter)
|
||||
moe_grouped.init_weights(args.init_std, buffer_device=device)
|
||||
|
||||
moe_args_naive = MoEArgs(
|
||||
num_experts=args.experts,
|
||||
num_shared_experts=0,
|
||||
score_func=args.score_func,
|
||||
route_norm=args.route_norm,
|
||||
top_k=args.top_k,
|
||||
use_grouped_mm=False,
|
||||
score_before_experts=args.score_before,
|
||||
load_balance_coeff=None,
|
||||
)
|
||||
moe_naive = MoE(moe_args_naive, dim=args.hidden, hidden_dim=args.inter)
|
||||
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
|
||||
|
||||
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
|
||||
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
|
||||
|
||||
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
|
||||
|
||||
tokens = args.bsz * args.seq
|
||||
print(
|
||||
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} "
|
||||
f"inter={args.inter} experts={args.experts} top_k={args.top_k}"
|
||||
)
|
||||
|
||||
def run_naive():
|
||||
return _forward_fn(moe_naive, x)
|
||||
|
||||
def run_grouped():
|
||||
return _forward_fn(moe_grouped, x)
|
||||
|
||||
if hasattr(moe_naive, "tokens_per_expert"):
|
||||
moe_naive.tokens_per_expert.zero_()
|
||||
if hasattr(moe_grouped, "tokens_per_expert"):
|
||||
moe_grouped.tokens_per_expert.zero_()
|
||||
|
||||
t_naive = _bench(run_naive, iters=args.iters, warmup=args.warmup)
|
||||
flops = _estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||||
tflops_naive = flops / ((t_naive / 1000.0) * 1e12)
|
||||
print(
|
||||
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t"
|
||||
f"{tflops_naive:.2f} TFLOP/s"
|
||||
)
|
||||
|
||||
y_naive = run_naive()
|
||||
|
||||
if hasattr(moe_grouped, "tokens_per_expert"):
|
||||
moe_grouped.tokens_per_expert.zero_()
|
||||
|
||||
t_grouped = _bench(run_grouped, iters=args.iters, warmup=args.warmup)
|
||||
tflops_grouped = flops / ((t_grouped / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_grouped if t_grouped > 0 else float("nan")
|
||||
print(
|
||||
f"grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
|
||||
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
|
||||
y_grouped = run_grouped()
|
||||
diff = (y_naive.float() - y_grouped.float()).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||||
print(
|
||||
f"grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,328 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Sweep Torchtitan MoE grouped vs naive configurations and report performance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
|
||||
import torch
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
|
||||
if str(_TITAN_PATH) not in sys.path:
|
||||
sys.path.append(str(_TITAN_PATH))
|
||||
|
||||
from torchtitan.models.moe import MoE, MoEArgs
|
||||
|
||||
|
||||
def _parse_int_list(value: str) -> List[int]:
|
||||
return [int(v) for v in value.split(",") if v]
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
|
||||
p.add_argument(
|
||||
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
|
||||
)
|
||||
p.add_argument(
|
||||
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
|
||||
)
|
||||
p.add_argument(
|
||||
"--experts", default="8,16,32,64", help="Comma separated expert counts"
|
||||
)
|
||||
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
|
||||
p.add_argument("--hidden", type=int, default=4096)
|
||||
p.add_argument("--inter", type=int, default=14336)
|
||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||
p.add_argument("--iters", type=int, default=25)
|
||||
p.add_argument("--warmup", type=int, default=5)
|
||||
p.add_argument("--init-std", type=float, default=0.02)
|
||||
p.add_argument("--score-before", action="store_true")
|
||||
p.add_argument("--score-func", choices=["softmax", "sigmoid"], default="softmax")
|
||||
p.add_argument("--route-norm", action="store_true")
|
||||
p.add_argument("--csv", type=Path, default=None, help="Optional CSV output path")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def _map_dtype(arg: str) -> torch.dtype:
|
||||
return {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[arg]
|
||||
|
||||
|
||||
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||||
return 6.0 * tokens * top_k * hidden * inter
|
||||
|
||||
|
||||
def _prepare_module(module: MoE, *, device: torch.device, dtype: torch.dtype) -> MoE:
|
||||
module = module.to(device=device)
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(dtype)
|
||||
if param.grad is not None:
|
||||
param.grad = None
|
||||
for name, buf in module.named_buffers():
|
||||
if name == "tokens_per_expert":
|
||||
module._buffers[name] = torch.zeros_like(
|
||||
buf, dtype=torch.float32, device=device
|
||||
)
|
||||
elif name == "expert_bias" and buf is not None:
|
||||
module._buffers[name] = torch.zeros_like(
|
||||
buf, dtype=torch.float32, device=device
|
||||
)
|
||||
else:
|
||||
module._buffers[name] = buf.to(device=device, dtype=dtype)
|
||||
module.eval()
|
||||
return module
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _forward(module: MoE, x: torch.Tensor) -> torch.Tensor:
|
||||
return module(x)
|
||||
|
||||
|
||||
def _bench(callable_, *, iters: int, warmup: int, device: torch.device) -> float:
|
||||
for _ in range(warmup):
|
||||
callable_()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
timings: List[float] = []
|
||||
for _ in range(iters):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
callable_()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
timings.append((time.perf_counter() - start) * 1000.0)
|
||||
return sum(timings) / len(timings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepResult:
|
||||
bsz: int
|
||||
seq: int
|
||||
experts: int
|
||||
top_k: int
|
||||
dtype: str
|
||||
naive_ms: float
|
||||
grouped_ms: float
|
||||
speedup: float
|
||||
naive_tflops: float
|
||||
grouped_tflops: float
|
||||
max_abs: float
|
||||
mean_abs: float
|
||||
rel_l2: float
|
||||
|
||||
|
||||
def _run_case(
|
||||
*,
|
||||
bsz: int,
|
||||
seq: int,
|
||||
experts: int,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
inter: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
iters: int,
|
||||
warmup: int,
|
||||
init_std: float,
|
||||
score_before: bool,
|
||||
score_func: str,
|
||||
route_norm: bool,
|
||||
) -> SweepResult:
|
||||
torch.manual_seed(0)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
moe_args_grouped = MoEArgs(
|
||||
num_experts=experts,
|
||||
num_shared_experts=0,
|
||||
score_func=score_func,
|
||||
route_norm=route_norm,
|
||||
top_k=top_k,
|
||||
use_grouped_mm=True,
|
||||
score_before_experts=score_before,
|
||||
load_balance_coeff=None,
|
||||
)
|
||||
moe_grouped = MoE(moe_args_grouped, dim=hidden, hidden_dim=inter)
|
||||
moe_grouped.init_weights(init_std, buffer_device=device)
|
||||
|
||||
moe_args_naive = MoEArgs(
|
||||
num_experts=experts,
|
||||
num_shared_experts=0,
|
||||
score_func=score_func,
|
||||
route_norm=route_norm,
|
||||
top_k=top_k,
|
||||
use_grouped_mm=False,
|
||||
score_before_experts=score_before,
|
||||
load_balance_coeff=None,
|
||||
)
|
||||
moe_naive = MoE(moe_args_naive, dim=hidden, hidden_dim=inter)
|
||||
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
|
||||
|
||||
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
|
||||
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
|
||||
|
||||
x = torch.randn(bsz, seq, hidden, device=device, dtype=dtype)
|
||||
|
||||
def run_naive():
|
||||
if hasattr(moe_naive, "tokens_per_expert"):
|
||||
moe_naive.tokens_per_expert.zero_()
|
||||
return _forward(moe_naive, x)
|
||||
|
||||
def run_grouped():
|
||||
if hasattr(moe_grouped, "tokens_per_expert"):
|
||||
moe_grouped.tokens_per_expert.zero_()
|
||||
return _forward(moe_grouped, x)
|
||||
|
||||
naive_ms = _bench(run_naive, iters=iters, warmup=warmup, device=device)
|
||||
y_naive = run_naive()
|
||||
|
||||
grouped_ms = _bench(run_grouped, iters=iters, warmup=warmup, device=device)
|
||||
y_grouped = run_grouped()
|
||||
|
||||
diff = (y_naive.float() - y_grouped.float()).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||||
|
||||
tokens = bsz * seq
|
||||
flops = _estimate_flops(tokens, hidden, inter, top_k)
|
||||
naive_tflops = flops / ((naive_ms / 1000.0) * 1e12)
|
||||
grouped_tflops = flops / ((grouped_ms / 1000.0) * 1e12)
|
||||
speedup = naive_ms / grouped_ms if grouped_ms > 0 else float("nan")
|
||||
|
||||
return SweepResult(
|
||||
bsz=bsz,
|
||||
seq=seq,
|
||||
experts=experts,
|
||||
top_k=top_k,
|
||||
dtype=str(dtype),
|
||||
naive_ms=naive_ms,
|
||||
grouped_ms=grouped_ms,
|
||||
speedup=speedup,
|
||||
naive_tflops=naive_tflops,
|
||||
grouped_tflops=grouped_tflops,
|
||||
max_abs=max_abs,
|
||||
mean_abs=mean_abs,
|
||||
rel_l2=rel_l2,
|
||||
)
|
||||
|
||||
|
||||
def _print_header(
|
||||
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
|
||||
) -> None:
|
||||
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
|
||||
print(
|
||||
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
|
||||
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
|
||||
)
|
||||
|
||||
|
||||
def _print_result(res: SweepResult) -> None:
|
||||
print(
|
||||
f"{res.bsz}\t{res.seq}\t{res.experts}\t{res.top_k}\t"
|
||||
f"{res.naive_ms:.2f}\t{res.grouped_ms:.2f}\t{res.speedup:.2f}\t"
|
||||
f"{res.naive_tflops:.2f}\t{res.grouped_tflops:.2f}\t"
|
||||
f"{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
|
||||
)
|
||||
|
||||
|
||||
def _write_csv(path: Path, results: Iterable[SweepResult]) -> None:
|
||||
fieldnames = [
|
||||
"batch_size",
|
||||
"seq_len",
|
||||
"experts",
|
||||
"top_k",
|
||||
"dtype",
|
||||
"naive_ms",
|
||||
"grouped_ms",
|
||||
"speedup",
|
||||
"naive_tflops",
|
||||
"grouped_tflops",
|
||||
"max_abs",
|
||||
"mean_abs",
|
||||
"rel_l2",
|
||||
]
|
||||
with path.open("w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for r in results:
|
||||
writer.writerow(
|
||||
{
|
||||
"batch_size": r.bsz,
|
||||
"seq_len": r.seq,
|
||||
"experts": r.experts,
|
||||
"top_k": r.top_k,
|
||||
"dtype": r.dtype,
|
||||
"naive_ms": f"{r.naive_ms:.4f}",
|
||||
"grouped_ms": f"{r.grouped_ms:.4f}",
|
||||
"speedup": f"{r.speedup:.4f}",
|
||||
"naive_tflops": f"{r.naive_tflops:.4f}",
|
||||
"grouped_tflops": f"{r.grouped_tflops:.4f}",
|
||||
"max_abs": f"{r.max_abs:.6e}",
|
||||
"mean_abs": f"{r.mean_abs:.6e}",
|
||||
"rel_l2": f"{r.rel_l2:.6e}",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _parse_args()
|
||||
dtype = _map_dtype(args.dtype)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
batch_sizes = _parse_int_list(args.batch_sizes)
|
||||
seq_lens = _parse_int_list(args.seq_lens)
|
||||
experts_list = _parse_int_list(args.experts)
|
||||
top_ks = _parse_int_list(args.top_ks)
|
||||
|
||||
results: List[SweepResult] = []
|
||||
_print_header(args.hidden, args.inter, dtype, device)
|
||||
|
||||
for bsz in batch_sizes:
|
||||
for seq in seq_lens:
|
||||
for experts in experts_list:
|
||||
for top_k in top_ks:
|
||||
try:
|
||||
res = _run_case(
|
||||
bsz=bsz,
|
||||
seq=seq,
|
||||
experts=experts,
|
||||
top_k=top_k,
|
||||
hidden=args.hidden,
|
||||
inter=args.inter,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
init_std=args.init_std,
|
||||
score_before=args.score_before,
|
||||
score_func=args.score_func,
|
||||
route_norm=args.route_norm,
|
||||
)
|
||||
except RuntimeError as err:
|
||||
print(
|
||||
f"{bsz}\t{seq}\t{experts}\t{top_k}\tERROR: {err}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
results.append(res)
|
||||
_print_result(res)
|
||||
|
||||
if args.csv and results:
|
||||
_write_csv(args.csv, results)
|
||||
print(f"Wrote {len(results)} rows to {args.csv}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"'
|
||||
)
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Inspect Qwen2 MoE expert implementations for grouped-mm debugging."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
sys.path.extend(
|
||||
[
|
||||
str(ROOT / "transformers" / "src"),
|
||||
str(ROOT / "src"),
|
||||
]
|
||||
)
|
||||
|
||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
from axolotl.kernels.moe.torch_grouped import _iter_expert_impls
|
||||
|
||||
|
||||
def main() -> None:
|
||||
cfg = Qwen2MoeConfig(
|
||||
hidden_size=4096,
|
||||
moe_intermediate_size=14336,
|
||||
shared_expert_intermediate_size=14336,
|
||||
num_experts=32,
|
||||
num_experts_per_tok=4,
|
||||
)
|
||||
|
||||
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
|
||||
experts = block.experts
|
||||
experts._ax_parent_block = block
|
||||
|
||||
impls = _iter_expert_impls(experts)
|
||||
print(f"impl count: {len(impls)}")
|
||||
for idx, impl in enumerate(impls[:8]):
|
||||
has_gate = hasattr(impl, "gate_proj")
|
||||
has_up = hasattr(impl, "up_proj")
|
||||
print(
|
||||
f"impl[{idx}] type={impl.__class__.__name__} has_gate={has_gate} has_up={has_up}"
|
||||
)
|
||||
if has_gate:
|
||||
print(f" gate shape {tuple(impl.gate_proj.weight.shape)}")
|
||||
print(f" up shape {tuple(impl.up_proj.weight.shape)}")
|
||||
print(f" down shape {tuple(impl.down_proj.weight.shape)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,47 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Probe PyTorch for grouped GEMM operator names and namespaces.
|
||||
Run: python scripts/probe_torch_grouped_ops.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
import torch
|
||||
except Exception as e:
|
||||
print("Failed to import torch:", e)
|
||||
sys.exit(1)
|
||||
|
||||
print("torch version:", torch.__version__)
|
||||
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
|
||||
print("ops namespaces:", namespaces)
|
||||
|
||||
found_any = False
|
||||
for ns in namespaces:
|
||||
obj = getattr(torch.ops, ns, None)
|
||||
ops = []
|
||||
if obj is not None:
|
||||
try:
|
||||
ops = dir(obj)
|
||||
except Exception as e:
|
||||
print(f"warning: failed to list ops for namespace {ns}: {e}")
|
||||
cands = [
|
||||
o
|
||||
for o in ops
|
||||
if ("group" in o.lower())
|
||||
or ("mm_grouped" in o.lower())
|
||||
or ("matmul_grouped" in o.lower())
|
||||
or ("grouped" in o.lower())
|
||||
]
|
||||
if cands:
|
||||
found_any = True
|
||||
print(f"namespace {ns} candidates:", cands)
|
||||
|
||||
if not found_any:
|
||||
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
12
setup.py
12
setup.py
@@ -26,7 +26,6 @@ def parse_requirements(extras_require_map):
|
||||
_install_requires.append(line)
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
@@ -34,7 +33,6 @@ def parse_requirements(extras_require_map):
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"liger-kernel",
|
||||
]
|
||||
_install_requires = [
|
||||
@@ -51,7 +49,7 @@ def parse_requirements(extras_require_map):
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.6.0" # default to torch 2.6
|
||||
torch_version = "2.8.0" # default to torch 2.8.0
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
@@ -87,7 +85,6 @@ def parse_requirements(extras_require_map):
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 4):
|
||||
extras_require_map.pop("vllm")
|
||||
@@ -124,7 +121,6 @@ extras_require = {
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.8.3",
|
||||
"ring-flash-attn>=0.1.7",
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.5",
|
||||
@@ -163,6 +159,12 @@ extras_require = {
|
||||
"llmcompressor==0.5.1",
|
||||
],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
|
||||
"opentelemetry": [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-prometheus",
|
||||
"prometheus-client",
|
||||
],
|
||||
}
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
|
||||
@@ -85,9 +85,7 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
unpatch_llama4 = patch_llama4_linearized_modeling()
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model, torch_dtype=torch.bfloat16
|
||||
)
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16)
|
||||
processor = AutoProcessor.from_pretrained(model)
|
||||
processor.save_pretrained(output)
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ def do_quantize(
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", torch_dtype=torch_dtype
|
||||
model_path, device_map="auto", dtype=torch_dtype
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
|
||||
@@ -99,7 +99,7 @@ def ray_train_func(kwargs: dict):
|
||||
resolve_dtype(cfg)
|
||||
|
||||
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
|
||||
if cfg.deepspeed:
|
||||
if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"):
|
||||
cfg.deepspeed = cfg.deepspeed.to_dict()
|
||||
|
||||
# initialize accelerator before model instantiation
|
||||
|
||||
@@ -12,6 +12,9 @@ MOE_ARCH_BLOCK = {
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
}
|
||||
|
||||
@@ -29,7 +29,11 @@ from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils import (
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_opentelemetry_available,
|
||||
)
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
@@ -134,6 +138,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
|
||||
if self.cfg.save_first_step:
|
||||
callbacks.append(SaveModelOnFirstStepCallback())
|
||||
|
||||
@@ -491,6 +501,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"dion_momentum",
|
||||
"dion_rank_fraction",
|
||||
"dion_rank_multiple_of",
|
||||
"dataset_num_proc",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
@@ -514,9 +525,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
# max_length is not used in CausalTrainer
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
@@ -28,7 +28,6 @@ from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
@@ -63,12 +62,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.relora:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback())
|
||||
|
||||
# TODO: check if can move to base class
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
@@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
else:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
@@ -129,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
# Handle when max_prompt_length == max_length from defaults
|
||||
# CPOTrainer requires strictly less than
|
||||
if (
|
||||
training_args_kwargs["max_prompt_length"]
|
||||
== training_args_kwargs["max_length"]
|
||||
):
|
||||
training_args_kwargs["max_prompt_length"] -= 1
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
@@ -144,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.GRPO:
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Mapping
|
||||
|
||||
def chat_message_transform_builder(
|
||||
train_on_inputs=False,
|
||||
conversations_field: str = "conversations",
|
||||
conversations_field: str = "messages",
|
||||
message_field_role: str | list[str] | None = None, # commonly "role"
|
||||
message_field_content: str | list[str] | None = None, # commonly "content"
|
||||
message_field_training: str | list[str] | None = None, # commonly "weight"
|
||||
@@ -20,13 +20,13 @@ def chat_message_transform_builder(
|
||||
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
||||
Defaults to False.
|
||||
conversations_field (str, optional):
|
||||
The field name of the conversations. Defaults to "conversations".
|
||||
The field name of the conversations. Defaults to "messages".
|
||||
message_field_role (str | list[str], optional):
|
||||
The field name of the role. Defaults to "role".
|
||||
The field name of the role.
|
||||
message_field_content (str | list[str], optional):
|
||||
The field name of the message content. Defaults to "content".
|
||||
The field name of the message content.
|
||||
message_field_training (str | list[str], optional):
|
||||
The field name of the train/weight. Defaults to "weight".
|
||||
The field name of the train/weight.
|
||||
|
||||
Returns:
|
||||
Callable:
|
||||
|
||||
@@ -225,17 +225,6 @@ class AxolotlTrainer(
|
||||
|
||||
data_collator = self.data_collator if is_training else self.eval_data_collator
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
if isinstance(dataset, datasets.Dataset):
|
||||
if is_training:
|
||||
if not self.args.sample_packing or self.args.pretraining:
|
||||
@@ -294,6 +283,18 @@ class AxolotlTrainer(
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
# Accelerator.free_memory() will destroy the references, so
|
||||
@@ -560,13 +561,6 @@ class AxolotlTrainer(
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
"limit_all_gathers" in self.args.fsdp_config
|
||||
and self.args.fsdp_config["limit_all_gathers"]
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
def additional_accelerator_args(
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -27,7 +27,6 @@ class DPOStrategy:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
|
||||
@@ -52,6 +52,7 @@ class GRPOStrategy:
|
||||
if trl.vllm_mode:
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -31,6 +31,7 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- apertus
|
||||
- arcee
|
||||
- cohere
|
||||
- cohere2
|
||||
@@ -44,14 +45,22 @@ plugins:
|
||||
- glm
|
||||
- glm4
|
||||
- glm4_moe
|
||||
- glm4v
|
||||
- glm4v_moe
|
||||
- gpt_oss
|
||||
- granite
|
||||
- granitemoe
|
||||
- granitemoeshared
|
||||
- granitemoehybrid
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- lfm2
|
||||
- lfm2_moe
|
||||
- lfm2_vl
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- llava
|
||||
- mistral
|
||||
- mistral3
|
||||
- mixtral
|
||||
@@ -65,6 +74,9 @@ plugins:
|
||||
- qwen2_5_vl
|
||||
- qwen3
|
||||
- qwen3_moe
|
||||
- qwen3_vl
|
||||
- qwen3_vl_moe
|
||||
- qwen3_next
|
||||
- smollm3
|
||||
- seed_oss
|
||||
- voxtral
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -360,7 +360,7 @@ def _diffusion_step(
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
logits = shift_logits_to_input_positions(outputs.logits)
|
||||
|
||||
# Only sample at currently masked positions
|
||||
if current_mask.any():
|
||||
|
||||
@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
input_ids=noisy_batch.long(),
|
||||
attention_mask=bidirectional_mask,
|
||||
)
|
||||
logits = outputs.logits
|
||||
logits = shift_logits_to_input_positions(outputs.logits)
|
||||
|
||||
if masked_indices.sum() > 0:
|
||||
valid_indices = torch.where(masked_indices)
|
||||
|
||||
@@ -157,3 +157,10 @@ def create_bidirectional_attention_mask(
|
||||
|
||||
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||
return bidirectional_mask.unsqueeze(1)
|
||||
|
||||
|
||||
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Align next-token logits with their input token positions for diffusion."""
|
||||
if logits.size(1) <= 1:
|
||||
return logits
|
||||
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
|
||||
@@ -72,9 +72,9 @@ def kldiv_forward_llama_like(
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
# self._loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
loss = self.loss_function(
|
||||
loss = self._loss_function(
|
||||
self.lm_head.weight,
|
||||
hidden_states,
|
||||
target_token_ids,
|
||||
|
||||
@@ -29,7 +29,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_accepts_loss_kwargs = True
|
||||
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
|
||||
|
||||
loss_fn = LigerFusedLinearKLTopKLogprobLoss(
|
||||
self.args.kd_ce_alpha, # hard label loss
|
||||
self.args.kd_alpha, # kd loss
|
||||
self.args.kd_temperature,
|
||||
@@ -37,6 +38,14 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||
normalize_topk=self.args.kd_normalize_topk,
|
||||
)
|
||||
target = self.model
|
||||
|
||||
# Unwrap PEFT wrapper
|
||||
if hasattr(target, "get_base_model"):
|
||||
target = target.get_base_model()
|
||||
|
||||
# Set on the actual model instance
|
||||
target._loss_function = loss_fn
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .backends import MOEBackend, get_moe_backend_name
|
||||
|
||||
__all__ = ["get_moe_backend_name", "MOEBackend"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user