Compare commits

..

12 Commits

Author SHA1 Message Date
Dan Saunders
f3c8a25b30 Merge branch 'main' into codecov-pulls-only 2025-06-18 16:00:37 -04:00
Carsten Kragelund Jørgensen
eb3a57eb17 Ignore generation/endgeneration tags when analyzing Jinja chat template (#2787)
* ignore generation/endgeneration tags

Axolotl handles calculating the mask for assistant turns on its own, and as such these tags are not needed, however currently the analyzer does not recognize them at all and throws an error.

* feat: add phi4 tokenizer test and unblock gemma2

* fix: improve template

* chore: refactor

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-18 15:59:07 -04:00
Wing Lian
34da391391 Set dev version (#2807) [skip ci] 2025-06-18 15:49:05 -04:00
NanoCode012
0bb9077553 Fix: logging on py310 (#2802)
* feat: encourage py311

* fix: logging import on py310

* fix: do upper and simplify handling
2025-06-18 15:46:27 -04:00
Wing Lian
a85efffbef bump transformers==4.52.4 (#2800) [skip ci]
* bump transformers==4.52.4

* don't use hf offline for qwen tokenizer

* increase timeout

* don't use methodtype

* increase timeout

* better assertion logging

* upgrade deepspeed version too
2025-06-18 15:46:14 -04:00
Dan Saunders
06a648263b Config doc autogen: follow-up fix docs build (#2806)
* config reference doc autogen

* improvements

* cleanup; still ugly but working

* reformat

* remove autogen config ref from git

* factor out validations

* rewrite

* rewrite

* cleanup

* progress

* progress

* progress

* lint and minifying somewhat

* remove unneeded

* coderabbit

* coderabbit

* update preview-docs workflow triggers

* installing with deps

* coderabbit

* update refs

* overwrote file accidentally

* docs install deps
2025-06-18 15:42:54 -04:00
Dan Saunders
9d5bfc127e Config doc autogen (#2718)
* config reference doc autogen

* improvements

* cleanup; still ugly but working

* reformat

* remove autogen config ref from git

* factor out validations

* rewrite

* rewrite

* cleanup

* progress

* progress

* progress

* lint and minifying somewhat

* remove unneeded

* coderabbit

* coderabbit

* update preview-docs workflow triggers

* installing with deps

* coderabbit

* update refs

* overwrote file accidentally
2025-06-18 15:36:53 -04:00
Dan Saunders
016eb8055f accidental file 2025-06-17 13:58:02 -04:00
Dan Saunders
639ddeff6a return codecov artifact from modal image 2025-06-17 13:33:02 -04:00
Dan Saunders
753e4e3dec updates 2025-06-17 10:45:32 -04:00
Dan Saunders
2538c3b761 update to run only if succeeded 2025-06-17 10:45:32 -04:00
Dan Saunders
aa3639b7ad run codecov action at end of CI; only_pulls: true 2025-06-17 10:45:32 -04:00
52 changed files with 3321 additions and 2378 deletions

View File

@@ -23,7 +23,7 @@ jobs:
- name: Install dependencies
run: |
python3 -m pip install jupyter quartodoc
python3 -m pip install -e . --no-deps
python3 -m pip install -e .
- name: Build autodoc
run: quartodoc build
- name: Publish to GitHub Pages (and render)

View File

@@ -8,7 +8,9 @@ on:
paths:
- '**/*.md' # any Markdown file
- '**/*.qmd' # any Quarto file
- '_quarto.yaml'
- '_quarto.yml'
- docs/scripts/generate_config_docs.py
- src/axolotl/utils/schemas/**.py
permissions:
checks: write
@@ -38,7 +40,7 @@ jobs:
- name: Install dependencies
run: |
python3 -m pip install jupyter quartodoc
python3 -m pip install -e . --no-deps
python3 -m pip install -e .
- name: Build autodoc
run: quartodoc build

View File

@@ -106,13 +106,12 @@ jobs:
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
- name: Upload coverage artifacts
uses: actions/upload-artifact@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
name: coverage-${{ matrix.pytorch_version }}-${{ github.run_id }}
path: ./coverage.xml
retention-days: 1
- name: cleanup pip cache
run: |
@@ -234,6 +233,14 @@ jobs:
run: |
modal run cicd.e2e_tests
- name: Upload coverage artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: coverage-e2e-1st-${{ github.run_id }}
path: ./e2e-coverage.xml
retention-days: 1
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
@@ -297,6 +304,14 @@ jobs:
run: |
modal run cicd.e2e_tests
- name: Upload coverage artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: coverage-e2e-${{ matrix.cuda }}-${{ matrix.pytorch }}-${{ github.run_id }}
path: ./e2e-coverage.xml
retention-days: 1
docker-e2e-cleanup:
runs-on: [self-hosted, modal]
timeout-minutes: 90
@@ -336,3 +351,26 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.cleanup
upload-coverage:
name: Upload Coverage to Codecov
runs-on: ubuntu-latest
needs: [pytest, docker-e2e-tests, docker-e2e-tests-1st]
if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main'
steps:
- name: Download coverage reports
uses: actions/download-artifact@v4
with:
path: coverage-reports
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
directory: coverage-reports
fail_ci_if_error: false
verbose: true
name: codecov-umbrella
override_commit: ${{ github.event.pull_request.head.sha || github.sha }}
override_pr: ${{ github.event.pull_request.number }}

View File

@@ -328,7 +328,7 @@ The following optimizers are supported:
- Use `gradient_checkpointing: true` to reduce memory usage
- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html).
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config-reference.html).
### Errors:

239
README.md
View File

@@ -1,177 +1,124 @@
<div align="center">
<a href="https://github.com/axolotl-ai-cloud/axolotl">
<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/docs/logo.png" alt="Axolotl Logo" width="250" style="margin-bottom: 20px;"/>
</a>
<h1><span style="color: #4CAF50;">Axolotl: Fine-tune LLMs with Unprecedented Ease & Power!</span> 🚀</h1>
<p style="font-size: 1.1em; color: #555;">Your ultimate toolkit for efficient, scalable, and versatile large language model fine-tuning.</p>
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
<p>
<a href="https://discord.gg/HhrNrHJPRb" target="_blank">
<img src="https://img.shields.io/discord/1070542385153273887?label=Discord&logo=discord&logoColor=white&color=7289DA" alt="Discord Community" style="margin: 5px;">
</a>
<a href="https://docs.axolotl.ai/" target="_blank">
<img src="https://img.shields.io/badge/Documentation-blue?style=flat&logo=readthedocs&logoColor=white" alt="Official Documentation" style="margin: 5px;">
</a>
<a href="https://pypi.org/project/axolotl/" target="_blank">
<img src="https://img.shields.io/pypi/v/axolotl?label=PyPI&logo=pypi&logoColor=white&color=blue" alt="PyPI Package" style="margin: 5px;">
</a>
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases" target="_blank">
<img src="https://img.shields.io/github/downloads/axolotl-ai-cloud/axolotl/total?label=Downloads&color=green" alt="GitHub Downloads" style="margin: 5px;">
</a>
</p>
<br>
</div>
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://codecov.io/gh/axolotl-ai-cloud/axolotl"><img src="https://codecov.io/gh/axolotl-ai-cloud/axolotl/branch/main/graph/badge.svg" alt="codecov"></a>
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p>
---
<div style="background-color: #f0f8ff; padding: 25px; border-radius: 12px; margin-bottom: 30px; border: 1px solid #d0e8ff;">
<h2 style="color: #0056b3; text-align: center; margin-top: 0;">🎉 Latest Innovations & Updates!</h2>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/06:</span> Magistral with mistral-common tokenizer support!</strong> Dive into <a href="https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral" style="color: #007bff; text-decoration: none;">examples</a> to train your own Magistral models.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/05:</span> Quantization Aware Training (QAT) support!</strong> Explore the <a href="https://docs.axolotl.ai/docs/qat.html" style="color: #007bff; text-decoration: none;">docs</a> to learn more.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/04:</span> Llama 4 support!</strong> See <a href="https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4" style="color: #007bff; text-decoration: none;">examples</a> to train Llama 4 with Axolotl's linearized version!</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/03:</span> Sequence Parallelism (SP) support!</strong> Scale your context length. Read the <a href="https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl" style="color: #007bff; text-decoration: none;">blog</a> and <a href="https://docs.axolotl.ai/docs/sequence_parallelism.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/03:</span> (Beta) Fine-tuning Multimodal models!</strong> Check out the <a href="https://docs.axolotl.ai/docs/multimodal.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/02:</span> LoRA optimizations!</strong> Reduce memory and improve speed. Jump into the <a href="https://docs.axolotl.ai/docs/lora_optims.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/02:</span> GRPO support!</strong> Dive into our <a href="https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm" style="color: #007bff; text-decoration: none;">blog</a> and <a href="https://github.com/axolotl-ai-cloud/grpo_code" style="color: #007bff; text-decoration: none;">GRPO example</a>.</li>
<li style="margin-bottom: 0px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/01:</span> Reward Modelling / Process Reward Modelling fine-tuning!</strong> See <a href="https://docs.axolotl.ai/docs/reward_modelling.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
</ul>
</div>
## 🎉 Latest Updates
<h2 style="color: #FF5733;"><span style="margin-right: 10px;">✨</span> Axolotl Overview: Your LLM Fine-tuning Powerhouse!</h2>
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
<div style="background-color: #fffacd; padding: 20px; border-radius: 10px; margin-bottom: 30px; border: 1px solid #ffd700;">
<p style="font-size: 1.1em; color: #333; text-align: center;">Axolotl is a powerful, flexible, and user-friendly tool designed to supercharge your post-training workflows for a wide range of cutting-edge AI models.</p>
</div>
## ✨ Overview
<div style="display: flex; flex-wrap: wrap; justify-content: space-around; gap: 20px; margin-bottom: 40px;">
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">🤖</span> Broad Model Compatibility</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Train a vast array of models including LLaMA, Mistral, Mixtral, Pythia, and many more.</li>
<li>Fully compatible with HuggingFace transformers causal language models, ensuring wide adoption.</li>
</ul>
</div>
Axolotl is a tool designed to streamline post-training for various AI models.
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">🔧</span> Diverse Training Methodologies</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Full fine-tuning, LoRA, QLoRA, GPTQ, QAT.</li>
<li>Preference Tuning: DPO, IPO, KTO, ORPO.</li>
<li>Advanced RL: GRPO.</li>
<li>Multimodal and Reward Modelling (RM) / Process Reward Modelling (PRM).</li>
</ul>
</div>
Features:
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">⚙️</span> Streamlined Configuration</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Utilize a single, intuitive YAML file across dataset preprocess, training, evaluation, quantization, and inference.</li>
</ul>
</div>
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">⚡</span> Cutting-Edge Performance Optimizations</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li><a href="https://docs.axolotl.ai/docs/multipack.html" style="color: #007bff;">Multipacking</a>, <a href="https://github.com/Dao-AILab/flash-attention" style="color: #007bff;">Flash Attention</a>, <a href="https://github.com/facebookresearch/xformers" style="color: #007bff;">Xformers</a>, <a href="https://pytorch.org/blog/flexattention/" style="color: #007bff;">Flex Attention</a>, <a href="https://github.com/linkedin/Liger-Kernel" style="color: #007bff;">Liger Kernel</a>, <a href="https://github.com/apple/ml-cross-entropy/tree/main" style="color: #007bff;">Cut Cross Entropy</a>.</li>
<li>Sequence Parallelism (SP), LoRA optimizations.</li>
<li>Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!</li>
</ul>
</div>
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">📂</span> Flexible Data Handling</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Load datasets from local paths, HuggingFace Hub, and major cloud providers (S3, Azure, GCP, OCI).</li>
</ul>
</div>
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">☁️</span> Cloud-Ready & Deployable</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Official <a href="https://hub.docker.com/u/axolotlai" style="color: #007bff;">Docker images</a> and <a href="https://pypi.org/project/axolotl/" style="color: #007bff;">PyPI packages</a> for seamless integration on cloud platforms and local hardware.</li>
</ul>
</div>
</div>
## 🚀 Quick Start
<h2 style="color: #007bff;"><span style="margin-right: 10px;">🚀</span> Quick Start: Get Fine-tuning in Minutes!</h2>
**Requirements**:
<div style="background-color: #e6f7ff; padding: 25px; border-radius: 12px; margin-bottom: 30px; border: 1px solid #cceeff;">
<h3 style="color: #0056b3; margin-top: 0;">Requirements:</h3>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 5px;"><span style="color: #333; font-weight: bold;">▶ NVIDIA GPU</span> (Ampere or newer for `bf16` and Flash Attention) or AMD GPU</li>
<li style="margin-bottom: 5px;"><span style="color: #333; font-weight: bold;">▶ Python 3.11</span></li>
<li style="margin-bottom: 5px;"><span style="color: #333; font-weight: bold;">▶ PyTorch ≥2.5.1</span></li>
</ul>
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.5.1
<h3 style="color: #0056b3;">Installation:</h3>
<pre><code style="background-color: #eef; padding: 15px; border-radius: 8px; display: block; overflow-x: auto;">pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
### Installation
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL</code></pre>
<p style="font-size: 0.9em; color: #555;">Other installation approaches are described <a href="https://docs.axolotl.ai/docs/installation.html" style="color: #007bff; text-decoration: none;">here</a>.</p>
axolotl fetch deepspeed_configs # OPTIONAL
```
<h3 style="color: #0056b3;">Your First Fine-tune:</h3>
<pre><code style="background-color: #eef; padding: 15px; border-radius: 8px; display: block; overflow-x: auto;"># Fetch axolotl examples
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
### Your First Fine-tune
```bash
# Fetch axolotl examples
axolotl fetch examples
# Or, specify a custom path
axolotl fetch examples --dest path/to/folder
# Train a model using LoRA
axolotl train examples/llama-3/lora-1b.yml</code></pre>
<p style="text-align: center; font-size: 1.1em; font-weight: bold; margin-top: 20px;">
That's it! Check out our <a href="https://docs.axolotl.ai/docs/getting-started.html" style="background-color: #28a745; color: white; padding: 12px 25px; border-radius: 8px; text-decoration: none; display: inline-block; transition: background-color 0.3s ease;"> Getting Started Guide ➜</a> for a more detailed walkthrough.
</p>
</div>
axolotl train examples/llama-3/lora-1b.yml
```
<h2 style="color: #8A2BE2;"><span style="margin-right: 10px;">📚</span> Comprehensive Documentation: Unlock Axolotl's Full Potential</h2>
That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough.
<div style="background-color: #f7f0ff; padding: 25px; border-radius: 12px; margin-bottom: 30px; border: 1px solid #e0caff;">
<p style="text-align: center; font-size: 1.1em; color: #333;">Dive deep into Axolotl's capabilities with our extensive documentation:</p>
<ul style="list-style-type: none; padding-left: 0; text-align: center;">
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/installation.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Installation Options</a> - Detailed setup instructions for different environments</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/config.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Configuration Guide</a> - Full configuration options and examples</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/dataset_loading.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Dataset Loading</a> - Loading datasets from various sources</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/dataset-formats/" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Dataset Guide</a> - Supported formats and how to use them</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/multi-gpu.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Multi-GPU Training</a></li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/multi-node.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Multi-Node Training</a></li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/multipack.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Multipacking</a></li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/api/" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> API Reference</a> - Auto-generated code documentation</li>
<li style="margin-bottom: 0px;"><a href="https://docs.axolotl.ai/docs/faq.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;">❓ FAQ</a> - Frequently asked questions</li>
</ul>
</div>
<h2 style="color: #FF8C00;"><span style="margin-right: 10px;">🤝</span> Need Help? We're Here for You!</h2>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 10px;"><span style="font-size: 1.2em; color: #7289DA;"></span> Join our vibrant <a href="https://discord.gg/HhrNrHJPRb" style="color: #7289DA; text-decoration: none; font-weight: bold;">Discord community</a> for real-time support and discussions.</li>
<li style="margin-bottom: 10px;"><span style="font-size: 1.2em; color: #555;"></span> Explore our <a href="https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/" style="color: #FF8C00; text-decoration: none; font-weight: bold;">Examples</a> directory for practical use cases.</li>
<li style="margin-bottom: 10px;"><span style="font-size: 1.2em; color: #555;"></span> Read our <a href="https://docs.axolotl.ai/docs/debugging.html" style="color: #FF8C00; text-decoration: none; font-weight: bold;">Debugging Guide</a> for troubleshooting tips.</li>
<li style="margin-bottom: 0px;"><span style="font-size: 1.2em; color: #007bff;">✉</span> Need dedicated support? Please contact <a href="mailto:wing@axolotl.ai" style="color: #007bff; text-decoration: none; font-weight: bold;">wing@axolotl.ai</a> for professional assistance options.</li>
</ul>
## 📚 Documentation
<h2 style="color: #FF1493;"><span style="margin-right: 10px;">🌟</span> Contribute to Axolotl!</h2>
<p style="font-size: 1.1em;">
Contributions are always welcome and highly appreciated! Axolotl thrives on community support. Please see our <a href="https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md" style="color: #FF1493; text-decoration: none; font-weight: bold;">Contributing Guide</a> for details on how you can help make Axolotl even better.
</p>
- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments
- [Configuration Guide](https://docs.axolotl.ai/docs/config-reference.html) - Full configuration options and examples
- [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources
- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [Multipacking](https://docs.axolotl.ai/docs/multipack.html)
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
<div align="center" style="margin-top: 40px; padding: 25px; background-color: #f8f8f8; border-radius: 12px; border: 1px solid #eee;">
<h2 style="color: #FF69B4; margin-bottom: 20px;">❤️ Our Esteemed Sponsors</h2>
<p style="font-size: 1.1em; color: #555;">A huge thank you to our visionary sponsors who provide the essential resources to keep Axolotl at the forefront of LLM fine-tuning:</p>
<a href="https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl" target="_blank" style="display: inline-block; margin: 20px;">
<img src="https://assets-global.website-files.com/6247c4c1d68352614b7e87ae/63b27b3b44b82d02c8163f4f_logo-dark-square.png" alt="Modal Logo" width="180" style="vertical-align: middle; border-radius: 8px; box-shadow: 0 4px 10px rgba(0,0,0,0.15);"/>
</a>
<p style="font-size: 0.9em; color: #777; margin-top: 20px;">
<strong>Modal:</strong> Revolutionizing cloud computing for Gen AI. Run jobs, deploy models, and fine-tune LLMs at scale with ease.
</p>
<p style="font-size: 1em; color: #555; margin-top: 30px;">
Interested in powering the future of Axolotl? <span style="font-weight: bold; color: #FF69B4;">Become a sponsor!</span> Contact us at <a href="mailto:wing@axolotl.ai" style="color: #007bff; text-decoration: none;">wing@axolotl.ai</a>
</p>
</div>
## 🤝 Getting Help
<h2 style="color: #6A5ACD;"><span style="margin-right: 10px;">📜</span> License</h2>
<p style="font-size: 1.1em;">
This project is proudly licensed under the <span style="font-weight: bold; color: #6A5ACD;">Apache 2.0 License</span>. See the <a href="LICENSE" style="color: #007bff; text-decoration: none;">LICENSE</a> file for full details.
</p>
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
- Check out our [Examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/) directory
- Read our [Debugging Guide](https://docs.axolotl.ai/docs/debugging.html)
- Need dedicated support? Please contact [wing@axolotl.ai](mailto:wing@axolotl.ai) for options
## 🌟 Contributing
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## ❤️ Sponsors
Thank you to our sponsors who help make Axolotl possible:
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
fine-tune large language models, run protein folding simulations, and much more.
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📜 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.

View File

@@ -1,5 +1,6 @@
project:
type: website
pre-render: docs/scripts/generate_config_docs.py
quartodoc:
dir: docs/api
@@ -235,7 +236,7 @@ website:
- docs/installation.qmd
- docs/inference.qmd
- docs/cli.qmd
- docs/config.qmd
- docs/config-reference.qmd
- text: "API Reference"
href: docs/api

View File

@@ -51,5 +51,3 @@ pytest -v --durations=10 \
--cov=axolotl \
--cov-append \
--cov-report=xml:e2e-coverage.xml
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true

View File

@@ -1,20 +1,34 @@
"""Modal app to run axolotl GPU tests"""
import pathlib
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=90 * 60, # 90 min
timeout=120 * 60, # 90 min
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
# Read the coverage file if it exists
coverage_file = pathlib.Path("/workspace/axolotl/e2e-coverage.xml")
if coverage_file.exists():
return coverage_file.read_text(encoding="utf-8")
return None
@app.local_entrypoint()
def main():
cicd_pytest.remote()
coverage = cicd_pytest.remote()
# Save the coverage file to the local filesystem if it was generated
if coverage:
with open("e2e-coverage.xml", "w", encoding="utf-8") as f:
f.write(coverage)

View File

@@ -69,7 +69,7 @@ def run_cmd(cmd: str, run_folder: str):
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=90 * 60,
timeout=120 * 60,
cpu=16.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
@@ -77,7 +77,18 @@ def run_cmd(cmd: str, run_folder: str):
def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
# Read the coverage file if it exists
coverage_file = pathlib.Path("/workspace/axolotl/multigpu-coverage.xml")
if coverage_file.exists():
return coverage_file.read_text(encoding="utf-8")
return None
@app.local_entrypoint()
def main():
cicd_pytest.remote()
coverage = cicd_pytest.remote()
# Save the coverage file to the local filesystem if it was generated
if coverage:
with open("multigpu-coverage.xml", "w", encoding="utf-8") as file:
file.write(coverage)

1
docs/.gitignore vendored
View File

@@ -2,3 +2,4 @@
_site/
/api/*.qmd
/api/*.html
config-reference.qmd

View File

@@ -1,801 +0,0 @@
---
title: Config Reference
description: A complete list of all configuration options.
---
```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# This can also be a relative path to a model on disk
base_model: ./llama-7b-hf
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
base_model_ignore_patterns:
# If the base_model repo on hf hub doesn't include configuration .json files,
# You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub
revision_of_model:
# Optional tokenizer configuration path in case you want to use a different tokenizer
# than the one defined in the base model
tokenizer_config:
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
model_type: AutoModelForCausalLM
# Corresponding tokenizer for the model AutoTokenizer is a good choice
tokenizer_type: AutoTokenizer
# Trust remote code for untrusted source
trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy:
# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer.
tokenizer_use_mistral_common:
# Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
embeddings_skip_upcast:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init_weights:
# (Internal use only)
# Used to identify which the model is based on
is_falcon_derived_model:
is_llama_derived_model:
is_qwen_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:
# optional overrides to the base model configuration
overrides_of_model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# optional overrides the base model loading from_pretrained
overrides_of_model_kwargs:
# use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# quantization aware training
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
# post-training quantization
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
# Use bitsandbytes 4 bit
load_in_4bit:
# Use CUDA bf16
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
# Use CUDA fp16
fp16: true
# Use CUDA tf32
tf32: true # require >=ampere
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
float16: true
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true
# List[str]. Add plugins to extend the pipeline.
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
# https://docs.axolotl.ai/docs/custom_integrations.html
plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with
# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets
# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats
datasets:
# HuggingFace dataset repo | s3:// | gs:// | path to local file or directory
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
name: # Optional[str] name of dataset configuration to load
split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
# Custom user instruction prompt
- path: repo
type:
# The below are defaults. only set what's needed if you use a different column name.
system_prompt: ""
system_format: "{system}"
field_system: system
field_instruction: instruction
field_input: input
field_output: output
# Customizable to be single line or multi-line
# Use {instruction}/{input} as key to be replaced
# 'format' can include {input}
format: |-
User: {instruction} {input}
Assistant:
# 'no_input_format' cannot include {input}
no_input_format: "{instruction} "
# For `completion` datsets only, uses the provided field instead of `text` column
field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja chat template. Used only if `chat_template: jinja` or empty.
chat_template_jinja:
# Key containing the messages (default: "messages")
field_messages: messages
# Key containing the tools (default: "tools")
# Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
field_tools: tools
# Key containing the system message (default: "system")
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
field_system: system
# Mapping of properties from the input dataset to the chat template.
# (default: message_property_mappings={'role':'role', 'content':'content'})
# If a property exists in the template but not in this mapping, the system will attempt
# to load it directly from the message using the property name as the key.
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
# while 'value' is loaded and used as 'content' in the chat template.
message_property_mappings:
role: from
content: value
# ...
# Optional[Dict[str, List]]. Roles mapping in the messages.
# The format is {target_role: [source_roles]}. All source roles will be mapped to the target role.
# The default is:
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant"]
system: ["system"]
tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags
# See example at `docs/dataset-formats/conversation.qmd`
split_thinking:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
# Note: If the below 5 fields are empty, defaults to training only on the last message.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["assistant"] # default
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn (default): train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
train_on_eos: turn
# Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are:
# - all: train on all EOT tokens
# - turn: train on the EOT token at the end of each trainable turn
# - last: train on the last EOT token in the conversation
# If not specified, defaults to the value of train_on_eos for backward compatibility.
train_on_eot:
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
# Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true
# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
- path: /workspace/data/eval.jsonl
ds_type: json
# You need to specify a split. For "json" datasets the default split is called "train".
split: train
type: completion
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
rl:
rl_beta: # Optional[float]. The beta parameter for the RL training.
# dpo
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
# orpo
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
# kto
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
# simpo
cpo_alpha: 1.0 # Weight of the BC regularizer
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation.
temperature: # Optional[float]. Sampling temperature for the GRPO policy.
top_p: # Optional[float]. Top-p sampling probability for the generation policy.
top_k: # Optional[int]. Top-k sampling for the generation policy.
min_p: # Optional[float]. Minimum probability for the generation policy.
repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text.
num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO.
epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm.
epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm.
use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO.
loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.
mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation.
# reward modelling: `True` or `False`
reward_model:
# process reward modelling: `True` or `False`
process_reward_model:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training.
# These tokens mark the boundaries between conversation turns.
# For example: ["/INST", "</s>", "[/SYSTEM_PROMPT]"]
# If not specified, defaults to just the model's eos_token.
# This is useful for templates that use multiple delimiter tokens.
eot_tokens:
# - "</s>"
# - "[/INST]"
# - "[/SYSTEM_PROMPT]"
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub
push_dataset_to_hub: # Optional[str] repo_org/repo_name
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set
# Keep dataset in memory while preprocessing
# Only needed if cached dataset is taking too much storage
dataset_keep_in_memory:
# push checkpoints to hub
hub_model_id: # private repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# Required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
val_set_size: 0.04
# Num shards for whole dataset
dataset_shard_num:
# Index of shard to use for whole dataset
dataset_shard_idx:
# The maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# Pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing:
# Set to 'false' if getting errors during eval with sample_packing on.
eval_sample_packing:
# You can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
# Increasing the following values helps with packing, but usually only slightly (<%1.)
# The number of samples packed at a time.
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
# Use batch flattening for speedups when not using sample_packing
batch_flattening:
# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
device_map:
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
max_memory:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
# If you already have a lora model trained that you want to load, put that here.
# This means after training, if you want to test the model, you should set this to the value of `output_dir`.
# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`.
lora_model_dir:
# LoRA hyperparameters
# For more details about the following options, see:
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
# - gate_proj
# - down_proj
# - up_proj
lora_target_linear: # If true, will target all linear modules
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
peft_layers_to_transform:
# Optional[bool]. Whether to use DoRA.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
peft_use_dora:
# Optional[bool]. Whether to use RSLoRA.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
peft_use_rslora:
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
peft_layer_replication:
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
# How to initialize LoRA weights. Default to True which is MS original implementation.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
peft_init_lora_weights:
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
lora_modules_to_save:
# - embed_tokens
# - lm_head
lora_fan_in_fan_out: false
# Apply custom LoRA autograd functions and activation function Triton kernels for
# speed and memory savings
# See: https://docs.axolotl.ai/docs/lora_optims.html
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
# LoRA+ hyperparameters
# For more details about the following options, see:
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
peft:
# Configuration options for loftq initialization for LoRA
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
loftq_config:
loftq_bits: # typically 4 bits
# ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps
relora_anneal_steps: # Number of anneal steps for each relora cycle
relora_prune_ratio: # threshold for optimizer magnitude when pruning
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team
wandb_watch:
wandb_name: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Tensorboard
use_tensorboard: # Optional[bool]
# Where to save the full-finetuned model to
output_dir: ./completed-model
# Whether to use torch.compile and which backend to use
# setting to `auto` will enable torch compile when torch>=2.5.1
torch_compile: # Optional[Union[Literal["auto"], bool]]
torch_compile_backend: # Optional[str]
torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune'
# Training hyperparameters
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
micro_batch_size: 2
eval_batch_size:
num_epochs: 4
warmup_steps: 100 # cannot use with warmup_ratio
warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool]
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
# snapshots can be visualized @ https://pytorch.org/memory_viz
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package). Default True
save_safetensors:
# Whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
# Group similarly sized data to minimize padding.
# May be slower to start, as it must download and sort the entire dataset.
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
# use_reentrant: true
# Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
# Valid values are driven by the Transformers SchedulerType class, see:
# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420
# Valid values include
# - 'linear'
# - 'cosine' (default)
# - 'cosine_with_restarts'
# - 'polynomial'
# - 'constant'
# - 'constant_with_warmup'
# - 'inverse_sqrt'
# - 'reduce_lr_on_plateau'
# - 'cosine_with_min_lr'
# - 'warmup_stable_decay'
# Additional schedulers include:
# - 'one_cycle'
# - 'rex'
lr_scheduler:
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
# For one_cycle optim
lr_div_factor: # Learning rate div factor
# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
#
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# in the examples/ for your model and fine-tuning use case.
#
# Valid values for 'optimizer' include:
# - adamw_torch
# - adamw_torch_fused (default)
# - adamw_torch_xla
# - adamw_torch_npu_fused
# - adamw_apex_fused
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - adamw_torch_4bit
# - ademamix
# - sgd
# - adagrad
# - adamw_bnb_8bit
# - adamw_8bit # alias for adamw_bnb_8bit
# - ademamix_8bit
# - lion_8bit
# - lion_32bit
# - paged_adamw_32bit
# - paged_adamw_8bit
# - paged_ademamix_32bit
# - paged_ademamix_8bit
# - paged_lion_32bit
# - paged_lion_8bit
# - rmsprop
# - rmsprop_bnb
# - rmsprop_bnb_8bit
# - rmsprop_bnb_32bit
# - galore_adamw
# - galore_adamw_8bit
# - galore_adafactor
# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise
# - lomo
# - adalomo
# - grokadamw
# - schedule_free_adamw
# - schedule_free_sgd
# - apollo_adamw
# - apollo_adamw_layerwise
#
# Additional custom optimizers include:
# - optimi_adamw
# - ao_adamw_8bit
# - ao_adamw_fp8
# - came_pytorch
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:
# For Galore Optimizers the following optim_args are available
# rank: # type: int
# update_proj_gap # type: int
# scale # type: float
# proj_type: # type: str, default = std
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
optim_target_modules:
# - self_attn # for llama
# - mlp
# Specify weight decay
weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_beta3: # only used for CAME Optimizer
adam_epsilon:
adam_epsilon2: # only used for CAME Optimizer
# Gradient clipping max norm
max_grad_norm:
# Augmentation techniques
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral
neftune_noise_alpha:
# Optional[bool]. Whether to bettertransformers
flash_optimum:
# Note: Only one of the following attention patches can be used at a time.
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
# Optional[bool]. Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Optional[str]. Resume from a specific checkpoint dir
resume_from_checkpoint:
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# Be careful with this being turned on between different models.
auto_resume_from_checkpoints: false
## Multimodal section
# int | tuple[int, int] | None . Size to resize images to, width x height.
# Will read from model/processor config if not set.
image_size:
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
image_resize_algorithm: 'bilinear'
## End of multimodal section
# Don't mess with this, it's here for accelerate and torchrun
local_rank:
# Add or change special tokens.
# If you add tokens here, you don't need to add them to the `tokens` list.
special_tokens:
# bos_token: "<s>"
# eos_token: "</s>"
# unk_token: "<unk>"
# pad_token: "[PAD]"
# Optional[list[str]]. Add extra tokens to the tokenizer.
tokens:
# - "<|startoftext|>"
# - "<|endoftext|>"
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP
fsdp:
fsdp_config:
# Deepspeed config path. e.g., deepspeed_configs/zero3.json
deepspeed:
# Advanced DDP Arguments
ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Sequence parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
# in the sample packing case, and "batch_ring" in the non-sample packing case.
ring_attn_func:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:
# Debug mode
debug:
# Seed
seed:
# Allow overwrite yml config using from cli
strict:
```

View File

@@ -12,7 +12,7 @@ Chat Template strategy uses a jinja2 template that converts a list of messages i
{"conversations": [{"role": "...", "content": "..."}]}
```
See [configs](../config.qmd) for full configs and supported templates.
See [configs](../config-reference.qmd) for full configs and supported templates.
### Migrating from sharegpt
@@ -130,13 +130,13 @@ datasets:
```
::: {.callout-tip}
See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
See [config documentation](../config-reference.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
:::
::: {.callout-note}
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config-reference.qmd) for more details.
:::
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.

View File

@@ -186,4 +186,4 @@ datasets:
no_input_format: "[INST] {instruction} [/INST]"
```
See full config options under [here](../config.qmd).
See full config options under [here](../config-reference.qmd).

View File

@@ -36,7 +36,7 @@ This matches the API of [`datasets.load_dataset`](https://github.com/huggingface
For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading).
For full details on the config, see [config.qmd](config.qmd).
For full details on the config, see [config-reference.qmd](config-reference.qmd).
::: {.callout-note}

View File

@@ -55,7 +55,7 @@ output_dir: ./outputs/lora-out
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
:::
See our [Config options](config.qmd) for more details.
See our [config options](config-reference.qmd) for more details.
### Training {#sec-training}
@@ -179,7 +179,7 @@ Now that you have the basics, you might want to:
Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options
- [Configuration Guide](config-reference.qmd) - Full configuration options
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)

View File

@@ -14,7 +14,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
## Requirements {#sec-requirements}
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- Python ≥3.11
- PyTorch ≥2.5.1
## Installation Methods {#sec-installation-methods}
@@ -153,7 +153,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
### Conda/Pip venv {#sec-conda}
1. Install Python ≥3.10
1. Install Python ≥3.11
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}

View File

@@ -32,7 +32,7 @@ output_dir: # The path to the output directory.
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.qmd) - you can do this by using the existing QAT configuration file which
you used to train the model:
```yaml

View File

@@ -0,0 +1,752 @@
# type: ignore
"""
Quarto documentation generation from Pydantic models. Uses Pydantic model source code
to automatically group fields, including inherited fields from parent classes.
"""
import ast
import inspect
import textwrap
import types
import typing
from typing import Any, FrozenSet, Type, Union
from pydantic import BaseModel
from axolotl.utils.schemas.config import AxolotlInputConfig
class QuartoGenerator:
"""Generate Quarto documentation from Pydantic models."""
def __init__(self):
self._class_fields_cache = {}
self._inheritance_map_cache = {}
self._nested_models_cache = {}
def _get_direct_fields(self, cls: Type[BaseModel]) -> FrozenSet[str]:
"""Get fields defined directly in a single class (not inherited)."""
if cls in self._class_fields_cache:
return self._class_fields_cache[cls]
fields = set()
# Get annotated fields
if hasattr(cls, "__annotations__"):
fields.update(cls.__annotations__.keys())
# Filter out private/special methods
fields = {f for f in fields if not f.startswith("_")}
result = frozenset(fields)
self._class_fields_cache[cls] = result
return result
def _is_pydantic_model(self, type_obj) -> bool:
"""Check if a type is a Pydantic BaseModel."""
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
# pylint: disable=too-many-return-statements
def _extract_nested_type(self, field_type) -> Any:
"""Extract the actual type from complex type annotations."""
# Handle Annotated types (Python 3.9+)
if hasattr(typing, "get_origin") and hasattr(typing, "get_args"):
origin = typing.get_origin(field_type)
args = typing.get_args(field_type)
if origin is not None:
# Handle Annotated[SomeType, ...] - extract the first argument
if hasattr(typing, "Annotated") and origin is typing.Annotated:
if args:
return self._extract_nested_type(
args[0]
) # Recursively process the actual type
# Handle list[SomeType], List[SomeType], etc.
elif origin in (list, typing.List):
if args:
return self._extract_nested_type(
args[0]
) # Extract element type
# Handle Union types (including | syntax)
elif origin is typing.Union:
# Get non-None types from the Union
non_none_types = [arg for arg in args if arg is not type(None)]
if len(non_none_types) >= 1:
# Prioritize Pydantic models over primitive types
pydantic_models = [
arg
for arg in non_none_types
if self._is_pydantic_model(arg)
]
if pydantic_models:
# Return the first Pydantic model found
return self._extract_nested_type(pydantic_models[0])
# No Pydantic models, return the first non-None type
return self._extract_nested_type(non_none_types[0])
# Handle new Python 3.10+ union syntax (PeftConfig | None)
if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType:
# Get non-None types from the Union
non_none_types = [
arg for arg in field_type.__args__ if arg is not type(None)
]
if len(non_none_types) >= 1:
# Prioritize Pydantic models over primitive types
pydantic_models = [
arg for arg in non_none_types if self._is_pydantic_model(arg)
]
if pydantic_models:
return self._extract_nested_type(pydantic_models[0])
return self._extract_nested_type(non_none_types[0])
# Handle old typing.Union syntax (fallback)
if hasattr(field_type, "__origin__"):
if field_type.__origin__ is Union:
# Get non-None types from the Union
non_none_types = [
arg for arg in field_type.__args__ if arg is not type(None)
]
if len(non_none_types) >= 1:
# Prioritize Pydantic models over primitive types
pydantic_models = [
arg for arg in non_none_types if self._is_pydantic_model(arg)
]
if pydantic_models:
return self._extract_nested_type(pydantic_models[0])
return self._extract_nested_type(non_none_types[0])
# Handle other generic types like dict[str, Any], etc.
elif hasattr(field_type, "__args__"):
return field_type
return field_type
# pylint: disable=too-many-return-statements
def _extract_all_pydantic_models_from_type(
self, field_type
) -> list[type[BaseModel]]:
"""Extract all Pydantic models from a type annotation, including from Unions."""
models = []
if field_type is None:
return models
# Handle Annotated types
if hasattr(typing, "get_origin") and hasattr(typing, "get_args"):
origin = typing.get_origin(field_type)
args = typing.get_args(field_type)
if origin is not None:
# Handle Annotated[SomeType, ...] - extract from the first argument
if hasattr(typing, "Annotated") and origin is typing.Annotated:
if args:
models.extend(
self._extract_all_pydantic_models_from_type(args[0])
)
return models
# Handle list[SomeType], List[SomeType], etc.
if origin in (list, typing.List):
if args:
models.extend(
self._extract_all_pydantic_models_from_type(args[0])
)
return models
# Handle Union types
if origin is typing.Union:
for arg in args:
if arg is not type(None): # Skip None type
models.extend(
self._extract_all_pydantic_models_from_type(arg)
)
return models
# Handle new Python 3.10+ union syntax
if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType:
for arg in field_type.__args__:
if arg is not type(None): # Skip None type
models.extend(self._extract_all_pydantic_models_from_type(arg))
return models
# Handle old typing.Union syntax (fallback)
if hasattr(field_type, "__origin__") and field_type.__origin__ is Union:
for arg in field_type.__args__:
if arg is not type(None): # Skip None type
models.extend(self._extract_all_pydantic_models_from_type(arg))
return models
# Check if this type itself is a Pydantic model
if self._is_pydantic_model(field_type):
models.append(field_type)
return models
def _get_nested_models(
self, model_class: type[BaseModel], visited=None
) -> dict[str, type[BaseModel]]:
"""Get all nested Pydantic models from a model class."""
if visited is None:
visited = set()
# Avoid infinite recursion
if model_class in visited:
return {}
if model_class in self._nested_models_cache:
return self._nested_models_cache[model_class]
visited.add(model_class)
nested_models = {}
# Check all fields in the model
for field_info in model_class.model_fields.values():
field_type = self._extract_nested_type(field_info.annotation)
if self._is_pydantic_model(field_type):
nested_models[field_type.__name__] = field_type
# Recursively get nested models from this nested model
deeper_nested = self._get_nested_models(field_type, visited.copy())
nested_models.update(deeper_nested)
self._nested_models_cache[model_class] = nested_models
return nested_models
def _build_inheritance_map(self, child_class: Type[BaseModel]):
"""Build inheritance map for a class and all its parents."""
if child_class in self._inheritance_map_cache:
return self._inheritance_map_cache[child_class]
inheritance_map = {}
# Get MRO and filter out BaseModel and object
mro_classes = [
cls
for cls in child_class.__mro__
if cls not in (BaseModel, object) and hasattr(cls, "__annotations__")
]
# Process each class in the MRO
for cls in mro_classes:
inheritance_map[cls] = self._get_direct_fields(cls)
self._inheritance_map_cache[child_class] = inheritance_map
return inheritance_map
def _wrap_comment(self, text: str, width: int = 88) -> list[str]:
"""Wrap a comment to specified width, accounting for '# ' prefix."""
if not text.strip():
return ["#"]
# Account for "# " prefix (2 characters)
content_width = width - 2
wrapped_lines = textwrap.wrap(text, width=content_width)
return [f"# {line}" for line in wrapped_lines]
def _extract_type_from_source(
self, model_class: type[BaseModel], field_name: str
) -> str:
"""Extract the actual type annotation text from source code, checking inheritance chain."""
# Use inheritance map to check classes efficiently
inheritance_map = self._build_inheritance_map(model_class)
# Check classes in MRO order
for cls in model_class.__mro__:
if cls in inheritance_map and field_name in inheritance_map[cls]:
type_annotation = self._get_type_from_class_source(cls, field_name)
if type_annotation != "unknown":
return type_annotation
return "unknown"
def _get_type_from_class_source(self, class_obj: type, field_name: str) -> str:
"""Extract type annotation from a specific class's source code."""
try:
source = inspect.getsource(class_obj)
tree = ast.parse(source)
except (OSError, TypeError):
return "unknown"
# Find the class definition
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__:
# Find the field assignment
for body_node in node.body:
if isinstance(body_node, ast.AnnAssign) and isinstance(
body_node.target, ast.Name
):
if body_node.target.id == field_name and body_node.annotation:
return ast.unparse(body_node.annotation)
break
return "unknown"
def _extract_field_groups_from_all_classes(
self, model_class: type[BaseModel]
) -> list[dict]:
"""Extract field groups from all classes in the inheritance hierarchy."""
all_groups = []
inheritance_map = self._build_inheritance_map(model_class)
# Get all Pydantic base classes in MRO order (most specific first)
# This puts AxolotlInputConfig fields first, then parent class fields
pydantic_classes = [
cls
for cls in model_class.__mro__
if cls in inheritance_map and inheritance_map[cls]
]
# Extract groups from each class
for cls in pydantic_classes:
class_groups = self._extract_field_groups_from_source(cls)
for group in class_groups:
all_groups.append(group)
# If no groups found, create a default grouping by class
if not all_groups:
for cls in pydantic_classes:
fields_in_class = inheritance_map[cls]
if fields_in_class:
all_groups.append(
{
"fields": list(fields_in_class),
}
)
return all_groups
# pylint: disable=too-many-return-statements
def _extract_field_groups_from_source(
self, model_class: type[BaseModel]
) -> list[dict]:
"""Extract field groups from source code based on blank lines and comments."""
try:
source = inspect.getsource(model_class)
tree = ast.parse(source)
except (OSError, TypeError):
# Fallback if we can't get source code
fields_in_class = self._get_direct_fields(model_class)
if fields_in_class:
return [
{
"fields": list(fields_in_class),
}
]
return []
groups = []
current_group_fields = []
current_group_comment = None
# Find the class definition
class_node = None
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == model_class.__name__:
class_node = node
break
if not class_node:
fields_in_class = self._get_direct_fields(model_class)
if fields_in_class:
return [
{
"fields": list(fields_in_class),
}
]
return []
# Parse the source lines to detect groupings
source_lines = source.split("\n")
# Get fields that are actually defined in this specific class
fields_in_class = self._get_direct_fields(model_class)
# Find assignments that correspond to model fields for THIS class only
field_assignments = []
for node in class_node.body:
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
field_name = node.target.id
if field_name in fields_in_class:
field_assignments.append(
{
"name": field_name,
"lineno": node.lineno,
"end_lineno": getattr(node, "end_lineno", node.lineno),
}
)
if not field_assignments:
if fields_in_class:
return [
{
"fields": list(fields_in_class),
}
]
return []
# Sort by line number
field_assignments.sort(key=lambda x: x["lineno"])
# Group fields based on blank lines and comments
for i, field_info in enumerate(field_assignments):
field_name = field_info["name"]
current_line = field_info["lineno"]
# Check if this starts a new group (blank line before or significant gap)
is_new_group = False
if i == 0:
is_new_group = True
else:
prev_end_line = field_assignments[i - 1]["end_lineno"]
# Check for blank lines or comments between fields
lines_between = source_lines[prev_end_line : current_line - 1]
has_blank_line = any(line.strip() == "" for line in lines_between)
has_comment = any(
line.strip().startswith("#") for line in lines_between
)
# Start new group if there's a blank line or comment, or significant gap
if has_blank_line or has_comment or (current_line - prev_end_line > 3):
is_new_group = True
if is_new_group and current_group_fields:
# Save the previous group
groups.append(
{
"fields": current_group_fields.copy(),
"description": current_group_comment,
}
)
current_group_fields = []
current_group_comment = None
current_group_fields.append(field_name)
# Add the final group
if current_group_fields:
groups.append(
{
"fields": current_group_fields,
"description": current_group_comment,
}
)
return groups
def _generate_field_documentation(
self,
model_class: type[BaseModel],
field_name: str,
field_info: dict,
field_type_str: str,
is_required: bool,
indent_level: int = 0,
visited_models: set = None,
) -> list[str]:
"""Generate documentation for a single field, expanding nested models inline."""
if visited_models is None:
visited_models = set()
lines = []
indent = " " * indent_level
# Get the actual field type for nested model detection
if field_name in model_class.model_fields:
pydantic_field_info = model_class.model_fields[field_name]
actual_field_type = pydantic_field_info.annotation
else:
actual_field_type = None
# Add description comment if available
description = field_info.get("description", "")
if description:
wrapped_lines = self._wrap_comment(description, width=88 - len(indent))
for line in wrapped_lines:
lines.append(f"{indent}{line}")
# Extract nested Pydantic models from the type annotation
nested_models = self._extract_all_pydantic_models_from_type(actual_field_type)
# Filter out already visited models to prevent infinite recursion
expandable_models = [
model for model in nested_models if model not in visited_models
]
if expandable_models:
# This field contains Pydantic models that can be expanded
# Show the field with its full type annotation
field_line = f"{indent}{field_name}: {field_type_str}"
if field_info.get("default") is not None:
field_line += f" = {field_info['default']}"
if is_required:
field_line += " (required)"
lines.append(field_line)
# Add to visited to prevent infinite recursion
new_visited = visited_models.copy()
new_visited.update(expandable_models)
# Expand each nested Pydantic model
for i, nested_model in enumerate(expandable_models):
if i > 0:
lines.append("\n")
lines.append(f"{indent} # For {nested_model.__name__}:")
# Get nested model schema
try:
nested_schema = nested_model.model_json_schema()
nested_properties = nested_schema.get("properties", {})
nested_required = nested_schema.get("required", [])
except Exception: # pylint: disable=broad-exception-caught
# Fallback: use model fields directly
nested_properties = {}
nested_required = []
for (
nested_field_name,
nested_field_info,
) in nested_model.model_fields.items():
nested_description = ""
if (
hasattr(nested_field_info, "json_schema_extra")
and nested_field_info.json_schema_extra
):
nested_description = (
nested_field_info.json_schema_extra.get(
"description", ""
)
)
elif (
hasattr(nested_field_info, "description")
and nested_field_info.description
):
nested_description = nested_field_info.description
nested_default_val = None
if (
hasattr(nested_field_info, "default")
and nested_field_info.default is not None
):
if str(nested_field_info.default) != "PydanticUndefined":
nested_default_val = nested_field_info.default
nested_properties[nested_field_name] = {
"type": "unknown",
"description": nested_description,
"default": nested_default_val,
}
if nested_field_info.is_required():
nested_required.append(nested_field_name)
# Get field groups for the nested model
nested_field_groups = self._extract_field_groups_from_all_classes(
nested_model
)
# Generate nested fields with increased indentation
for i, group in enumerate(nested_field_groups):
if not group["fields"]:
continue
# Add blank line between groups (except before first group)
if i > 0:
lines.append("")
# Process nested fields
for nested_field_name in group["fields"]:
if nested_field_name not in nested_properties:
continue
nested_field_info = nested_properties[nested_field_name]
nested_field_type = self._extract_type_from_source(
nested_model, nested_field_name
)
nested_is_required = nested_field_name in nested_required
# Recursively generate documentation for nested field
nested_lines = self._generate_field_documentation(
nested_model,
nested_field_name,
nested_field_info,
nested_field_type,
nested_is_required,
indent_level + 1,
new_visited,
)
lines.extend(nested_lines)
else:
# Regular field (no expandable nested models)
field_line = f"{indent}{field_name}: {field_type_str}"
if field_info.get("default") is not None:
field_line += f" = {field_info['default']}"
if is_required:
field_line += " (required)"
lines.append(field_line)
return lines
def generate_qmd(
self,
model_class: type[BaseModel],
title: str | None = None,
expand_nested: bool = True,
) -> str:
"""Auto-generate config reference documentation including inherited fields."""
if title is None:
title = f"{model_class.__name__} Reference"
# Try to get JSON schema, with fallback for serialization issues
try:
schema = model_class.model_json_schema()
properties = schema.get("properties", {})
required = schema.get("required", [])
except Exception as e: # pylint: disable=broad-exception-caught
print(
f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
)
# Fallback: use model fields directly
properties = {}
required = []
for field_name, field_info in model_class.model_fields.items():
# Extract description from json_schema_extra or field info
description = ""
if (
hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra
):
description = field_info.json_schema_extra.get("description", "")
elif hasattr(field_info, "description") and field_info.description:
description = field_info.description
# Get default value
default_val = None
if hasattr(field_info, "default") and field_info.default is not None:
# Handle special Pydantic default markers
if str(field_info.default) != "PydanticUndefined":
default_val = field_info.default
properties[field_name] = {
"type": "unknown",
"description": description,
"default": default_val,
}
if field_info.is_required():
required.append(field_name)
# Extract field groups from all classes in inheritance hierarchy
field_groups = self._extract_field_groups_from_all_classes(model_class)
# Start building QMD content
qmd_lines = [
"---",
f"title: {title}",
"description: A complete list of all configuration options.",
"---",
"",
]
# Generate one big code block with all fields (inline nested expansion)
qmd_lines.append("```yaml")
for i, group in enumerate(field_groups):
if not group["fields"]:
continue
# Add blank line between groups (except before first group)
if i > 0:
qmd_lines.append("")
# Process fields in the order they appear in source
for field_name in group["fields"]:
if field_name not in properties:
continue
field_info = properties[field_name]
field_type = self._extract_type_from_source(model_class, field_name)
is_required = field_name in required
if expand_nested:
# Check if this field has nested models
if field_name in model_class.model_fields:
pydantic_field_info = model_class.model_fields[field_name]
nested_models = self._extract_all_pydantic_models_from_type(
pydantic_field_info.annotation
)
has_nested = bool(nested_models)
else:
has_nested = False
# Add blank line before nested config
if has_nested:
qmd_lines.append("")
# Use the new inline generation method
field_lines = self._generate_field_documentation(
model_class,
field_name,
field_info,
field_type,
is_required,
indent_level=0,
visited_models=set(),
)
qmd_lines.extend(field_lines)
# Add blank line after nested config
if has_nested:
qmd_lines.append("")
else:
# Original simple approach
description = field_info.get("description", "")
default = field_info.get("default")
# Add wrapped comment for description
if description:
wrapped_lines = self._wrap_comment(description)
qmd_lines.extend(wrapped_lines)
line = f"{field_name}: {field_type}"
if default is not None:
line += f" = {default}"
if is_required:
line += " (required)"
qmd_lines.append(line)
qmd_lines.append("```")
# Join all lines and clean up any double newlines
content = "\n".join(qmd_lines)
# Replace multiple consecutive newlines with just two newlines (one blank line)
import re
content = re.sub(r"\n{3,}", "\n\n", content)
# Ensure single newline at the very end
content = content.rstrip("\n") + "\n"
return content
def main():
generator = QuartoGenerator()
print("Generating config reference content...")
qmd_content = generator.generate_qmd(AxolotlInputConfig, "Config Reference", True)
print("Writing to file...")
with open("docs/config-reference.qmd", "w", encoding="utf-8") as f:
f.write(qmd_content)
print("Done!")
if __name__ == "__main__":
main()

View File

@@ -13,7 +13,7 @@ packaging==23.2
huggingface_hub==0.32.2
peft==0.15.2
transformers==4.52.3
transformers==4.52.4
tokenizers>=0.21.1
accelerate==1.7.0
datasets==3.6.0

View File

@@ -118,7 +118,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.0",
"deepspeed==0.17.1",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.10.0"
__version__ = "0.11.0.dev"

View File

@@ -2,7 +2,6 @@
model patcher for chunked top-k kl-div
"""
from types import MethodType
from typing import Optional, Union, Unpack
import torch
@@ -95,4 +94,4 @@ def apply_kernel(model_type):
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = MethodType(kldiv_forward_llama_like, model_cls)
model_cls.forward = kldiv_forward_llama_like

View File

@@ -25,12 +25,20 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.axolotl_level = logging.getLevelNamesMapping()[
os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL)
]
self.other_level = logging.getLevelNamesMapping()[
os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
]
axolotl_log_level = os.getenv(
"AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL
).upper()
other_log_level = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper()
try:
# py311+ only
level_mapping = logging.getLevelNamesMapping()
self.axolotl_level = level_mapping[axolotl_log_level]
self.other_level = level_mapping[other_log_level]
except AttributeError:
# For py310, use getLevelName directly
self.axolotl_level = logging.getLevelName(axolotl_log_level)
self.other_level = logging.getLevelName(other_log_level)
def filter(self, record: LogRecord) -> bool:
# General filter

View File

@@ -596,11 +596,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if (
turn_idx == 0
and turns[0].get("role") == "system"
and (
"mistral" in self.tokenizer.name_or_path.lower()
or "gemma"
in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer
)
and ("mistral" in self.tokenizer.name_or_path.lower())
):
return -1, -1

View File

@@ -3,6 +3,7 @@
from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes
from jinja2.ext import Extension
class JinjaTemplateAnalysis(TypedDict):
@@ -27,6 +28,18 @@ class JinjaTemplateAnalysis(TypedDict):
iteration_target: Optional[Union[str, list[str]]]
class GenerationTagIgnore(Extension):
"""
Ignores the generation and endgeneration tags in Jinja templates.
"""
tags = {"generation", "endgeneration"}
def parse(self, parser):
parser.stream.skip(1)
return nodes.Const("")
class JinjaTemplateAnalyzer:
"""
Analyzes Jinja templates to extract information about variable usage,
@@ -57,7 +70,9 @@ class JinjaTemplateAnalyzer:
"""
def __init__(self, template: str):
self.env: Environment = Environment(autoescape=True)
self.env: Environment = Environment(
autoescape=True, extensions=[GenerationTagIgnore]
)
self.property_access: Dict[str, Set[str]] = {}
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
self.index_access: Dict[str, Set[Union[int, float]]] = {}

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,8 @@
"""Pydantic models for datasets-related configuration"""
from pydantic import BaseModel, model_validator
from typing import Literal
from pydantic import BaseModel, Field, model_validator
from axolotl.utils.schemas.enums import ChatTemplate
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
@@ -9,57 +11,178 @@ from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
class UserDefinedPrompterType(BaseModel):
"""Structure for user defined prompt types"""
system_prompt: str | None = None
system_format: str | None = None
system_prompt: str | None = Field(
default=None,
json_schema_extra={"description": "Custom user instruction prompt"},
)
system_format: str | None = Field(
default=None,
json_schema_extra={"description": "Use {system} as key to be replaced"},
)
field_system: str | None = None
field_instruction: str | None = None
field_input: str | None = None
field_output: str | None = None
format: str | None = None
no_input_format: str | None = None
field: str | None = None
format: str | None = Field(
default=None,
json_schema_extra={
"description": "Customizable to be single line or multi-line. Use {instruction}/{input} as key to be replaced. 'format' can include {input}"
},
)
no_input_format: str | None = Field(
default=None,
json_schema_extra={"description": "'no_input_format' cannot include {input}"},
)
field: str | None = Field(
default=None,
json_schema_extra={
"description": "For `completion` datsets only, uses the provided field instead of `text` column"
},
)
class SFTDataset(BaseModel):
"""SFT configuration subset"""
path: str | None = None
split: str | None = None
type: str | UserDefinedPrompterType | None = None
path: str | None = Field(
default=None,
json_schema_extra={
"description": "HuggingFace dataset repo | s3:// | gs:// | path to local file or directory"
},
)
split: str | None = Field(
default=None,
json_schema_extra={"description": "name of dataset split to load from"},
)
type: str | UserDefinedPrompterType | None = Field(
default=None,
json_schema_extra={
"description": "The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]"
},
)
input_transform: str | None = None
shards: int | None = None
shards_idx: int | None = None
preprocess_shards: int | None = None
shards: int | None = Field(
default=None,
json_schema_extra={
"description": "split dataset into N pieces (use with shards_idx)"
},
)
shards_idx: int | None = Field(
default=None,
json_schema_extra={"description": "the index of sharded dataset to use"},
)
preprocess_shards: int | None = Field(
default=None,
json_schema_extra={
"description": "process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)"
},
)
conversation: str | None = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: ChatTemplate | str | None = None
chat_template_jinja: str | None = None
data_files: str | list[str] | None = None
chat_template: ChatTemplate | str | None = Field(
default=None,
json_schema_extra={
"description": "The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field."
},
)
chat_template_jinja: str | None = Field(
default=None,
json_schema_extra={
"description": "Custom jinja chat template. Used only if `chat_template: jinja` or empty."
},
)
data_files: str | list[str] | None = Field(
default=None, json_schema_extra={"description": "path to source data files"}
)
input_format: str | None = None
name: str | None = None
ds_type: str | None = None
name: str | None = Field(
default=None,
json_schema_extra={"description": "name of dataset configuration to load"},
)
ds_type: str | None = Field(
default=None,
json_schema_extra={"description": "defines the datatype when path is a file"},
)
field: str | None = None
field_human: str | None = None
field_model: str | None = None
field_messages: str | None = None
field_tools: str | None = None
field_messages: str | None = Field(
default=None,
json_schema_extra={
"description": 'Key containing the messages (default: "messages")'
},
)
field_tools: str | None = Field(
default=None,
json_schema_extra={
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
},
)
# deprecated, use message_property_mappings
message_field_role: str | None = None
# deprecated, use message_property_mappings
message_field_content: str | None = None
message_property_mappings: dict[str, str] | None = None
message_field_training: str | None = None
message_field_training_detail: str | None = None
split_thinking: bool | None = None
message_property_mappings: dict[str, str] | None = Field(
default=None,
json_schema_extra={
"description": "Mapping of properties from the input dataset to the chat template. (default: message_property_mappings={'role':'role', 'content':'content'}) If a property exists in the template but not in this mapping, the system will attempt to load it directly from the message using the property name as the key. Example: In the mapping below, 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and used as 'content' in the chat template."
},
)
message_field_training: str | None = Field(
default=None,
json_schema_extra={
"description": "The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`."
},
)
message_field_training_detail: str | None = Field(
default=None,
json_schema_extra={
"description": "The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train)."
},
)
split_thinking: bool | None = Field(
default=None,
json_schema_extra={
"description": "(for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags"
},
)
logprobs_field: str | None = None
temperature: float | None = None
roles_to_train: list[str] | None = None
train_on_eos: str | None = None
roles: dict[str, list[str]] | None = None
drop_system_message: bool | None = None
trust_remote_code: bool | None = False
revision: str | None = None
roles_to_train: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "Roles to train on. The tokens from these roles will be considered for the loss."
},
)
train_on_eos: Literal["all", "turn", "last"] | None = Field(
default=None,
json_schema_extra={
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation"
},
)
roles: dict[str, list[str]] | None = Field(
default=None,
json_schema_extra={
"description": 'Roles mapping in the messages. The format is {target_role: [source_roles]}. All source roles will be mapped to the target role. The default is: user: ["human", "user"], assistant: ["gpt", "assistant"], system: ["system"], tool: ["tool"]'
},
)
drop_system_message: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to drop the system turn from the dataset. Only works with chat_template. This does not drop the default system message from chat_template if it exists. If you wish to, we recommend using a custom jinja template with the default system message removed or adding a system turn with empty content."
},
)
trust_remote_code: bool | None = Field(
default=False,
json_schema_extra={"description": "Trust remote code for untrusted source"},
)
revision: str | None = Field(
default=None,
json_schema_extra={
"description": "The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets."
},
)
@model_validator(mode="before")
@classmethod

View File

@@ -60,10 +60,30 @@ class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names"""
overrides_of_model_config: dict[str, Any] | None = Field(
default=None, alias="model_config"
default=None,
alias="model_config",
json_schema_extra={
"description": "optional overrides to the base model configuration"
},
)
overrides_of_model_kwargs: dict[str, Any] | None = Field(
default=None, alias="model_kwargs"
default=None,
alias="model_kwargs",
json_schema_extra={
"description": "optional overrides the base model loading from_pretrained"
},
)
type_of_model: str | None = Field(
default=None,
alias="model_type",
json_schema_extra={
"description": "If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too"
},
)
revision_of_model: str | None = Field(
default=None,
alias="model_revision",
json_schema_extra={
"description": "You can specify to choose a specific model revision from huggingface hub"
},
)
type_of_model: str | None = Field(default=None, alias="model_type")
revision_of_model: str | None = Field(default=None, alias="model_revision")

View File

@@ -1,5 +1,7 @@
"""Enums for Axolotl input config"""
# pylint: disable=invalid-name
from enum import Enum
import torch
@@ -8,81 +10,81 @@ import torch
class TorchIntDType(Enum):
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name
uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name
uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name
uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name
uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name
uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name
uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name
int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name
int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name
uint1 = getattr(torch, "uint1", None)
uint2 = getattr(torch, "uint2", None)
uint3 = getattr(torch, "uint3", None)
uint4 = getattr(torch, "uint4", None)
uint5 = getattr(torch, "uint5", None)
uint6 = getattr(torch, "uint6", None)
uint7 = getattr(torch, "uint7", None)
int4 = getattr(torch, "int4", None)
int8 = getattr(torch, "int8", None)
class RLType(str, Enum):
"""RL trainer type configuration subset"""
DPO = "dpo" # pylint: disable=invalid-name
GRPO = "grpo" # pylint: disable=invalid-name
IPO = "ipo" # pylint: disable=invalid-name
ORPO = "orpo" # pylint: disable=invalid-name
KTO = "kto" # pylint: disable=invalid-name
SIMPO = "simpo" # pylint: disable=invalid-name
DPO = "dpo"
GRPO = "grpo"
IPO = "ipo"
ORPO = "orpo"
KTO = "kto"
SIMPO = "simpo"
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
llama4 = "llama4" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
qwen3 = "qwen3" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name
pixtral = "pixtral" # pylint: disable=invalid-name
llava = "llava" # pylint: disable=invalid-name
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
gemma3 = "gemma3" # pylint: disable=invalid-name
command_a = "command_a" # pylint: disable=invalid-name
command_a_tool_use = "command_a_tool_use" # pylint: disable=invalid-name
command_a_rag = "command_a_rag" # pylint: disable=invalid-name
aya = "aya" # pylint: disable=invalid-name
alpaca = "alpaca"
chatml = "chatml"
mistral_v1 = "mistral_v1"
mistral_v2v3 = "mistral_v2v3"
mistral_v3_tekken = "mistral_v3_tekken"
mistral_v7_tekken = "mistral_v7_tekken"
gemma = "gemma"
cohere = "cohere"
llama3 = "llama3"
llama3_2_vision = "llama3_2_vision"
llama4 = "llama4"
phi_3 = "phi_3"
phi_35 = "phi_35"
deepseek_v2 = "deepseek_v2"
deepseek_v3 = "deepseek_v3"
jamba = "jamba"
jinja = "jinja"
qwen_25 = "qwen_25"
qwen3 = "qwen3"
tokenizer_default = "tokenizer_default"
exaone = "exaone"
metharme = "metharme"
pixtral = "pixtral"
llava = "llava"
qwen2_vl = "qwen2_vl"
gemma3 = "gemma3"
command_a = "command_a"
command_a_tool_use = "command_a_tool_use"
command_a_rag = "command_a_rag"
aya = "aya"
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
optimi_adamw = "optimi_adamw"
ao_adamw_4bit = "ao_adamw_4bit"
ao_adamw_8bit = "ao_adamw_8bit"
ao_adamw_fp8 = "ao_adamw_fp8"
adopt_adamw = "adopt_adamw"
came_pytorch = "came_pytorch"
muon = "muon"
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
# BATCH_ZIGZAG = "batch_zigzag"
# BATCH_STRIPE = "batch_stripe"

View File

@@ -13,10 +13,21 @@ class MLFlowConfig(BaseModel):
"""MLFlow configuration subset"""
use_mlflow: bool | None = None
mlflow_tracking_uri: str | None = None
mlflow_experiment_name: str | None = None
mlflow_run_name: str | None = None
hf_mlflow_log_artifacts: bool | None = None
mlflow_tracking_uri: str | None = Field(
default=None, json_schema_extra={"description": "URI to mlflow"}
)
mlflow_experiment_name: str | None = Field(
default=None, json_schema_extra={"description": "Your experiment name"}
)
mlflow_run_name: str | None = Field(
default=None, json_schema_extra={"description": "Your run name"}
)
hf_mlflow_log_artifacts: bool | None = Field(
default=None,
json_schema_extra={
"description": "set to true to copy each saved checkpoint on each save to mlflow artifact registry"
},
)
class LISAConfig(BaseModel):
@@ -40,13 +51,33 @@ class WandbConfig(BaseModel):
"""Wandb configuration subset"""
use_wandb: bool | None = None
wandb_name: str | None = None
wandb_run_id: str | None = None
wandb_mode: str | None = None
wandb_project: str | None = None
wandb_entity: str | None = None
wandb_name: str | None = Field(
default=None,
json_schema_extra={"description": "Set the name of your wandb run"},
)
wandb_run_id: str | None = Field(
default=None, json_schema_extra={"description": "Set the ID of your wandb run"}
)
wandb_mode: str | None = Field(
default=None,
json_schema_extra={
"description": '"offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb'
},
)
wandb_project: str | None = Field(
default=None, json_schema_extra={"description": "Your wandb project name"}
)
wandb_entity: str | None = Field(
default=None,
json_schema_extra={"description": "A wandb Team name if using a Team"},
)
wandb_watch: str | None = None
wandb_log_model: str | None = None
wandb_log_model: str | None = Field(
default=None,
json_schema_extra={
"description": '"checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training'
},
)
@model_validator(mode="before")
@classmethod
@@ -64,14 +95,52 @@ class WandbConfig(BaseModel):
class CometConfig(BaseModel):
"""Comet configuration subset"""
use_comet: bool | None = None
comet_api_key: str | None = None
comet_workspace: str | None = None
comet_project_name: str | None = None
comet_experiment_key: str | None = None
comet_mode: str | None = None
comet_online: bool | None = None
comet_experiment_config: dict[str, Any] | None = None
use_comet: bool | None = Field(
default=None,
json_schema_extra={"description": "Enable or disable Comet integration."},
)
comet_api_key: str | None = Field(
default=None,
json_schema_extra={
"description": "API key for Comet. Recommended to set via `comet login`."
},
)
comet_workspace: str | None = Field(
default=None,
json_schema_extra={
"description": "Workspace name in Comet. Defaults to the user's default workspace."
},
)
comet_project_name: str | None = Field(
default=None,
json_schema_extra={
"description": "Project name in Comet. Defaults to Uncategorized."
},
)
comet_experiment_key: str | None = Field(
default=None,
json_schema_extra={
"description": "Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key."
},
)
comet_mode: str | None = Field(
default=None,
json_schema_extra={
"description": 'Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.'
},
)
comet_online: bool | None = Field(
default=None,
json_schema_extra={
"description": "Set to True to log data to Comet server, or False for offline storage. Default is True."
},
)
comet_experiment_config: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Dictionary for additional configuration settings, see the doc for more details."
},
)
class GradioConfig(BaseModel):

View File

@@ -12,20 +12,55 @@ class ModelInputConfig(BaseModel):
model_config = {"protected_namespaces": ()}
base_model: str
base_model_config: str | None = None
base_model: str = Field(
json_schema_extra={
"description": "This is the huggingface model that contains *.pt, *.safetensors, or *.bin files. This can also be a relative path to a model on disk"
}
)
base_model_config: str | None = Field(
default=None,
json_schema_extra={
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
},
)
cls_model_config: str | None = None
tokenizer_config: str | None = None
tokenizer_use_fast: bool | None = None
tokenizer_legacy: bool | None = None
tokenizer_use_mistral_common: bool | None = None
tokenizer_config: str | None = Field(
default=None,
json_schema_extra={
"description": "Optional tokenizer configuration path in case you want to use a different tokenizer than the one defined in the base model"
},
)
tokenizer_use_fast: bool | None = Field(
default=None,
json_schema_extra={
"description": "use_fast option for tokenizer loading from_pretrained, default to True"
},
)
tokenizer_legacy: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use the legacy tokenizer setting, defaults to True"
},
)
tokenizer_use_mistral_common: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer."
},
)
tokenizer_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers tokenizer class"}
default=None,
json_schema_extra={
"description": "Corresponding tokenizer for the model AutoTokenizer is a good choice"
},
)
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
trust_remote_code: bool | None = None
trust_remote_code: bool | None = Field(
default=None,
json_schema_extra={"description": "Trust remote code for untrusted source"},
)
@field_validator("trust_remote_code")
@classmethod
@@ -40,10 +75,23 @@ class ModelInputConfig(BaseModel):
class ModelOutputConfig(BaseModel):
"""model save configuration subset"""
output_dir: str = Field(default="./model-out")
hub_model_id: str | None = None
hub_strategy: str | None = None
save_safetensors: bool | None = True
output_dir: str = Field(
default="./model-out",
json_schema_extra={"description": "Where to save the full-finetuned model to"},
)
hub_model_id: str | None = Field(
default=None, json_schema_extra={"description": "push checkpoints to hub"}
)
hub_strategy: str | None = Field(
default=None,
json_schema_extra={"description": "how to push checkpoints to hub"},
)
save_safetensors: bool | None = Field(
default=True,
json_schema_extra={
"description": "Save model as safetensors (require safetensors package). Default True"
},
)
class SpecialTokensConfig(BaseModel):

View File

@@ -9,7 +9,7 @@ class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
default=4, json_schema_extra={"description": "typically 4 bits"}
)
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
@@ -17,31 +17,78 @@ class LoftQConfig(BaseModel):
class PeftConfig(BaseModel):
"""peftq configuration subset"""
loftq_config: LoftQConfig | None = None
loftq_config: LoftQConfig | None = Field(
default=None,
json_schema_extra={
"description": "Configuration options for loftq initialization for LoRA"
},
)
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = Field(default=False)
load_in_8bit: bool | None = Field(
default=False,
json_schema_extra={
"description": "This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer"
},
)
load_in_4bit: bool | None = Field(
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
)
adapter: str | None = None
lora_model_dir: str | None = None
adapter: str | None = Field(
default=None,
json_schema_extra={
"description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model"
},
)
lora_model_dir: str | None = Field(
default=None,
json_schema_extra={
"description": "If you already have a lora model trained that you want to load, put that here. This means after training, if you want to test the model, you should set this to the value of `output_dir`. Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`."
},
)
lora_r: int | None = None
lora_alpha: int | None = None
lora_fan_in_fan_out: bool | None = None
lora_target_modules: str | list[str] | None = None
lora_target_linear: bool | None = None
lora_modules_to_save: list[str] | None = None
lora_target_linear: bool | None = Field(
default=None,
json_schema_extra={"description": "If true, will target all linear modules"},
)
lora_modules_to_save: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities."
},
)
lora_dropout: float | None = 0.0
peft_layers_to_transform: list[int] | None = None
peft_layers_to_transform: list[int] | None = Field(
default=None,
json_schema_extra={
"description": "The layer indices to transform, otherwise, apply to all layers"
},
)
peft_layers_pattern: list[str] | None = None
peft: PeftConfig | None = None
peft_use_dora: bool | None = None
peft_use_rslora: bool | None = None
peft_layer_replication: list[tuple[int, int]] | None = None
peft_init_lora_weights: bool | str | None = None
peft_use_dora: bool | None = Field(
default=None, json_schema_extra={"description": "Whether to use DoRA."}
)
peft_use_rslora: bool | None = Field(
default=None, json_schema_extra={"description": "Whether to use RSLoRA."}
)
peft_layer_replication: list[tuple[int, int]] | None = Field(
default=None,
json_schema_extra={"description": "List of layer indices to replicate."},
)
peft_init_lora_weights: bool | str | None = Field(
default=None,
json_schema_extra={
"description": "How to initialize LoRA weights. Default to True which is MS original implementation."
},
)
qlora_sharded_model_loading: bool | None = Field(
default=False,
@@ -49,9 +96,24 @@ class LoraConfig(BaseModel):
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
lora_on_cpu: bool | None = Field(
default=None,
json_schema_extra={
"description": "Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge"
},
)
gptq: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether you are training a 4-bit GPTQ quantized model"
},
)
bnb_config_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "optional overrides to the bnb 4bit quantization configuration"
},
)
loraplus_lr_ratio: float | None = Field(
default=None,
@@ -62,7 +124,7 @@ class LoraConfig(BaseModel):
loraplus_lr_embedding: float | None = Field(
default=1e-6,
json_schema_extra={
"description": "loraplus learning rate for lora embedding layers."
"description": "loraplus learning rate for lora embedding layers. Default value is 1e-6."
},
)
@@ -125,8 +187,29 @@ class LoraConfig(BaseModel):
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
relora_steps: int | None = None
relora_warmup_steps: int | None = None
relora_anneal_steps: int | None = None
relora_prune_ratio: float | None = None
relora_cpu_offload: bool | None = None
relora_steps: int | None = Field(
default=None,
json_schema_extra={"description": "Number of steps per ReLoRA restart"},
)
relora_warmup_steps: int | None = Field(
default=None,
json_schema_extra={"description": "Number of per-restart warmup steps"},
)
relora_anneal_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of anneal steps for each relora cycle"
},
)
relora_prune_ratio: float | None = Field(
default=None,
json_schema_extra={
"description": "threshold for optimizer magnitude when pruning"
},
)
relora_cpu_offload: bool | None = Field(
default=None,
json_schema_extra={
"description": "True to perform lora weight merges on cpu during restarts, for modest gpu memory savings"
},
)

View File

@@ -15,17 +15,22 @@ class QATConfig(BaseModel):
"""
activation_dtype: TorchIntDType | None = Field(
default=None, description="Activation dtype"
default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
)
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8, description="Weight dtype"
default=TorchIntDType.int8,
description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"',
)
quantize_embedding: bool | None = Field(
default=False, description="Quantize embedding"
)
group_size: int | None = Field(default=32, description="Group size")
group_size: int | None = Field(
default=32,
description="The number of elements in each group for per-group fake quantization",
)
fake_quant_after_n_steps: int | None = Field(
default=None, description="Fake quant after n steps"
default=None, description="The number of steps to apply fake quantization after"
)
@field_validator("activation_dtype", "weight_dtype", mode="before")
@@ -44,15 +49,20 @@ class PTQConfig(BaseModel):
"""
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8, description="Weight dtype"
default=TorchIntDType.int8,
description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8",
)
activation_dtype: TorchIntDType | None = Field(
default=None, description="Activation dtype"
default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
)
quantize_embedding: bool | None = Field(
default=None, description="Quantize embedding"
default=None, description="Whether to quantize the embedding layer."
)
group_size: int | None = Field(
default=32,
description="The number of elements in each group for per-group fake quantization",
)
group_size: int | None = Field(default=32, description="Group size")
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod

View File

@@ -23,10 +23,17 @@ class LrGroup(BaseModel):
class HyperparametersConfig(BaseModel):
"""Training hyperparams configuration subset"""
gradient_accumulation_steps: int | None = Field(default=1)
gradient_accumulation_steps: int | None = Field(
default=1,
json_schema_extra={
"description": "If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps."
},
)
micro_batch_size: int | None = Field(
default=1,
json_schema_extra={"description": "per gpu micro batch size for training"},
json_schema_extra={
"description": "The number of samples to include in each batch. This is the number of samples sent to each GPU. Batch size per gpu = micro_batch_size * gradient_accumulation_steps"
},
)
batch_size: int | None = Field(
default=None,
@@ -41,45 +48,99 @@ class HyperparametersConfig(BaseModel):
},
)
auto_find_batch_size: bool | None = None
auto_find_batch_size: bool | None = Field(
default=None,
json_schema_extra={
"description": "whether to find batch size that fits in memory. Passed to underlying transformers Trainer"
},
)
train_on_inputs: bool | None = False
group_by_length: bool | None = None
train_on_inputs: bool | None = Field(
default=False,
json_schema_extra={
"description": "Whether to mask out or include the human's prompt from the training labels"
},
)
group_by_length: bool | None = Field(
default=None,
json_schema_extra={
"description": "Group similarly sized data to minimize padding. May be slower to start, as it must download and sort the entire dataset. Note that training loss may have an oscillating pattern with this enabled."
},
)
learning_rate: str | float
embedding_lr: float | None = None
embedding_lr_scale: float | None = None
weight_decay: float | None = 0.0
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
OptimizerNames.ADAMW_TORCH_FUSED
weight_decay: float | None = Field(
default=0.0, json_schema_extra={"description": "Specify weight decay"}
)
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = Field(
default=OptimizerNames.ADAMW_TORCH_FUSED,
json_schema_extra={"description": "Specify optimizer"},
)
optim_args: (str | dict[str, Any]) | None = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
json_schema_extra={
"description": "Dictionary of arguments to pass to the optimizer"
},
)
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
default=None,
json_schema_extra={
"description": "The target modules to optimize, i.e. the module names that you would like to train."
"description": "The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm"
},
)
torchdistx_path: str | None = Field(
default=None,
json_schema_extra={
"description": "Path to torch distx for optim 'adamw_anyprecision'"
},
)
torchdistx_path: str | None = None
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
SchedulerType.COSINE
)
lr_scheduler_kwargs: dict[str, Any] | None = None
lr_scheduler_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Specify a scheduler and kwargs to use with the optimizer"
},
)
lr_quadratic_warmup: bool | None = None
cosine_min_lr_ratio: float | None = None
cosine_constant_lr_ratio: float | None = None
lr_div_factor: float | None = None
cosine_min_lr_ratio: float | None = Field(
default=None,
json_schema_extra={
"description": "decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr"
},
)
cosine_constant_lr_ratio: float | None = Field(
default=None,
json_schema_extra={
"description": "freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step"
},
)
lr_div_factor: float | None = Field(
default=None, json_schema_extra={"description": "Learning rate div factor"}
)
lr_groups: list[LrGroup] | None = None
adam_epsilon: float | None = None
adam_epsilon2: float | None = None
adam_beta1: float | None = None
adam_beta2: float | None = None
adam_beta3: float | None = None
max_grad_norm: float | None = None
adam_epsilon: float | None = Field(
default=None, json_schema_extra={"description": "adamw hyperparams"}
)
adam_epsilon2: float | None = Field(
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
)
adam_beta1: float | None = Field(
default=None, json_schema_extra={"description": "adamw hyperparams"}
)
adam_beta2: float | None = Field(
default=None, json_schema_extra={"description": "adamw hyperparams"}
)
adam_beta3: float | None = Field(
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
)
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)
num_epochs: float = Field(default=1.0)
@field_validator("batch_size")

View File

@@ -10,12 +10,14 @@ class TRLConfig(BaseModel):
beta: float | None = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
json_schema_extra={
"description": "Beta parameter for the RL training. Same as `rl_beta`. Use"
},
)
max_completion_length: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the completion for RL training"
"description": "Maximum length of the completion for RL training."
},
)
@@ -23,81 +25,69 @@ class TRLConfig(BaseModel):
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
use_vllm: bool = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
json_schema_extra={"description": "Whether to use VLLM for RL training."},
)
vllm_server_host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host of the vLLM server to connect to"},
json_schema_extra={"description": "Host of the vLLM server to connect to."},
)
vllm_server_port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to connect to"},
json_schema_extra={"description": "Port of the vLLM server to connect to."},
)
vllm_server_timeout: int | None = Field(
default=None,
json_schema_extra={
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
"after the timeout, a `ConnectionError` is raised."
"description": "Total timeout (in seconds) to wait for the vLLM server to respond."
},
)
vllm_guided_decoding_regex: str | None = Field(
default=None,
json_schema_extra={
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
},
json_schema_extra={"description": "Regex for vLLM guided decoding."},
)
reward_funcs: list[str] | None = Field(
default=None,
json_schema_extra={"description": "List of reward functions to load"},
json_schema_extra={
"description": "List of reward functions to load. Paths must be importable from current dir."
},
)
reward_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for each reward function. Must match the number of reward functions."
"description": "List of reward weights for the reward functions."
},
)
num_generations: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
json_schema_extra={"description": "Number of generations to sample."},
)
log_completions: bool | None = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
json_schema_extra={"description": "Whether to log completions."},
)
num_completions_to_print: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
"description": "Number of completions to print when log_completions is True."
},
)
sync_ref_model: bool | None = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
json_schema_extra={"description": "Whether to sync the reference model."},
)
ref_model_mixup_alpha: float | None = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
json_schema_extra={"description": "Mixup alpha for the reference model."},
)
ref_model_sync_steps: int | None = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
json_schema_extra={"description": "Sync steps for the reference model."},
)
scale_rewards: bool = Field(
default=True,
json_schema_extra={
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
"description": "Whether to scale rewards by their standard deviation."
},
)
@@ -124,13 +114,13 @@ class TRLConfig(BaseModel):
repetition_penalty: float | None = Field(
default=None,
json_schema_extra={
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
"description": "Penalty for tokens that appear in prompt and generated text."
},
)
num_iterations: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
"description": "Number of iterations per batch (μ) for GRPO."
},
)
epsilon: float | None = Field(
@@ -152,12 +142,12 @@ class TRLConfig(BaseModel):
loss_type: str | None = Field(
default=None,
json_schema_extra={
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
"description": "Loss formulation to use. Supported values: grpo, bnpo, dr_grpo."
},
)
mask_truncated_completions: bool = Field(
default=False,
json_schema_extra={
"description": "When enabled, truncated completions are excluded from the loss calculation."
"description": "Whether to exclude truncated completions from loss calculation."
},
)

File diff suppressed because it is too large Load Diff

View File

@@ -91,7 +91,10 @@ class TestSequenceParallelism:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high"
temp_dir + "/runs",
"train/train_loss",
threshold,
"Train Loss (%s) is too high",
)
@pytest.mark.parametrize(

View File

@@ -85,5 +85,5 @@ class TestPackedFlex:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -91,5 +91,5 @@ class TestMultiGPUGemma3:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high"
)

View File

@@ -89,7 +89,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -154,7 +154,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
def test_dpo_lora_ddp(self, temp_dir):
@@ -232,7 +232,7 @@ class TestMultiGPULlama:
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
"Train Loss (%s) is too high",
)
def test_dpo_qlora_ddp(self, temp_dir):
@@ -310,7 +310,7 @@ class TestMultiGPULlama:
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
"Train Loss (%s) is too high",
)
@pytest.mark.parametrize(
@@ -380,7 +380,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -452,7 +452,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_6_0
@@ -533,7 +533,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
@@ -613,7 +613,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -697,7 +697,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.4, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -771,7 +771,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -845,7 +845,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.skip(
@@ -912,5 +912,5 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high"
)

View File

@@ -75,7 +75,7 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_lt_2_6_0
@@ -133,5 +133,5 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)

View File

@@ -78,5 +78,5 @@ class TestFAXentropyLlama:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -73,7 +73,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
@@ -123,7 +123,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -178,5 +178,5 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -63,5 +63,5 @@ class TestPackedFlex(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -69,5 +69,5 @@ class TestPretrainLlama:
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
"Train Loss (%s) is too high",
)

View File

@@ -62,5 +62,5 @@ class TestPackedLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -129,5 +129,5 @@ class TestQATLlama:
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
"Train Loss (%s) is too high",
)

View File

@@ -66,6 +66,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -143,6 +143,12 @@ def fixture_phi35_tokenizer():
return tokenizer
@pytest.fixture(name="phi4_tokenizer", scope="session", autouse=True)
def fixture_phi4_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning")
return tokenizer
@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
def fixture_gemma2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")

View File

@@ -33,15 +33,14 @@ PARAMETRIZE_PARAMS = [
"mistralv03_tokenizer_chat_template_jinja",
"[/INST]",
),
# TODO: temporarily skip gemma due to gemma3 template
# Re-enable on new chat_template implementation for perf
# (
# "gemma2_tokenizer",
# "jinja",
# "gemma2_tokenizer_chat_template_jinja",
# "<end_of_turn>",
# ),
(
"gemma2_tokenizer",
"jinja",
"gemma2_tokenizer_chat_template_jinja",
"<end_of_turn>",
),
("phi35_tokenizer", "phi_35", None, "<|end|>"),
("phi4_tokenizer", "phi_4", None, "<|im_end|>"),
]
@@ -95,11 +94,7 @@ class TestChatTemplateConfigurations:
if (
turn_idx == 0
and turn.get("from") in ["system", "context"]
and (
"mistral" in tokenizer.name_or_path.lower()
or "gemma"
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
)
and ("mistral" in tokenizer.name_or_path.lower())
):
assert (
start_idx == -1 and end_idx == -1
@@ -935,36 +930,14 @@ class TestChatTemplateConfigurations:
"messages",
)
if chat_template == "llama3":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "chatml":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "phi_35":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
# Special case for Mistral with additional tool variables
if chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
expected_variables = {"role", "content", "tool_call_id", "tool_calls"}
# Most chat templates use the standard role and content variables
elif chat_template in ["llama3", "chatml", "phi_35", "phi_4"] or (
chat_template == "jinja" and tokenizer == "gemma2_tokenizer"
):
expected_variables = {"role", "content"}
else:
LOG.warning(
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
@@ -973,6 +946,12 @@ class TestChatTemplateConfigurations:
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
)
assert variables == expected_variables, (
f"Expected variables: {expected_variables} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
def test_eot_tokens_conflict_with_eos_token(
self,
tokenizer,

View File

@@ -11,8 +11,6 @@ from axolotl.prompt_strategies.chat_template import (
)
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="messages_w_reasoning")
def messages_w_reasoning_fixture():
@@ -59,7 +57,6 @@ def messages_w_reasoning_fixture():
@pytest.fixture(name="qwen3_tokenizer")
@enable_hf_offline
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument