Compare commits

..

7 Commits

Author SHA1 Message Date
Dan Saunders
cbcc795bb3 commenting out unused 2025-06-16 01:53:13 +00:00
Dan Saunders
e34b6f4dfe temp: trying another approach 2025-06-15 21:32:10 +00:00
Dan Saunders
f8f87321bd progress 2025-06-14 17:40:21 +00:00
Dan Saunders
7a88de4fa8 finish basic impl; change naming from SP -> CP to match torch 2025-06-13 09:51:06 -04:00
Dan Saunders
aced809989 progress (messy :O) 2025-06-12 18:54:41 +00:00
Dan Saunders
ae73123eae progress; move validation to pydantic model config 2025-06-07 06:58:59 +00:00
Dan Saunders
10d1e44943 SDPA context parallel 2025-06-06 00:34:12 +00:00
146 changed files with 2865 additions and 6834 deletions

View File

@@ -16,7 +16,6 @@ on:
jobs:
build-base:
if: github.repository_owner == 'axolotl-ai-cloud'
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
strategy:
@@ -48,14 +47,14 @@ jobs:
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
@@ -107,7 +106,6 @@ jobs:
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
build-base-uv:
if: github.repository_owner == 'axolotl-ai-cloud'
timeout-minutes: 480
runs-on: ubuntu-latest-m
strategy:
fail-fast: false
@@ -124,7 +122,7 @@ jobs:
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:

View File

@@ -29,12 +29,12 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -97,12 +97,12 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -8,7 +8,7 @@ on:
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/core/trainers/mixins/context_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
schedule:
@@ -43,7 +43,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"

View File

@@ -52,7 +52,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -125,7 +125,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -188,7 +188,7 @@ jobs:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
timeout-minutes: 90
needs: [pre-commit, pytest, pytest-sdist]
strategy:
@@ -238,7 +238,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
timeout-minutes: 90
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, docker-e2e-tests-1st]
@@ -262,13 +262,13 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps:

275
README.md
View File

@@ -1,177 +1,152 @@
<div align="center">
<a href="https://github.com/axolotl-ai-cloud/axolotl">
<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/docs/logo.png" alt="Axolotl Logo" width="250" style="margin-bottom: 20px;"/>
</a>
<h1><span style="color: #4CAF50;">Axolotl: Fine-tune LLMs with Unprecedented Ease & Power!</span> 🚀</h1>
<p style="font-size: 1.1em; color: #555;">Your ultimate toolkit for efficient, scalable, and versatile large language model fine-tuning.</p>
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
<p>
<a href="https://discord.gg/HhrNrHJPRb" target="_blank">
<img src="https://img.shields.io/discord/1070542385153273887?label=Discord&logo=discord&logoColor=white&color=7289DA" alt="Discord Community" style="margin: 5px;">
</a>
<a href="https://docs.axolotl.ai/" target="_blank">
<img src="https://img.shields.io/badge/Documentation-blue?style=flat&logo=readthedocs&logoColor=white" alt="Official Documentation" style="margin: 5px;">
</a>
<a href="https://pypi.org/project/axolotl/" target="_blank">
<img src="https://img.shields.io/pypi/v/axolotl?label=PyPI&logo=pypi&logoColor=white&color=blue" alt="PyPI Package" style="margin: 5px;">
</a>
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases" target="_blank">
<img src="https://img.shields.io/github/downloads/axolotl-ai-cloud/axolotl/total?label=Downloads&color=green" alt="GitHub Downloads" style="margin: 5px;">
</a>
</p>
<br>
</div>
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://codecov.io/gh/axolotl-ai-cloud/axolotl"><img src="https://codecov.io/gh/axolotl-ai-cloud/axolotl/branch/main/graph/badge.svg" alt="codecov"></a>
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p>
---
Axolotl is a tool designed to streamline post-training for various AI models.
Post-training refers to any modifications or additional training performed on
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
techniques. With support for multiple model architectures and training configurations,
Axolotl makes it easy to get started with these techniques.
<div style="background-color: #f0f8ff; padding: 25px; border-radius: 12px; margin-bottom: 30px; border: 1px solid #d0e8ff;">
<h2 style="color: #0056b3; text-align: center; margin-top: 0;">🎉 Latest Innovations & Updates!</h2>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/06:</span> Magistral with mistral-common tokenizer support!</strong> Dive into <a href="https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral" style="color: #007bff; text-decoration: none;">examples</a> to train your own Magistral models.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/05:</span> Quantization Aware Training (QAT) support!</strong> Explore the <a href="https://docs.axolotl.ai/docs/qat.html" style="color: #007bff; text-decoration: none;">docs</a> to learn more.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/04:</span> Llama 4 support!</strong> See <a href="https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4" style="color: #007bff; text-decoration: none;">examples</a> to train Llama 4 with Axolotl's linearized version!</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/03:</span> Sequence Parallelism (SP) support!</strong> Scale your context length. Read the <a href="https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl" style="color: #007bff; text-decoration: none;">blog</a> and <a href="https://docs.axolotl.ai/docs/sequence_parallelism.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/03:</span> (Beta) Fine-tuning Multimodal models!</strong> Check out the <a href="https://docs.axolotl.ai/docs/multimodal.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/02:</span> LoRA optimizations!</strong> Reduce memory and improve speed. Jump into the <a href="https://docs.axolotl.ai/docs/lora_optims.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/02:</span> GRPO support!</strong> Dive into our <a href="https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm" style="color: #007bff; text-decoration: none;">blog</a> and <a href="https://github.com/axolotl-ai-cloud/grpo_code" style="color: #007bff; text-decoration: none;">GRPO example</a>.</li>
<li style="margin-bottom: 0px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/01:</span> Reward Modelling / Process Reward Modelling fine-tuning!</strong> See <a href="https://docs.axolotl.ai/docs/reward_modelling.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
</ul>
</div>
Axolotl is designed to work with YAML config files that contain everything you need to
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
and much more.
<h2 style="color: #FF5733;"><span style="margin-right: 10px;">✨</span> Axolotl Overview: Your LLM Fine-tuning Powerhouse!</h2>
Features:
<div style="background-color: #fffacd; padding: 20px; border-radius: 10px; margin-bottom: 30px; border: 1px solid #ffd700;">
<p style="font-size: 1.1em; color: #333; text-align: center;">Axolotl is a powerful, flexible, and user-friendly tool designed to supercharge your post-training workflows for a wide range of cutting-edge AI models.</p>
</div>
- Train various Huggingface models such as llama, pythia, falcon, mpt
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!
<div style="display: flex; flex-wrap: wrap; justify-content: space-around; gap: 20px; margin-bottom: 40px;">
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">🤖</span> Broad Model Compatibility</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Train a vast array of models including LLaMA, Mistral, Mixtral, Pythia, and many more.</li>
<li>Fully compatible with HuggingFace transformers causal language models, ensuring wide adoption.</li>
</ul>
</div>
## 🚀 Quick Start
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">🔧</span> Diverse Training Methodologies</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Full fine-tuning, LoRA, QLoRA, GPTQ, QAT.</li>
<li>Preference Tuning: DPO, IPO, KTO, ORPO.</li>
<li>Advanced RL: GRPO.</li>
<li>Multimodal and Reward Modelling (RM) / Process Reward Modelling (PRM).</li>
</ul>
</div>
**Requirements**:
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">⚙️</span> Streamlined Configuration</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Utilize a single, intuitive YAML file across dataset preprocess, training, evaluation, quantization, and inference.</li>
</ul>
</div>
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.5.1
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">⚡</span> Cutting-Edge Performance Optimizations</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li><a href="https://docs.axolotl.ai/docs/multipack.html" style="color: #007bff;">Multipacking</a>, <a href="https://github.com/Dao-AILab/flash-attention" style="color: #007bff;">Flash Attention</a>, <a href="https://github.com/facebookresearch/xformers" style="color: #007bff;">Xformers</a>, <a href="https://pytorch.org/blog/flexattention/" style="color: #007bff;">Flex Attention</a>, <a href="https://github.com/linkedin/Liger-Kernel" style="color: #007bff;">Liger Kernel</a>, <a href="https://github.com/apple/ml-cross-entropy/tree/main" style="color: #007bff;">Cut Cross Entropy</a>.</li>
<li>Sequence Parallelism (SP), LoRA optimizations.</li>
<li>Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!</li>
</ul>
</div>
### Installation
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">📂</span> Flexible Data Handling</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Load datasets from local paths, HuggingFace Hub, and major cloud providers (S3, Azure, GCP, OCI).</li>
</ul>
</div>
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">☁️</span> Cloud-Ready & Deployable</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Official <a href="https://hub.docker.com/u/axolotlai" style="color: #007bff;">Docker images</a> and <a href="https://pypi.org/project/axolotl/" style="color: #007bff;">PyPI packages</a> for seamless integration on cloud platforms and local hardware.</li>
</ul>
</div>
</div>
<h2 style="color: #007bff;"><span style="margin-right: 10px;">🚀</span> Quick Start: Get Fine-tuning in Minutes!</h2>
<div style="background-color: #e6f7ff; padding: 25px; border-radius: 12px; margin-bottom: 30px; border: 1px solid #cceeff;">
<h3 style="color: #0056b3; margin-top: 0;">Requirements:</h3>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 5px;"><span style="color: #333; font-weight: bold;">▶ NVIDIA GPU</span> (Ampere or newer for `bf16` and Flash Attention) or AMD GPU</li>
<li style="margin-bottom: 5px;"><span style="color: #333; font-weight: bold;">▶ Python 3.11</span></li>
<li style="margin-bottom: 5px;"><span style="color: #333; font-weight: bold;">▶ PyTorch ≥2.5.1</span></li>
</ul>
<h3 style="color: #0056b3;">Installation:</h3>
<pre><code style="background-color: #eef; padding: 15px; border-radius: 8px; display: block; overflow-x: auto;">pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL</code></pre>
<p style="font-size: 0.9em; color: #555;">Other installation approaches are described <a href="https://docs.axolotl.ai/docs/installation.html" style="color: #007bff; text-decoration: none;">here</a>.</p>
axolotl fetch deepspeed_configs # OPTIONAL
```
<h3 style="color: #0056b3;">Your First Fine-tune:</h3>
<pre><code style="background-color: #eef; padding: 15px; border-radius: 8px; display: block; overflow-x: auto;"># Fetch axolotl examples
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
### Your First Fine-tune
```bash
# Fetch axolotl examples
axolotl fetch examples
# Or, specify a custom path
axolotl fetch examples --dest path/to/folder
# Train a model using LoRA
axolotl train examples/llama-3/lora-1b.yml</code></pre>
<p style="text-align: center; font-size: 1.1em; font-weight: bold; margin-top: 20px;">
That's it! Check out our <a href="https://docs.axolotl.ai/docs/getting-started.html" style="background-color: #28a745; color: white; padding: 12px 25px; border-radius: 8px; text-decoration: none; display: inline-block; transition: background-color 0.3s ease;"> Getting Started Guide ➜</a> for a more detailed walkthrough.
</p>
</div>
axolotl train examples/llama-3/lora-1b.yml
```
<h2 style="color: #8A2BE2;"><span style="margin-right: 10px;">📚</span> Comprehensive Documentation: Unlock Axolotl's Full Potential</h2>
That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough.
<div style="background-color: #f7f0ff; padding: 25px; border-radius: 12px; margin-bottom: 30px; border: 1px solid #e0caff;">
<p style="text-align: center; font-size: 1.1em; color: #333;">Dive deep into Axolotl's capabilities with our extensive documentation:</p>
<ul style="list-style-type: none; padding-left: 0; text-align: center;">
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/installation.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Installation Options</a> - Detailed setup instructions for different environments</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/config.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Configuration Guide</a> - Full configuration options and examples</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/dataset_loading.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Dataset Loading</a> - Loading datasets from various sources</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/dataset-formats/" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Dataset Guide</a> - Supported formats and how to use them</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/multi-gpu.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Multi-GPU Training</a></li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/multi-node.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Multi-Node Training</a></li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/multipack.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Multipacking</a></li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/api/" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> API Reference</a> - Auto-generated code documentation</li>
<li style="margin-bottom: 0px;"><a href="https://docs.axolotl.ai/docs/faq.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;">❓ FAQ</a> - Frequently asked questions</li>
</ul>
</div>
## ✨ Key Features
<h2 style="color: #FF8C00;"><span style="margin-right: 10px;">🤝</span> Need Help? We're Here for You!</h2>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 10px;"><span style="font-size: 1.2em; color: #7289DA;"></span> Join our vibrant <a href="https://discord.gg/HhrNrHJPRb" style="color: #7289DA; text-decoration: none; font-weight: bold;">Discord community</a> for real-time support and discussions.</li>
<li style="margin-bottom: 10px;"><span style="font-size: 1.2em; color: #555;"></span> Explore our <a href="https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/" style="color: #FF8C00; text-decoration: none; font-weight: bold;">Examples</a> directory for practical use cases.</li>
<li style="margin-bottom: 10px;"><span style="font-size: 1.2em; color: #555;"></span> Read our <a href="https://docs.axolotl.ai/docs/debugging.html" style="color: #FF8C00; text-decoration: none; font-weight: bold;">Debugging Guide</a> for troubleshooting tips.</li>
<li style="margin-bottom: 0px;"><span style="font-size: 1.2em; color: #007bff;">✉</span> Need dedicated support? Please contact <a href="mailto:wing@axolotl.ai" style="color: #007bff; text-decoration: none; font-weight: bold;">wing@axolotl.ai</a> for professional assistance options.</li>
</ul>
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
- **Easy Configuration**: Simple YAML files to control your training setup
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
- **Flexible Dataset Handling**: Use various formats and custom datasets
- **Cloud Ready**: Run on cloud platforms or local hardware
<h2 style="color: #FF1493;"><span style="margin-right: 10px;">🌟</span> Contribute to Axolotl!</h2>
<p style="font-size: 1.1em;">
Contributions are always welcome and highly appreciated! Axolotl thrives on community support. Please see our <a href="https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md" style="color: #FF1493; text-decoration: none; font-weight: bold;">Contributing Guide</a> for details on how you can help make Axolotl even better.
</p>
## 📚 Documentation
<div align="center" style="margin-top: 40px; padding: 25px; background-color: #f8f8f8; border-radius: 12px; border: 1px solid #eee;">
<h2 style="color: #FF69B4; margin-bottom: 20px;">❤️ Our Esteemed Sponsors</h2>
<p style="font-size: 1.1em; color: #555;">A huge thank you to our visionary sponsors who provide the essential resources to keep Axolotl at the forefront of LLM fine-tuning:</p>
<a href="https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl" target="_blank" style="display: inline-block; margin: 20px;">
<img src="https://assets-global.website-files.com/6247c4c1d68352614b7e87ae/63b27b3b44b82d02c8163f4f_logo-dark-square.png" alt="Modal Logo" width="180" style="vertical-align: middle; border-radius: 8px; box-shadow: 0 4px 10px rgba(0,0,0,0.15);"/>
</a>
<p style="font-size: 0.9em; color: #777; margin-top: 20px;">
<strong>Modal:</strong> Revolutionizing cloud computing for Gen AI. Run jobs, deploy models, and fine-tune LLMs at scale with ease.
</p>
<p style="font-size: 1em; color: #555; margin-top: 30px;">
Interested in powering the future of Axolotl? <span style="font-weight: bold; color: #FF69B4;">Become a sponsor!</span> Contact us at <a href="mailto:wing@axolotl.ai" style="color: #007bff; text-decoration: none;">wing@axolotl.ai</a>
</p>
</div>
- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments
- [Configuration Guide](https://docs.axolotl.ai/docs/config.html) - Full configuration options and examples
- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [Multipacking](https://docs.axolotl.ai/docs/multipack.html)
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
<h2 style="color: #6A5ACD;"><span style="margin-right: 10px;">📜</span> License</h2>
<p style="font-size: 1.1em;">
This project is proudly licensed under the <span style="font-weight: bold; color: #6A5ACD;">Apache 2.0 License</span>. See the <a href="LICENSE" style="color: #007bff; text-decoration: none;">LICENSE</a> file for full details.
</p>
## 🤝 Getting Help
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
- Check out our [Examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/) directory
- Read our [Debugging Guide](https://docs.axolotl.ai/docs/debugging.html)
- Need dedicated support? Please contact [wing@axolotl.ai](mailto:wing@axolotl.ai) for options
## 🌟 Contributing
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## Supported Models
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## ❤️ Sponsors
Thank you to our sponsors who help make Axolotl possible:
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
fine-tune large language models, run protein folding simulations, and much more.
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📜 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.

View File

@@ -75,7 +75,7 @@ quartodoc:
- title: Context Managers
desc: Context managers for altering trainer behaviors
contents:
- utils.ctx_managers.sequence_parallel
- utils.ctx_managers.context_parallel
- title: Prompt Strategies
desc: Prompt formatting strategies
contents:
@@ -274,7 +274,7 @@ website:
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/context_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -1,31 +0,0 @@
{
"compile": {
"disable": false,
"backend": "inductor"
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -29,12 +29,8 @@ RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
RUN uv pip install packaging setuptools wheel \
&& uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
fi

View File

@@ -27,8 +27,6 @@ trust_remote_code:
tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy:
# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer.
tokenizer_use_mistral_common:
# Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
@@ -175,10 +173,6 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Key containing the tools (default: "tools")
# Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
field_tools: tools
# Key containing the system message (default: "system")
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
field_system: system
@@ -770,13 +764,13 @@ ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Sequence parallelism
# Context parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# See https://docs.axolotl.ai/docs/context_parallelism.html for more details.
context_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1

View File

@@ -52,9 +52,7 @@ We recommend checking the below examples for other usecases.
### Examples
#### Training on last message
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
@@ -68,9 +66,7 @@ datasets:
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
#### Overriding default chat template
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
@@ -80,13 +76,7 @@ datasets:
roles_to_train: ["assistant"] # default value
```
::: {.callout-note}
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
:::
#### Using default chat template with fallback
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
@@ -95,9 +85,7 @@ datasets:
type: chat_template
```
#### Custom Jinja template
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
@@ -112,9 +100,7 @@ datasets:
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
#### Using template with different token for EOT and EOS
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
```yaml
eot_tokens:
@@ -139,7 +125,7 @@ Using `eot_tokens` requires each token that exists in `chat_template` to be a si
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
:::
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
```yaml
eot_tokens:
@@ -159,73 +145,7 @@ If EOS token only appears at the end of a prompt, `train_on_eos: last` is equiva
:::
#### Using tool use
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
```json
{
"tools": [
{
"type": "...",
"function": {
"name": "...",
"description": "...",
"parameters": {
"type": "...",
"properties": {
// ...
},
"required": ["..."],
},
},
},
],
"messages": [
// ...
{
"role": "assistant", // call the function via assistant
"tool_calls": [
{
"type": "function",
"function": {
"name": "...",
"arguments": {
"...": "...",
}
}
}
]
},
{
"role": "tool",
"name": "...",
"content": "..."
},
],
}
```
::: {.callout-note}
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
```yaml
chat_template: llama4
datasets:
- path: ...
type: chat_template
# field_tools: tools # default is `tools`
```
::: {.callout-tip}
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
:::
#### Using fine-grained control over token masking
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
@@ -276,9 +196,7 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
#### Reasoning split
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
```yaml
datasets:

View File

@@ -9,7 +9,7 @@ format:
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
For Blackwell GPUs, please use the tags with Pytorch 2.7.0 and CUDA 12.8.
:::
## Base
@@ -32,8 +32,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`

View File

@@ -18,7 +18,7 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- Context parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -80,14 +80,14 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
## Context parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the
We support context parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
See our [dedicated guide](sequence_parallelism.qmd) for more information.
See our [dedicated guide](context_parallelism.qmd) for more information.
### FSDP + QLoRA {#sec-fsdp-qlora}

View File

@@ -29,4 +29,4 @@ qat:
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this.

View File

@@ -500,7 +500,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:

View File

@@ -1,16 +1,16 @@
---
title: Sequence Parallelism
title: Context Parallelism
description: Train with long sequences split across multiple GPUs.
---
Sequence parallelism is a technique that splits sequences across multiple GPUs,
Context parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.
## When to Use Sequence Parallelism
## When to Use Context Parallelism
Use sequence parallelism when:
Use context parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
@@ -18,11 +18,11 @@ Use sequence parallelism when:
## Configuration
To enable sequence parallelism, add the following to your configuration file:
To enable context parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
context_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,23 +30,23 @@ heads_k_stride: 1
ring_attn_func:
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
The `context_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
## Implementation Details
When sequence parallelism is enabled:
When context parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
1. Each sequence is divided into equal chunks across the GPUs in a context parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions
4. The trainer uses special ring communication patterns for attention operations
## Requirements
To use sequence parallelism, you need:
To use context parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
@@ -66,7 +66,7 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
context_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -79,22 +79,22 @@ ring_attn_func:
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
## Sample Packing with Context Parallelism
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
Context parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group
2. The packed sequences are then divided across GPUs in the context parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
When using context parallelism, your effective global batch size is **divided** by the `context_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- Each group of `context_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- With 8 GPUs and no context parallelism: 8 different batches processed per step
- With 8 GPUs and `context_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -5,10 +5,6 @@ tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,71 +0,0 @@
# Finetune Magistral Small with Axolotl
Magistral Small is a 24B parameter opensource model from MistralAI found on [HuggingFace](https://huggingface.co/mistralai/Magistral-Small-2506). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 recommended)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
```
2. Download the example config:
```bash
axolotl fetch examples
```
3. Run the finetuning example:
```bash
axolotl train examples/magistral/magistral-small-qlora.yaml
```
This config uses about 24GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, Multi-modal, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -1,72 +0,0 @@
base_model: mistralai/Magistral-Small-2506
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing:
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_activation_checkpointing: true

View File

@@ -1,63 +0,0 @@
base_model: mistralai/Magistral-Small-2506
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.7 KiB

After

Width:  |  Height:  |  Size: 4.5 KiB

View File

@@ -67,5 +67,3 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
mistral-common==1.6.0

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.10.0"
__version__ = "0.10.0.dev0"

View File

@@ -73,7 +73,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
sequence_parallel_degree=None,
context_parallel_degree=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,

View File

@@ -1,3 +1,5 @@
"""Various shared constants"""
"""
Various shared constants
"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

@@ -3,13 +3,15 @@
import math
import random
from dataclasses import dataclass
from typing import Optional, Union
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
@@ -28,7 +30,16 @@ class TrainDatasetMeta:
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
"""Randomly sample `num_samples` samples with replacement from `dataset`."""
"""
Randomly sample `num_samples` samples from `dataset`.
Args:
dataset: Dataset.
num_samples: Number of samples to return.
Returns:
Random sample (with replacement) of examples in `dataset`.
"""
return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
)
@@ -40,37 +51,44 @@ def load_datasets(
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
) -> TrainDatasetMeta:
"""Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.
"""
Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample. This is duplicated in
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.
debug: Whether to print out tokenization of sample
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
preprocess_iterable = (
cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (
cfg.debug
or getattr(cli_args, "debug", False)
or getattr(cli_args, "debug_text_only", False)
or getattr(cli_args, "debug_num_examples", 0) > 0
or debug
):
if ( # pylint: disable=too-many-boolean-expressions
cli_args
and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1
@@ -95,10 +113,13 @@ def load_datasets(
def load_preference_datasets(
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
"""
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
Optionally, logs out debug information.
Args:
@@ -109,28 +130,23 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl is RLType.GRPO:
total_num_steps = None
total_num_steps: int | None = None
if cfg.rl is not RLType.GRPO:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if (cli_args and cli_args.debug) or cfg.debug:
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, num_examples)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
dataset=train_samples,
tokenizer=tokenizer,
num_examples=num_examples,
text_only=text_only,
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)

View File

@@ -380,16 +380,14 @@ class TrainerBuilderBase(abc.ABC):
)
# eval_strategy and eval_steps
if not self.eval_dataset and self.cfg.val_set_size == 0:
# do not eval if no eval_dataset and val_set_size=0
if not self.eval_dataset or self.cfg.val_set_size == 0:
# do not eval if no eval_dataset or val_set_size=0
training_args_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
training_args_kwargs["eval_on_start"] = True
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
training_args_kwargs["eval_on_start"] = True
def _configure_reporting(self, training_args_kwargs: dict):
report_to = []
@@ -492,9 +490,6 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -21,12 +21,18 @@ from axolotl.core.trainers import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
@@ -57,6 +63,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(EvalFirstStepCallback())
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
@@ -123,9 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def _get_trainer_cls(self):
"""
Gets the trainer class for the given configuration.
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
@@ -142,12 +146,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlTrainer
def build(self, total_num_steps):
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
@@ -316,12 +314,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_arguments_kwargs.update(plugin_training_args)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs["kd_zscore_base_temp"] = (
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -375,7 +381,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
@@ -402,10 +408,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer
def build_collator(
self,
training_args, # type: "AxolotlTrainingArguments" # type: ignore
is_eval=False,
**kwargs,
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if (
@@ -434,19 +437,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
collator_args = [self.tokenizer]
collator_cls_and_kwargs = None
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
self.cfg, is_eval=is_eval
)
if collator_cls_and_kwargs:
collator = collator_cls_and_kwargs[0]
if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model:
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
@@ -477,6 +468,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq

View File

@@ -12,9 +12,13 @@ from axolotl.core.trainers import (
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
@@ -27,9 +31,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -53,7 +54,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
context_parallel=self.cfg.context_parallel_degree > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
@@ -78,12 +79,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Returns training_args and trainer_kwargs
"""
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
)
@@ -95,6 +90,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
training_args_kwargs["remove_unused_columns"] = False
# only rlhf
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
@@ -143,7 +142,22 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
if self.cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
# Not compatible with IPO
if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
@@ -151,12 +165,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,

View File

@@ -5,7 +5,7 @@
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
from .grpo.trainer import AxolotlGRPOContextParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (

View File

@@ -7,11 +7,13 @@ from __future__ import annotations
import os
from collections import defaultdict
from functools import partial, wraps
from typing import Callable, Literal, Optional
from typing import Any, Callable, Literal, Optional
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
import datasets
import torch
from datasets import Dataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
@@ -25,7 +27,6 @@ from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
CheckpointSaveMixin,
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
@@ -34,16 +35,13 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils import get_not_null
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
):
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -69,6 +67,32 @@ class AxolotlTrainer(
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# SPDA device mesh init
import torch.distributed as dist
world_size = dist.get_world_size()
mesh_shape = (
world_size // 2,
2,
)
self.world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
) -> torch.Tensor:
ctx_manager = get_context_parallel_manager(
world_mesh=self.world_mesh,
model=model,
)
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
with ctx_manager(list(to_shard.values())):
super().training_step(model, inputs, num_items_in_batch)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -105,7 +129,7 @@ class AxolotlTrainer(
)
batch_max_len = train_batch_size * self.args.max_seq_length
sampler = MultipackBatchSampler(
return MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
@@ -115,12 +139,8 @@ class AxolotlTrainer(
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
num_processes=self.args.dataset_num_proc,
)
len(sampler)
return sampler
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
@@ -228,9 +248,7 @@ class AxolotlTrainer(
}
if not isinstance(dataset, torch.utils.data.IterableDataset):
dataloader_params["drop_last"] = get_not_null(
self.args.dataloader_drop_last, True
)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if sampler_fn is not None:
sampler = sampler_fn(dataset)
if isinstance(sampler, BatchSampler):

View File

@@ -22,19 +22,10 @@ class DPOStrategy:
training_args_kwargs = {}
if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
return training_args_kwargs

View File

@@ -14,5 +14,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
dpo_norm_loss: bool | None = False

View File

@@ -83,20 +83,3 @@ class AxolotlDPOTrainer(
gc.collect()
torch.cuda.empty_cache()
return loss
def concatenated_forward(
self,
model: nn.Module,
batch: dict[str, Union[list, torch.LongTensor]],
is_ref_model: bool = False,
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)

View File

@@ -8,7 +8,7 @@ from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOSequenceParallelTrainer,
AxolotlGRPOContextParallelTrainer,
AxolotlGRPOTrainer,
)
from axolotl.utils.dict import DictDefault
@@ -23,10 +23,10 @@ class GRPOStrategy:
@classmethod
def get_trainer_class(
cls, sequence_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
if sequence_parallel:
return AxolotlGRPOSequenceParallelTrainer
cls, context_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOContextParallelTrainer]:
if context_parallel:
return AxolotlGRPOContextParallelTrainer
return AxolotlGRPOTrainer
@classmethod
@@ -69,8 +69,8 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
if cfg.context_parallel_degree > 1:
grpo_args_kwargs["context_parallel_degree"] = cfg.context_parallel_degree
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights

View File

@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None
context_parallel_degree: int | None = None

View File

@@ -1,7 +1,7 @@
"""Repeat random sampler (similar to the one implemented in
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
sequence parallelism functionality; i.e., duplicating data across ranks in the same
sequence parallel group.
context parallelism functionality; i.e., duplicating data across ranks in the same
context parallel group.
"""
from typing import Iterator, Sized
@@ -10,26 +10,26 @@ import torch
from torch.utils.data import Sampler
class SequenceParallelRepeatRandomSampler(Sampler):
"""Sampler for GRPO training with sequence parallelism.
class ContextParallelRepeatRandomSampler(Sampler):
"""Sampler for GRPO training with context parallelism.
This sampler ensures:
- Ranks in the same sequence parallel (SP) group receive identical data.
- Ranks in the same context parallel (SP) group receive identical data.
- Each index is repeated multiple times for sampling different completions.
- Entire batches are repeated for reuse in multiple updates.
- Data is properly distributed across SP groups.
- Data is properly distributed across CP groups.
In the table below, the values represent dataset indices. Each SP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
In the table below, the values represent dataset indices. Each CP group has
`context_parallel_degree = 2` GPUs working together on the same data. There are 2
CP groups (SP0 and SP1), with `world_size = 4` total GPUs.
Sequence Parallel Groups
Context Parallel Groups
| SP0 | SP1 |
| GPU 0 | GPU 1 | GPU 2 | GPU 3 |
global_step step <---> mini_repeat_count=3
<----------> batch_size=2 per SP group
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU
<----------> batch_size=2 per CP group
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- CP groups get different data
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each CP group GPU
|
| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations
num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: Rank of current process.
batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group.
context_parallel_degree: Number of ranks in a context parallel group.
shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch.
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
context_parallel_degree: int = 1,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
@@ -76,16 +76,16 @@ class SequenceParallelRepeatRandomSampler(Sampler):
self.world_size = world_size
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
# Context parallelism parameters
self.context_parallel_degree = context_parallel_degree
self.num_sp_groups = world_size // context_parallel_degree
self.sp_group_id = rank // context_parallel_degree
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)
self.total_size = self.num_samples
# Calculate effective number of samples per SP group
# Calculate effective number of samples per CP group
if (
self.drop_last
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
@@ -125,8 +125,8 @@ class SequenceParallelRepeatRandomSampler(Sampler):
padding = indices[: self.batch_size - len(indices) % self.batch_size]
indices += padding
# Subsample based on SP group ID
# Each SP group gets distinct batches of data
# Subsample based on CP group ID
# Each CP group gets distinct batches of data
batch_indices = []
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
start_idx = i + self.sp_group_id * self.batch_size

View File

@@ -1,9 +1,8 @@
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
"""Axolotl GRPO trainers (with and without context parallelism handling)"""
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from functools import partial
from typing import Any
import datasets
@@ -42,7 +41,7 @@ from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.grpo.sampler import ContextParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
@@ -59,45 +58,9 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""
class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for context parallelism handling"""
def __init__(
self,
@@ -134,11 +97,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
)
# Get number of SP groups (number of processes divided by SP degree)
# Get number of CP groups (number of processes divided by CP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
num_sp_groups = num_processes // self.args.context_parallel_degree
# Calculate batch size per SP group (not per process)
# Calculate batch size per CP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
possible_values = [
n_gen
@@ -148,7 +111,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values:
raise ValueError(
f"The batch size per SP group ({num_sp_groups} x "
f"The batch size per CP group ({num_sp_groups} x "
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
f"the number of generations per prompt ({self.num_generations}). Given "
"the current configuration, the valid values for the number of "
@@ -156,7 +119,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
if self.args.eval_strategy != "no":
# If sequence parallelism is enabled, calculate batch size per SP group
# If context parallelism is enabled, calculate batch size per CP group
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
possible_values = [
n_gen
@@ -166,8 +129,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values:
raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"With context parallelism (degree {self.args.context_parallel_degree}), "
f"the eval batch size per CP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, "
f"the valid values for the number of generations are: {possible_values}."
@@ -180,7 +143,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.local_world_size = 1
def train(self, *args, **kwargs):
# Initialize the SP group
# Initialize the CP group
self.sp_group = get_ring_attn_group()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
@@ -196,16 +159,16 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
* self.args.gradient_accumulation_steps
)
return SequenceParallelRepeatRandomSampler(
return ContextParallelRepeatRandomSampler(
dataset=self.train_dataset,
mini_repeat_count=self.num_generations,
world_size=self.world_size,
rank=self.rank,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
// self.args.context_parallel_degree,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree,
context_parallel_degree=self.args.context_parallel_degree,
shuffle=True,
seed=self.args.seed,
drop_last=True,
@@ -263,11 +226,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# Return unprepared dataloader if using context parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
@@ -340,21 +303,21 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1:
# Calculate sequence parallel group information
if self.args.context_parallel_degree > 1:
# Calculate context parallel group information
world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree
num_sp_groups = world_size // sequence_parallel_degree
context_parallel_degree = self.args.context_parallel_degree
num_sp_groups = world_size // context_parallel_degree
# Since processes in the same SP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group
# Since processes in the same CP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each CP group
ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree
# Get the first process from each CP group (typically the group leader)
group_leader_rank = sp_group_id * context_parallel_degree
# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
# Extract prompts from this CP group, accounting for num_generations duplicates
# We only need prompts from one rank in each CP group
group_prompts = all_prompts_text[
group_leader_rank
* len(prompts_text) : (group_leader_rank + 1)
@@ -367,7 +330,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree
:: self.num_generations * self.args.context_parallel_degree
]
with profiling_context(self, "vLLM.generate"):
@@ -384,28 +347,28 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
else:
completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree
len(all_prompts_text) // self.args.context_parallel_degree
)
# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
# Determine the appropriate slice based on context parallelism
if self.args.context_parallel_degree > 1:
# Calculate CP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group
# Calculate the start index for this CP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice
# All ranks in the same CP group get the same data slice
process_slice = slice(
sp_group_start,
sp_group_start + len(prompts),
)
completion_ids = completion_ids[process_slice]
else:
# Original behavior for non-sequence parallel case
# Original behavior for non-context parallel case
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
@@ -615,20 +578,20 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
if self.args.context_parallel_degree > 1:
# Calculate CP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group
# Calculate the start index for this CP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice
# All ranks in the same CP group get the same data slice
process_slice = slice(
sp_group_start,
sp_group_start + len(prompts),
)
else:
# Original behavior for non-sequence parallel case
# Original behavior for non-context parallel case
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),

View File

@@ -3,7 +3,6 @@
# pylint: disable=unused-import
# flake8: noqa
from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin

View File

@@ -1,21 +0,0 @@
"""Custom handling to not fail training if fsdp optimizer is not savable"""
from transformers import Trainer
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CheckpointSaveMixin(Trainer):
"""Mixin to handle saving the optimizer and scheduler if they are not savable."""
def _save_optimizer_and_scheduler(self, output_dir):
try:
super()._save_optimizer_and_scheduler(output_dir)
except NotImplementedError as exc:
LOG.warning(
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible."
)

View File

@@ -2,17 +2,238 @@
extra axolotl specific training args
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Type
from typing import Optional
from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.integrations.config import merge_training_args
AxolotlTrainingMixins: Type = merge_training_args()
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
kd_ce_alpha: Optional[float] = field(
default=None,
metadata={
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
},
)
kd_alpha: Optional[float] = field(
default=1.0,
metadata={"help": "The alpha scaling parameter for KD loss"},
)
kd_temperature: Optional[float] = field(
default=1.0,
metadata={
"help": "the temperature parameter for KL divergence loss when using KD"
},
)
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
kd_top_k_before_softmax: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section
@dataclass

View File

@@ -1,224 +0,0 @@
"""
Base Axolotl Training Mixins shared across various trainer configs
"""
from dataclasses import dataclass, field
from typing import Optional
from PIL.Image import Resampling
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "The number of processes to use for data processing"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
# kd_ce_alpha: Optional[float] = field(
# default=None,
# metadata={
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
# },
# )
#
# kd_alpha: Optional[float] = field(
# default=1.0,
# metadata={"help": "The alpha scaling parameter for KD loss"},
# )
#
# kd_temperature: Optional[float] = field(
# default=1.0,
# metadata={
# "help": "the temperature parameter for KL divergence loss when using KD"
# },
# )
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section

View File

@@ -1,6 +1,7 @@
"""Module containing Dataset functionality"""
import os
from typing import List, Optional, Union
import torch
from datasets import Dataset, IterableDataset
@@ -19,21 +20,21 @@ LOG = get_logger(__name__)
class TokenizedPromptDataset(Dataset):
"""Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer: The prompt tokenizing method for processing the data.
dataset: Dataset with text files.
process_count: Number of processes to use for tokenizing.
keep_in_memory: Whether to keep the tokenized dataset in memory.
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
process_count (int): Number of processes to use for tokenizing.
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
"""
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
@@ -48,13 +49,6 @@ class TokenizedPromptDataset(Dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
LOG.info(
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
)
num_proc = 1
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
@@ -82,14 +76,14 @@ class TokenizedPromptDataset(Dataset):
def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset | IterableDataset,
dataset: Union[Dataset, IterableDataset],
**kwargs,
):
if isinstance(dataset, IterableDataset):
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = list(dataset.features.keys())
features = dataset.features.keys()
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
@@ -100,13 +94,12 @@ def wrap_dataset_for_tokenized_prompt(
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
@@ -117,7 +110,7 @@ class ConstantLengthDataset(IterableDataset):
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.datasets: List[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
@@ -181,10 +174,7 @@ class ConstantLengthDataset(IterableDataset):
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
@@ -16,7 +17,6 @@ from axolotl.train import (
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

View File

@@ -22,7 +22,6 @@ from __future__ import annotations
import collections
import importlib
import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
@@ -84,11 +83,6 @@ class BasePlugin:
def get_input_args(self) -> str | None:
"""Returns a pydantic model for the plugin's input arguments."""
def get_training_args_mixin(self) -> str | None:
"""
Returns a dataclass model for the plugin's training arguments.
"""
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -164,31 +158,6 @@ class BasePlugin:
trainer: The trainer object for training.
"""
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
"""
Returns custom training arguments to set on TrainingArgs.
Args:
cfg: The global axolotl configuration.
Returns:
object: dict containing the training arguments.
"""
def get_collator_cls_and_kwargs(
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
"""
Returns a custom class for the collator.
Args:
cfg: The global axolotl configuration.
is_eval: Whether this is an eval split.
Returns:
class: The class for the collator.
"""
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
@@ -309,7 +278,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
class PluginManager: # pylint: disable=too-many-public-methods
class PluginManager:
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
@@ -368,11 +337,8 @@ class PluginManager: # pylint: disable=too-many-public-methods
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
LOG.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError as exc:
except ImportError:
LOG.error(f"Failed to load plugin: {plugin_name}")
# print stacktrace
traceback.print_exc()
print(f"Error: {exc}")
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
@@ -387,20 +353,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
input_args.append(input_args_from_plugin)
return input_args
def get_training_args_mixin(self):
"""
Returns a list of dataclasses for all registered plugins' training args mixins'
Returns:
list[str]: A list of dataclsses
"""
training_args = []
for plugin in self.plugins.values():
training_args_from_plugin = plugin.get_training_args_mixin()
if training_args_from_plugin is not None:
training_args.append(training_args_from_plugin)
return training_args
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -490,42 +442,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
return trainer_cls
return None
def get_training_args(self, cfg):
"""
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
object: The training arguments
"""
training_args_kwargs = {}
for plugin in self.plugins.values():
training_args = plugin.get_training_args(cfg)
if training_args is not None:
training_args_kwargs.update(training_args)
return training_args_kwargs
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
"""
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
Parameters:
cfg (dict): The configuration for the plugins.
is_eval (bool): Whether this is an eval split.
Returns:
object: The collator class, or None if none was found.
"""
for plugin in self.plugins.values():
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
if collator is not None:
collator_cls, collator_kwargs = collator
return collator_cls, collator_kwargs
return None
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins.

View File

@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
This was moved here to prevent circular imports.
"""
from typing import Any, Dict, List, Type
from typing import Any, Dict, List
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
@@ -61,43 +61,3 @@ def merge_input_args():
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
def merge_training_args() -> Type:
"""
Merges training arguments from registered plugins with the base TrainingArguments.
This function retrieves the training arguments from registered plugins using the PluginManager.
It then dynamically creates new classes, AxolotlTrainingMixins,
that inherit from the base configurations and include the training arguments from the plugins.
Returns:
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
"""
# pylint: disable=duplicate-code
from axolotl.core.training_args_base import (
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
)
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
mixin_classes = []
dynamic_input = ""
for plugin_args in training_args_mixins:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
mixin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, {**globals(), **local_vars}, namespace
)
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
"AxolotlTrainingMixins"
]
return AxolotlTrainingMixins
return AxolotlTrainingMixinsBase

View File

@@ -24,14 +24,6 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
## Usage
**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet.
```bash
git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764
pip3 install --no-build-isolation -e .
```
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -15,12 +15,7 @@
"""
Plugin init to add KD support to Axolotl.
"""
from typing import Any
from transformers import Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
@@ -33,75 +28,9 @@ class KDPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"
def get_training_args_mixin(self):
return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer
return AxolotlKDTrainer
return None
def get_training_args(self, cfg):
return {
"kd_ce_alpha": cfg.kd_ce_alpha,
"kd_alpha": cfg.kd_alpha,
"kd_temperature": cfg.kd_temperature,
"kd_beta": cfg.kd_beta,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
if not cfg.kd_trainer:
return None, None
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
use_batch_sampler_collator = False
if is_eval is False and cfg.sample_packing:
use_batch_sampler_collator = True
if cfg.eval_sample_packing and is_eval:
use_batch_sampler_collator = True
if cfg.kd_online_server_base_url:
from .collator_online_teacher import OnlineTeacherCollator
return OnlineTeacherCollator, {
"kd_online_server_base_url": cfg.kd_online_server_base_url,
"kd_online_topk": cfg.kd_online_topk,
"kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server,
"kd_online_timeout": cfg.kd_online_timeout,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
if use_batch_sampler_collator:
return KDBatchSamplerDataCollatorForSeq2Seq, {}
return DataCollatorForKD, {}
def pre_model_load(self, cfg):
from .kernels.models import apply_kernel
apply_kernel(cfg.model_config_type)
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds temp scheduler callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
callback = KDTemperatureSchedulerCallback(
cfg.kd_temperature,
cfg.kd_temperature_min,
trainer,
)
return [callback]
return []

View File

@@ -15,19 +15,9 @@
"""
Plugin args for KD support.
"""
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class InferenceServerType(str, Enum):
"""
Online inferences server types to handle different request args
"""
vllm = "vllm" # pylint: disable=invalid-name
sglang = "sglang" # pylint: disable=invalid-name
from pydantic import BaseModel
class KDArgs(BaseModel):
@@ -35,41 +25,13 @@ class KDArgs(BaseModel):
Input args for knowledge distillation.
"""
kd_trainer: float | None = None # whether to use KD trainer
kd_ce_alpha: float | None = (
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[float] = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = 0.0 # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: bool | None = (
None # whether to normalize student logits during KD
)
# TODO online kd
kd_online_server_base_url: str | None = None
kd_online_topk: int | None = None
kd_online_server: InferenceServerType | None = Field(
default_factory=lambda: InferenceServerType.vllm
)
kd_online_timeout: int | None = 120
kd_temperature_min: float | None = (
None # kd temperature scheduling during online kd
)
@dataclass
class KDTrainingArgsMixin:
"""
Additional args for KD training.
"""
kd_ce_alpha: float | None = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[bool] = (
None # whether to sample top k before softmax during KD
)

View File

@@ -1,36 +0,0 @@
"""
Transformers trainer callbacks to schedule the KD temperature during training
"""
import math
from transformers.trainer_callback import TrainerCallback
class KDTemperatureSchedulerCallback(TrainerCallback):
"""
KD temperature scheduler callback for the trainer.
"""
def __init__(self, temperature_start, temperature_min, trainer):
self.temperature_start = temperature_start
self.temperature_min = temperature_min
self.temperature = temperature_start
self.trainer = trainer
def on_step_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
# cosine decay temperature over the max steps
progress = state.global_step / state.max_steps
# Cosine decay factor: 0.5 * (1 + cos(pi * progress))
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
self.temperature = self.temperature_start - (
(self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
)
if hasattr(self.trainer.data_collator, "kd_temperature"):
self.trainer.data_collator.kd_temperature = self.temperature

View File

@@ -15,15 +15,12 @@
"""
Chat template prompt strategy loader with KD support
"""
import logging
from typing import Any, Dict
import torch
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
LOG = logging.getLogger(__name__)
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
"""
@@ -104,8 +101,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# we shift for causal models in the trainer, so start the range from 0
for _ in range(0, input_padding_len):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
@@ -144,10 +143,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# normalize probabilities to sum to 1 in case they aren't already
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
if teacher_probs_t1_sum > 1e-9:
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
@@ -167,115 +162,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask
return sample
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
"""
Strat for datasets with complete structured KD logprob data
"""
def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
# pylint: disable=duplicate-code
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = []
target_token_ids = []
target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = [row[:top_k] for row in logprobs]
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# we shift for causal models in the trainer, so start the range from 0
for _ in range(0, input_padding_len):
if shift == 1:
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for token_pos_logprobs, pos_target_token_ids in zip(
logprobs, sample["target_token_ids"]
):
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
token_pos_logprobs, dtype=torch.float
)
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# normalize probabilities to sum to 1 in case they aren't already
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
if teacher_probs_t1_sum > 1e-9:
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(pos_target_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
@@ -285,10 +177,8 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
target_token_ids = prompt.pop("target_token_ids")
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt["target_token_ids"] = target_token_ids
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
@@ -299,7 +189,7 @@ class KDStrategyLoader(StrategyLoader):
Load ChatTemplateStrategy with KD support using StrategyLoader.
"""
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
def _get_strategy_cls(self):
return ChatTemplateStrategyWithKD
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
@@ -314,14 +204,4 @@ class KDStrategyLoader(StrategyLoader):
return strategy_params
class KDStrategyLoaderV2(KDStrategyLoader):
"""
Load KD chat template datasets with pre-tokenized logprob data
"""
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
return ChatTemplateStrategyWithKDv2
load_legacy = KDStrategyLoader()
load = KDStrategyLoaderV2()
load = KDStrategyLoader()

View File

@@ -47,16 +47,11 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
position_pad_token_id: int = 0
return_tensors: str = "pt"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
padding_side = self.tokenizer.padding_side
max_len = 0
# Pad labels and position_ids first
for feature_name, pad_token_id in [
@@ -107,9 +102,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
target_mask_list.append(f.pop("target_mask"))
# Determine max lengths
max_teacher_seq_len = max_len or max(
len(seq) for seq in target_logprobs_list
)
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
padded_target_logprobs = []
@@ -216,9 +209,7 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# We want to produce a single "merged" feature dict for each sub-batch.
out_features = [{} for _ in features]
for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks
features
):
for i, sub_features in enumerate(features):
# sub_features is a list of dicts, each dict = one sequences features
# We'll merge them into out_features[i].
#
@@ -252,17 +243,10 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# For example, input_ids or labels are often arrays.
arrays = []
for feat in sub_features:
if field_name in feat and isinstance(
feat[field_name], (list, torch.Tensor)
):
if isinstance(
feat[field_name][0], (dict, str)
): # pylint: disable=too-many-nested-blocks
continue
if field_name in feat:
arr = np.array(feat[field_name])
arrays.append(arr)
if arrays:
out_features[i][field_name] = np.concatenate(arrays)
out_features[i][field_name] = np.concatenate(arrays)
# 3) Now call the parent collator, which will do:
# - padding of labels/position_ids

View File

@@ -1,561 +0,0 @@
"""
Packed data loader for online teacher training supporting vllm and sglang.
"""
import hashlib
import hmac
import logging
from typing import Any, Dict, List, Optional
import requests
import torch
from orjson import orjson
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.integrations.kd.utils import normalize_logprobs
from axolotl.utils.data.utils import retry_on_request_exceptions
LOG = logging.getLogger(__name__)
def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):
"""
Create HMAC-SHA hash from a list of integers
Args:
int_list: List of integers
key: Secret key (string or bytes)
hash_func: Hash function (default: sha256)
Returns:
HMAC digest as hex string
"""
# Convert key to bytes if it's a string
if isinstance(key, str):
key = key.encode("utf-8")
# Convert list of ints to bytes
# Method 1: Convert each int to bytes and concatenate
data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list)
# Create HMAC
h = hmac.new(key, data, hash_func)
return h.hexdigest()
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
"""
Collator for online teacher training.
"""
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
def __init__(
self,
*args: Any,
kd_online_server_base_url: Optional[str] = None,
kd_online_topk: Optional[int] = None,
kd_temperature: Optional[float] = 1.0,
kd_online_server: Optional[str] = "vllm",
kd_online_timeout: Optional[int] = 120,
kd_cache_dir: Optional[str] = None,
kd_normalize_topk: Optional[bool] = True,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
if kd_online_server_base_url is None:
raise ValueError(
"kd_online_server_base_url must be provided for OnlineTeacherDataloader"
)
if kd_online_topk is None or kd_online_topk <= 0:
raise ValueError(
"kd_online_topk must be a positive integer for OnlineTeacherDataloader"
)
self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
self.kd_online_topk = kd_online_topk
self.kd_temperature = kd_temperature
self.kd_online_server = kd_online_server
self.http_session = requests.Session()
self.kd_online_timeout = kd_online_timeout
self.kd_cache_dir = kd_cache_dir
self.kd_normalize_topk = kd_normalize_topk
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
if not raw_logprobs or self.kd_online_topk == 0:
return (
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
)
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_sglang(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/generate"
payload = {
"input_ids": batch_input_ids,
"return_logprob": True,
"top_logprobs_num": self.kd_online_topk,
"logprob_start_len": 0,
"return_text_in_logprobs": True,
"echo": True,
"sampling_params": {
"max_new_tokens": 0,
"temperature": self.kd_temperature,
"skip_special_tokens": False,
},
}
# Initialize with empty lists, so if API call fails, these are returned.
ret_data_target_token_ids: List[List[List[int]]] = []
ret_data_target_logprobs: List[List[List[float]]] = []
ret_data_target_mask: List[List[List[int]]] = []
try:
response = self.http_session.post(
api_endpoint, json=payload, timeout=self.kd_online_timeout
)
response.raise_for_status()
api_data: list[dict] = response.json()
# Ensure api_data is a list, and its length matches batch_input_ids
if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids):
LOG.error(
f"API response format error. Expected a list of {len(batch_input_ids)} "
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
for sequence_data, seq_input_ids, seq_labels in zip(
api_data, batch_input_ids, labels
):
current_target_logprobs = []
current_target_token_ids = []
current_target_mask = []
meta_info = sequence_data.pop("meta_info", {})
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
"input_top_logprobs", []
)
if not isinstance(input_top_logprobs, list):
LOG.warning(
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
for i, _, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
elif (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data = input_top_logprobs[i]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (
isinstance(pos_top_logprobs_data, list)
and all(
isinstance(item, list) for item in pos_top_logprobs_data
)
and len(pos_top_logprobs_data) > 0
and len(pos_top_logprobs_data[0]) == 3
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_logprobs_raw, pos_token_ids, _ = [
list(row) for row in zip(*pos_top_logprobs_data)
]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk)
else:
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
ret_data_target_token_ids.append(current_target_token_ids)
ret_data_target_logprobs.append(current_target_logprobs)
ret_data_target_mask.append(current_target_mask)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_vllm(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/v1/completions"
payload = {
"prompt": batch_input_ids,
"echo": True,
"logprobs": True,
"prompt_logprobs": self.kd_online_topk,
"top_logprobs": self.kd_online_topk,
"max_new_tokens": 0,
"skip_special_tokens": False,
"temperature": self.kd_temperature,
"sampling_params": {
"max_tokens": 0,
},
}
# Initialize with empty lists, so if API call fails, these are returned.
ret_data_target_token_ids: List[List[List[int]]] = []
ret_data_target_logprobs: List[List[List[float]]] = []
ret_data_target_mask: List[List[List[int]]] = []
try:
headers = {"Accept-Encoding": "deflate, gzip, br, zstd"}
response = self.http_session.post(
api_endpoint,
json=payload,
headers=headers,
timeout=self.kd_online_timeout,
)
response.raise_for_status()
api_data: dict = orjson.loads(response.content)
choices: list[dict] = api_data["choices"]
# Ensure api_data is a list, and its length matches batch_input_ids
if not isinstance(choices, list) or len(choices) != len(batch_input_ids):
LOG.error(
f"API response format error. Expected a list of {len(batch_input_ids)} "
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
for sequence_data, seq_input_ids, seq_labels in zip(
choices, batch_input_ids, labels
):
# seq_input_ids: List[int]
# seq_labels: List[int]
current_target_logprobs = []
current_target_token_ids = []
current_target_mask = []
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None | dict[str, dict]]] = (
sequence_data.pop("prompt_logprobs", [])
)
if not isinstance(input_top_logprobs, list):
LOG.warning(
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
seq_len = len(seq_input_ids)
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
continue
if (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (
isinstance(pos_top_logprobs_data, dict)
and all(
isinstance(item, dict)
for item in pos_top_logprobs_data.values()
)
and len(pos_top_logprobs_data.keys()) > 0
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_token_ids_str = list(pos_top_logprobs_data.keys())
pos_logprobs_dict = pos_top_logprobs_data.values()
pos_token_ids = [
int(token_id) for token_id in pos_token_ids_str
]
pos_logprobs_raw = [
float(logprob.get("logprob", -float("inf")))
for logprob in pos_logprobs_dict
]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
LOG.warning(
f"Padding position {i} with {pad_len} top-k tokens and logprobs."
)
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk)
else:
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
for i in range(max(0, seq_len - len(current_target_logprobs))):
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(list(range(self.kd_online_topk)))
current_target_mask.append([0] * self.kd_online_topk)
ret_data_target_token_ids.append(current_target_token_ids)
ret_data_target_logprobs.append(current_target_logprobs)
ret_data_target_mask.append(current_target_mask)
# TODO save and load targets to disk for caching for next epoch
# generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int
# if self.kd_cache_dir:
# hash_input_ids = hmac_sha_from_int_list(
# seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}"
# )
# with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f:
# pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
def __call__(
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
) -> Dict[str, Any]:
if not features:
return super().__call__(features, return_tensors=return_tensors)
for (
sub_batch_features
) in features: # sub_batch_features is List[Dict[str, Any]]
if not sub_batch_features:
continue
input_ids_for_api_call: List[List[int]] = []
labels_for_api_call: List[List[int]] = []
# Store references to the original item dictionaries to update them in-place
items_for_api_call: List[Dict[str, Any]] = []
for item_dict in sub_batch_features:
if not isinstance(item_dict, dict):
LOG.warning(
f"Skipping non-dict item in sub_batch_features: {item_dict}"
)
continue
current_input_ids = item_dict.get("input_ids")
current_labels = item_dict.get("labels")
if current_input_ids is not None and current_labels is not None:
# Ensure input_ids and labels are lists of ints for JSON serialization
input_ids_list = (
current_input_ids.tolist()
if hasattr(current_input_ids, "tolist")
else list(current_input_ids)
)
labels_list = (
current_labels.tolist()
if hasattr(current_labels, "tolist")
else list(current_labels)
)
input_ids_for_api_call.append(input_ids_list)
labels_for_api_call.append(labels_list)
items_for_api_call.append(item_dict)
else:
# This item will not get teacher logprobs from the API.
# Initialize KD fields to empty lists so downstream collators handle them uniformly.
item_dict.setdefault("target_token_ids", [])
item_dict.setdefault("target_logprobs", [])
item_dict.setdefault("target_mask", [])
# print(items_for_api_call)
if items_for_api_call: # Only call API if there's something to process
if self.kd_online_server == "sglang":
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
input_ids_for_api_call, labels_for_api_call
)
else:
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
input_ids_for_api_call, labels_for_api_call
)
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
# Each value is a list, corresponding to items_for_api_call
for i, item_to_update in enumerate(items_for_api_call):
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
if api_responses_for_sub_batch and i < len(
api_responses_for_sub_batch["target_token_ids"]
): # Check bounds
assert len(
api_responses_for_sub_batch["target_token_ids"][i]
) == len(item_to_update["input_ids"])
assert len(
api_responses_for_sub_batch["target_logprobs"][i]
) == len(item_to_update["input_ids"])
assert len(
api_responses_for_sub_batch["target_mask"][i]
) == len(item_to_update["labels"])
item_to_update["target_token_ids"] = (
api_responses_for_sub_batch["target_token_ids"][i]
)
item_to_update["target_logprobs"] = api_responses_for_sub_batch[
"target_logprobs"
][i]
item_to_update["target_mask"] = api_responses_for_sub_batch[
"target_mask"
][i]
else:
# API call failed for this item, or response was shorter than expected.
# Ensure KD fields are initialized as empty lists.
LOG.warning(
f" (index {i}), or API response was too short. "
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
)
item_to_update.setdefault("target_token_ids", [])
item_to_update.setdefault("target_logprobs", [])
item_to_update.setdefault("target_mask", [])
return super().__call__(features, return_tensors=return_tensors)

View File

@@ -1,8 +0,0 @@
"""
Liger Chunked loss optimizations module
"""
from .liger import LigerFusedLinearKLTopKLogprobLoss
from .models import apply_kernel
__all__ = ["LigerFusedLinearKLTopKLogprobLoss", "apply_kernel"]

View File

@@ -1,485 +0,0 @@
"""
Liger Kernels for Chunked Top-K Log-Prob Distillation
"""
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_distillation import (
LigerFusedLinearDistillationBase,
)
from axolotl.integrations.kd.utils import normalize_logprobs
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
"""
Chunked kl-div loss for top-k logprobs
"""
@staticmethod
def distillation_loss_fn(
student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
beta: float = 0.0,
normalize_topk: bool = True,
) -> torch.Tensor:
"""
Compute Top-K KL divergence loss for a chunk.
Args:
student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
beta: Controls the type of KL divergence.
0.0 for Forward KL (P_teacher || P_student).
1.0 for Reverse KL (P_student || P_teacher).
0.5 for Symmetric KL (average of Forward and Reverse).
normalize_topk: Whether to normalize the log probabilities
Returns:
Sum of KL divergence losses for the chunk.
"""
topk = target_token_ids_chunk.shape[-1]
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
student_logits_temp_scaled.float()
)
target_logprobs_chunk = target_logprobs_chunk.float()
# Gather student logits for the top-k teacher token IDs
# target_token_ids_chunk: [chunk_size, top_k]
# student_logits_topk_temp_scaled: [chunk_size, top_k]
student_logits_topk_temp_scaled = torch.gather(
student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk
)
# Student log-probabilities for the gathered top-k tokens
student_lse = torch.logsumexp(
student_logits_temp_scaled, dim=-1, keepdim=True
) # [chunk_size, 1]
student_logprobs_topk_temp_scaled = (
student_logits_topk_temp_scaled - student_lse
)
# we have the top-k student logprobs, normalize them
if normalize_topk:
student_logprobs_topk_temp_scaled = normalize_logprobs(
student_logprobs_topk_temp_scaled, topk
)
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
# Teacher probabilities P(y|x_teacher) from logprobs
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
teacher_probs_valid = teacher_logprobs_valid.exp()
# Student probabilities P_student from log P_student
student_probs_topk_valid = student_logprobs_topk_valid.exp()
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
# The distillation loss is often formulated as -sum(P_teacher * log P_student)
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
# Here, target_logprobs_valid are log_softmax_teacher.
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
if beta == 0.0: # Contribution from Forward KL
fwd_kl_per_token = teacher_probs_valid * (
teacher_logprobs_valid - student_logprobs_topk_valid
)
kd_loss = fwd_kl_per_token.sum()
elif beta == 1.0: # Contribution from Reverse KL
rev_kl_per_token = student_probs_topk_valid * (
student_logprobs_topk_valid - teacher_logprobs_valid
)
kd_loss = rev_kl_per_token.sum()
else:
# JSD - Jensen-Shannon Divergence / Symmetric
mean_probs = (
1 - beta
) * student_probs_topk_valid + beta * teacher_probs_valid
log_mean_probs = mean_probs.log()
student_kl = F.kl_div(
log_mean_probs,
student_logprobs_topk_valid,
reduction="sum",
log_target=True,
)
teacher_kl = F.kl_div(
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
)
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
kd_loss = jsd_loss
return kd_loss
@staticmethod
def _compute_loss_kl_topk(
student_input_chunk: torch.Tensor,
student_weight: torch.Tensor,
# Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value
# or through `partial`. Let's make them explicit here for clarity.
target_token_ids_chunk: torch.Tensor,
target_logprobs_chunk: torch.Tensor,
target_mask_chunk: torch.Tensor,
target_chunk: torch.Tensor, # For hard loss (true labels)
student_bias: torch.Tensor = None, # This will be one of the grad targets
# Other params passed via `partial` from `forward`
distillation_loss_fn=None,
ignore_index: int = -100,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
compute_ce_loss: bool = True,
temperature: float = 1.0,
beta: float = 0.0,
normalize_topk: bool = True,
):
# Compute student logits for the chunk from hidden states and LM head
# student_input_chunk: [chunk_size, hidden_dim]
# student_lm_head_weight: [vocab_size, hidden_dim]
# student_logits_chunk: [chunk_size, vocab_size]
student_logits_chunk = F.linear(
student_input_chunk, student_weight, student_bias
)
ce_loss = torch.tensor(
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
)
if compute_ce_loss and weight_hard_loss > 0.0:
ce_loss = F.cross_entropy(
student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),
target_chunk.view(-1),
reduction="sum",
ignore_index=ignore_index,
)
soft_loss = torch.tensor(
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
)
if weight_soft_loss > 0.0:
student_logits_chunk_temp_scaled = student_logits_chunk / temperature
# Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()
# No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.
soft_loss = distillation_loss_fn(
student_logits_chunk_temp_scaled,
target_token_ids_chunk,
target_logprobs_chunk,
target_mask_chunk,
beta=beta,
normalize_topk=normalize_topk,
)
return soft_loss, ce_loss
@classmethod
def forward(
cls,
ctx,
student_input: torch.Tensor, # [batch_size, seq_len, dim]
student_lm_head_weight: torch.Tensor, # [dim, vocab_size]
target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k]
target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k]
target_mask: torch.Tensor, # [batch_size, seq_len, top_k]
true_labels: torch.Tensor, # [batch_size, seq_len]
student_lm_head_bias: torch.Tensor = None,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
beta: float = 0.0,
compiled: bool = False,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
grad_inputs_list = []
grad_bias_acc = (
torch.zeros_like(student_lm_head_bias)
if student_lm_head_bias is not None
else None
)
kd_loss_acc = torch.zeros(
(), device=student_input.device, dtype=student_input.dtype
)
ce_loss_acc = torch.zeros(
(), device=student_input.device, dtype=student_input.dtype
)
# This function will be what torch.func.grad_and_value differentiates.
# It takes student_input_chunk, student_weight (full), student_bias (full) as primals.
# Other necessary data (target_*, etc.) are passed as non-differentiable arguments.
def loss_fn_for_grad(
_student_input_chunk,
_student_lm_head_weight, # full weight
_student_lm_head_bias, # full bias
# Fixed arguments for a given chunk, not differentiated:
_target_token_ids_chunk,
_target_logprobs_chunk,
_target_mask_chunk,
_true_labels_chunk,
):
return cls._compute_loss_kl_topk(
student_input_chunk=_student_input_chunk,
student_weight=_student_lm_head_weight,
target_token_ids_chunk=_target_token_ids_chunk,
target_logprobs_chunk=_target_logprobs_chunk,
target_mask_chunk=_target_mask_chunk,
target_chunk=_true_labels_chunk,
student_bias=_student_lm_head_bias,
distillation_loss_fn=cls.distillation_loss_fn,
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
beta=beta,
normalize_topk=normalize_topk,
)
def accumulate_chunk_grads(
student_input_chunk_ac,
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
):
# student_weight and student_bias are closed over from the outer scope (full tensors)
if student_lm_head_bias is not None:
(
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
(chunk_kd_loss, chunk_ce_loss),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
)(
student_input_chunk_ac,
student_lm_head_weight,
student_lm_head_bias, # primals
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
) # non-primals
grad_bias_acc.add_(chunk_grad_bias)
else:
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
(
(chunk_grad_input, chunk_grad_weight), # No grad for bias
(chunk_kd_loss, chunk_ce_loss),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
)(
student_input_chunk_ac,
student_lm_head_weight,
None, # Pass None for student_bias primal
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
)
grad_weight_acc.add_(chunk_grad_weight)
kd_loss_acc.add_(chunk_kd_loss)
ce_loss_acc.add_(chunk_ce_loss)
return chunk_grad_input
if compiled:
accumulate_chunk_grads_compiled = torch.compile(
accumulate_chunk_grads, dynamic=True, backend="inductor"
) # dynamic=True often helpful
else:
accumulate_chunk_grads_compiled = accumulate_chunk_grads
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
B, N, D = student_input.shape # pylint: disable=invalid-name
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
# pad and shift for cross entropy loss
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(
student_input_flat, chunks=num_chunks, dim=0
)
_target_token_ids_chunks = torch.chunk(
target_token_ids_flat, chunks=num_chunks, dim=0
)
_target_logprobs_chunks = torch.chunk(
target_logprobs_flat, chunks=num_chunks, dim=0
)
_target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)
_true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)
for i in range(num_chunks):
grad_input_chunk = accumulate_chunk_grads_compiled(
_student_input_chunks[i],
_target_token_ids_chunks[i],
_target_logprobs_chunks[i],
_target_mask_chunks[i],
_true_labels_chunks[i],
)
grad_inputs_list.append(grad_input_chunk)
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
# we still need to scale the kd_loss by the temp^2
kd_loss_acc = kd_loss_acc * (temperature**2)
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
return final_loss
@staticmethod
def backward(ctx, grad_output):
grad_input_flat, grad_weight, grad_bias_maybe = (
ctx.saved_tensors
) # grad_input_flat is (B*N, D)
# Scale gradients by grad_output if it's not 1.0
if not torch.equal(
grad_output,
torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),
):
grad_input_flat = grad_input_flat * grad_output
grad_weight = grad_weight * grad_output
if grad_bias_maybe is not None:
grad_bias_maybe = grad_bias_maybe * grad_output
# Reshape grad_input_flat to match original student_input shape (B, N, D)
# ctx.orig_dims stores (B, N, D, K)
# We need the first three dimensions for student_input's shape.
# Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors
if (
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
and grad_input_flat.numel() == 0
):
# If original input was empty, gradient should also be empty with correct shape
grad_input_reshaped = torch.zeros(
ctx.orig_dims[0],
ctx.orig_dims[1],
ctx.orig_dims[2],
dtype=grad_input_flat.dtype,
device=grad_input_flat.device,
)
elif grad_input_flat.numel() == 0 and not (
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
):
# This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)
# but as a safeguard:
grad_input_reshaped = torch.zeros(
ctx.orig_dims[0],
ctx.orig_dims[1],
ctx.orig_dims[2],
dtype=grad_input_flat.dtype,
device=grad_input_flat.device,
)
else:
grad_input_reshaped = grad_input_flat.view(
ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]
)
nones_for_hyperparams = [None] * ctx.hyperparams_count
grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None
return (
grad_input_reshaped, # Gradient for student_input (reshaped)
grad_weight, # Gradient for student_lm_head_weight
None, # Gradient for target_token_ids
None, # Gradient for target_logprobs
None, # Gradient for target_mask
None, # Gradient for true_labels
grad_bias_return, # Gradient for student_lm_head_bias
*nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss
)
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
"""
wrapper for chunked top-k logprob kl-d
"""
def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
temperature: float = 1.0, # This is the kd_temperature
beta: float = 1.0,
ignore_index: int = -100,
compiled: bool = True,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
super().__init__()
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
raise ValueError("Loss weights must be between 0.0 and 1.0.")
if temperature <= 0:
raise ValueError("Temperature must be positive.")
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.temperature = temperature
self.beta = beta
self.ignore_index = ignore_index
self.compiled = compiled
self.chunk_size = chunk_size
self.compute_ce_loss = compute_ce_loss
self.normalize_topk = normalize_topk
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
print(
f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero."
)
# self.weight_hard_loss = 0.0 # Or let user manage this
if self.weight_soft_loss == 0.0:
print(
"Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed."
)
def forward(
self,
lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head
student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head
target_token_ids: torch.Tensor,
target_logprobs: torch.Tensor,
target_mask: torch.Tensor,
true_labels: torch.Tensor,
student_bias: torch.Tensor = None,
) -> torch.Tensor:
return LigerFusedLinearKLTopKLogprobFunction.apply(
student_hidden_states,
lm_head_weight,
target_token_ids,
target_logprobs,
target_mask,
true_labels,
student_bias,
self.weight_hard_loss,
self.weight_soft_loss,
self.ignore_index,
self.temperature,
self.beta,
self.compiled,
self.chunk_size,
self.compute_ce_loss,
self.normalize_topk,
)

View File

@@ -1,98 +0,0 @@
"""
model patcher for chunked top-k kl-div
"""
from types import MethodType
from typing import Optional, Union, Unpack
import torch
from transformers import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import LossKwargs
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
"""
placeholder kwargs for hf model classes
"""
def kldiv_forward_llama_like(
self,
input_ids: Optional[torch.LongTensor] = None,
target_logprobs: Optional[torch.Tensor] = None,
target_token_ids: Optional[torch.LongTensor] = None,
target_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self.loss_function(
self.lm_head.weight,
hidden_states,
target_token_ids,
target_logprobs,
target_mask,
true_labels=labels,
)
num_items_in_batch = kwargs.pop("num_items_in_batch", -1)
if num_items_in_batch is not None and num_items_in_batch > 0:
loss = loss / num_items_in_batch
return CausalLMOutputWithPast(
loss=loss,
logits=None,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_kernel(model_type):
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = MethodType(kldiv_forward_llama_like, model_cls)

View File

@@ -16,7 +16,40 @@
loss for top_k KL divergence
"""
import torch
from torch import nn
def zscore_standardize(
logits: torch.Tensor,
mask: torch.Tensor = None,
base_temperature: float = 1.0,
eps: float = 1e-9,
):
"""
Z-score standardize along the last dimension of `logits`.
i.e., for each [B, seq_len] row, across K entries:
z = (logits - mean) / std,
then scale by 1 / base_temperature if desired.
mask can be broadcastable or None. If None, we standardize all elements.
"""
if mask is None:
# shape: [B, seq_len, K]
# Mean and std over dim=-1
mean = logits.mean(dim=-1, keepdim=True)
var = logits.var(dim=-1, unbiased=False, keepdim=True)
else:
# If you have to exclude some tokens, multiply by mask, etc.
float_mask = mask.to(logits.dtype)
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
std = torch.sqrt(var.clamp_min(eps))
z = (logits - mean) / std
# Scale by 1 / base_temperature
z = z / base_temperature
return z
@torch.jit.script
@@ -27,6 +60,7 @@ def loss(
target_mask: torch.Tensor,
num_items_in_batch: int = -1, # Use -1 to indicate "None"
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
@@ -43,6 +77,8 @@ def loss(
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
Default: 0
"""
target_logprobs = target_logprobs.float()
@@ -52,24 +88,46 @@ def loss(
# student_logits shape: [B, student_seq_len, vocab_size]
teacher_seq_len = target_token_ids.shape[1]
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
if top_k_before_softmax:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
student_logits_topk = student_logits_topk.float()
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, teacher_seq_len, K]
else:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert teacher_mask to boolean for indexing
# In TorchScript, .bool() is sometimes unsupported, so we do:
@@ -86,6 +144,10 @@ def loss(
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
# Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items (if provided) or by valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
@@ -96,74 +158,80 @@ def loss(
return kd_loss
class ChunkedTopKKDLoss(nn.Module):
def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1,
):
"""
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
A variant of top_k KL divergence with Z-score scaling
from "Logit Standardization in Knowledge Distillation".
"""
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
super().__init__()
self.num_output_chunks = num_output_chunks
self.kd_temperature = kd_temperature
target_logprobs = target_logprobs.float()
def forward(
self,
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K]
target_mask: torch.Tensor, # [B, seq_len, K]
num_items_in_batch: int = -1, # optional batch size for normalization
) -> torch.Tensor:
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab]
student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, seq_len, K]
# 1. Split along the "token" dimension (dim=1).
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)
mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)
student_topk_logits = student_topk_logits.float()
# We'll accumulate a global "sum of losses" and "sum of valid tokens"
# so that our final average is consistent with the entire sequence/batch.
total_loss = 0.0
total_valid_tokens = 0
# 2) If you want to keep the "classical" T scaling, apply it first
if kd_temperature != 1.0:
student_topk_logits = student_topk_logits / kd_temperature
# 2. Loop over each chunk and compute a chunk-specific loss.
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
):
# We pass num_items_in_batch=-1 so that the kd_loss
# will average over *this chunk's* valid tokens only.
chunk_loss = loss(
student_logits=st_chunk,
target_token_ids=tid_chunk,
target_logprobs=lp_chunk,
target_mask=msk_chunk,
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
kd_temperature=self.kd_temperature,
)
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.)
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
# kd_loss returns an average over the chunk's valid tokens.
# We want a global average in the end, so we need to reweight
# by the number of valid tokens in this chunk and keep track of the total.
chunk_valid_mask = msk_chunk.to(torch.bool)
chunk_valid_count = chunk_valid_mask.sum() # scalar tensor
# 4) Z-score teacher and student
# If target_mask is 2D, expand to 3D for the K dimension
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
# Re-scale "chunk average" back to "chunk sum"
chunk_loss_sum = chunk_loss * chunk_valid_count
teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
)
student_z = zscore_standardize(
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
)
total_loss += chunk_loss_sum
total_valid_tokens += chunk_valid_count
# 5) Convert to log-probs for KL
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 3. Normalize *once* at the end.
if num_items_in_batch > 0:
# If the user gave us a manual denominator (e.g. total items in batch),
# we divide by it. Typically used if each item is of different length.
final_loss = total_loss / float(num_items_in_batch)
else:
# Otherwise, divide by total valid tokens across all chunks.
# to get the same result as a non-chunked approach.
final_loss = total_loss / float(total_valid_tokens)
# 6) Restrict to valid tokens if needed
valid_mask = target_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
student_logprobs_z = student_logprobs_z[valid_mask]
return final_loss
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
kd_loss = kd_loss_per_token.sum()
# 8) If using classical KD scaling by T^2
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
# kd_loss = kd_loss * (zscore_base_temp**2)
# 9) Normalize
if num_items_in_batch is not None and num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss

View File

@@ -18,7 +18,8 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainer(AxolotlTrainer):
@@ -26,18 +27,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
Custom trainer subclass for Knowledge Distillation (KD)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
self.args.kd_ce_alpha, # hard label loss
self.args.kd_alpha, # kd loss
self.args.kd_temperature,
self.args.kd_beta or 0.0,
compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk,
)
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
@@ -63,12 +52,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior.
"""
if (
self.args.sample_packing
and hasattr(inputs, "attention_mask")
and hasattr(inputs, "position_ids")
):
del inputs["attention_mask"]
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
seq_len = target_token_ids.shape[1]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
@@ -76,4 +65,49 @@ class AxolotlKDTrainer(AxolotlTrainer):
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
return outputs[0]
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
kd_temperature=self.args.kd_temperature,
zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
self.args.past_index
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss

View File

@@ -1,100 +0,0 @@
"""Helper KD utils"""
import math
from typing import List, Union
import numpy as np
import torch
from torch import FloatTensor, Tensor
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
if logprobs.shape[-1] != topk:
# pad last dimension of logprobs to match topk length with -inf
padding_len = topk - logprobs.shape[-1]
padding_tensor = torch.full(
(
*logprobs.shape[:-1],
padding_len,
), # Takes all dimensions of logprobs except the last, then appends padding_needed
float("-inf"),
dtype=logprobs.dtype,
device=logprobs.device,
)
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor
def strided_chunk_views(
tensor: Union[np.ndarray, torch.Tensor],
chunks: int,
dim: int = 0,
stride: int = 1,
chunk_size: int | None = None,
) -> List[Union[np.ndarray, torch.Tensor]]:
"""
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
Args:
tensor: Input tensor (numpy array or torch tensor)
chunks: Number of chunks to create
dim: Dimension along which to chunk (default: 0)
stride: Stride between chunk starting positions (default: 1)
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
Returns:
List of tensor chunks (views when possible, copies when necessary)
"""
# Get the size of the specified dimension
dim_size = tensor.shape[dim]
# Calculate chunk size if not provided
if chunk_size is None:
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
chunks_list = []
for i in range(chunks):
start_idx = i * stride
end_idx = min(start_idx + chunk_size, dim_size)
# Break if we've gone beyond the tensor
if start_idx >= dim_size:
break
# Create slice objects for all dimensions
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start_idx, end_idx)
chunk = tensor[tuple(slices)]
chunks_list.append(chunk)
return chunks_list
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
dim_size = input_tensor.shape[dim]
stride = math.ceil(dim_size / chunks)
return strided_chunk_views(
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
)

View File

@@ -166,17 +166,6 @@ class PatchManager:
def _apply_self_attention_lora_patch(self):
"""Apply self-attention LoRA patches if configured."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
# Only patch if conditions are met
can_patch = (
self.cfg.lora_dropout == 0
if hasattr(self.cfg, "lora_dropout")
else True
) # default to True if lora_dropout is not set
if not can_patch:
LOG.warning("Cannot patch self-attention - requires no dropout")
return
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)

View File

@@ -7,14 +7,12 @@ import transformers
from transformers import (
AddedToken,
AutoTokenizer,
PreTrainedTokenizer,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
barrier,
is_local_main_process,
@@ -119,21 +117,8 @@ def modify_tokenizer_files(
return tokenizer_dir
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
def _load_mistral_common_tokenizer(cfg: DictDefault):
"""Load mistral-common tokenizer"""
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
return tokenizer
if cfg.tokenizer_use_mistral_common:
return _load_mistral_common_tokenizer(cfg)
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -222,12 +207,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
)
and k != "pad_token"
):
lora_modules_to_save_str = ", ".join(
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save_str}] "
"when using an adapter and changing the special tokens."
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens(

View File

@@ -145,11 +145,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return Qwen2Attention
if model_type == "mllama":
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention
return MllamaTextSelfAttention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
@@ -274,29 +269,6 @@ def find_mlp_in_layer(
)
def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
"""
Get the layers of the model. Handles text-only and multimodal models.
Args:
model: A PEFT model.
Returns:
A list of layers.
"""
pretrained_model = model.model
# check for multimodal models first
if hasattr(pretrained_model, "language_model"):
return pretrained_model.language_model.layers
if hasattr(pretrained_model, "model"):
return pretrained_model.model.layers
raise NotImplementedError(
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
)
def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM:
@@ -368,7 +340,17 @@ def apply_lora_kernel_patches(
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")
layers = get_layers(model)
layers = []
# check for multimodal models first
pretrained_model = model.model
if hasattr(pretrained_model, "language_model"):
layers = pretrained_model.language_model.layers
elif hasattr(pretrained_model, "model"):
layers = pretrained_model.model.layers
else:
raise NotImplementedError(
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
)
# Patch each layer
for layer in layers:

View File

@@ -2,10 +2,10 @@
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
their context parallel version of Flash Attention 2.
We also provide some patches for accelerate functions to prepare the dataloader for
sequence parallelism training.
context parallelism training.
"""
import inspect
@@ -13,9 +13,9 @@ import inspect
import accelerate
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
@@ -63,15 +63,15 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
def register_ring_attn(
sequence_parallel_degree: int,
context_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
"""Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
context_parallel_degree: Context parallelism factor.
heads_k_stride: Context parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
@@ -80,28 +80,18 @@ def register_ring_attn(
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
LOG.info(
"Enabling ring attention context parallelism: "
f"each sequence will be processed across {context_parallel_degree} GPUs"
)
# Assign ranks to sequence parallel groups
# Assign ranks to context parallel groups
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
for i in range(world_size // context_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
i * context_parallel_degree,
(i + 1) * context_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
@@ -113,9 +103,7 @@ def register_ring_attn(
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
LOG.info(f"Context parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
@@ -150,7 +138,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the CP degree.
Raies:
RuntimeError: If source code to patch does not exist.
@@ -176,15 +164,15 @@ def patch_prepare_data_loader():
patched_function = namespace["prepare_data_loader"]
accelerate.data_loader.prepare_data_loader = patched_function
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
LOG.info("Patched accelerate.data_loader.prepare_data_loader for CP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int):
def patch_prepare_device_mesh(context_parallel_degree: int):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
that includes context parallelism with the specified degree.
Args:
sequence_parallel_degree (int): The degree of sequence parallelism to use.
context_parallel_degree (int): The degree of context parallelism to use.
"""
def _prepare_device_mesh(self):
@@ -199,11 +187,11 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
):
return self.state.ds_device_mesh
# Create device mesh with sequence parallelism
# Create device mesh with context parallelism
world_size = dist.get_world_size()
mesh_shape = (
world_size // sequence_parallel_degree,
sequence_parallel_degree,
world_size // context_parallel_degree,
context_parallel_degree,
)
device_ids = list(range(world_size))
@@ -221,5 +209,5 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
LOG.info(
"Successfully patched Accelerator._prepare_device_mesh "
f"with sequence_parallel_degree={sequence_parallel_degree}"
f"with context_parallel_degree={context_parallel_degree}"
)

View File

@@ -17,10 +17,7 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load"
package = "axolotl.prompt_strategies"
if (
strategy.split(".")[-1].startswith("load_")
or strategy.split(".")[-1] == "load"
):
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
elif len(strategy.split(".")) > 1:

View File

@@ -2,10 +2,8 @@
HF Chat Templates prompt strategy
"""
# pylint: disable=too-many-lines
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
from typing import Any, Dict, List, Set, Union
from pydantic import BaseModel
from transformers import ProcessorMixin
@@ -17,9 +15,6 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import DatasetConfig
if TYPE_CHECKING:
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
# Configure the logger
LOG = get_logger(__name__)
LOG.setLevel("INFO")
@@ -39,7 +34,6 @@ class ChatTemplatePrompter(Prompter):
message_field_training_detail: str | None = None,
field_messages: str = "messages",
field_system: str = "system",
field_tools: str = "tools",
roles: dict[str, list[str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
drop_system_message: bool = False,
@@ -72,7 +66,6 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.field_system = field_system
self.field_tools = field_tools
self.tokenizer = tokenizer
self.processor: ProcessorMixin | None = processor
self.chat_template = chat_template
@@ -84,38 +77,17 @@ class ChatTemplatePrompter(Prompter):
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables
def build_prompt(
self,
conversation: list[dict],
add_generation_prompt=False,
images=None,
tools=None,
):
"""
Build a prompt from a conversation.
Args:
conversation: A list of messages.
add_generation_prompt: Whether to add a generation prompt.
images: A list of images. (optional)
tools: A list of tools. (optional)
"""
chat_template_kwargs = {
"chat_template": self.chat_template,
"add_generation_prompt": add_generation_prompt,
}
if tools:
chat_template_kwargs["tools"] = tools
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")
text = self.processor.apply_chat_template(
conversation,
chat_template=self.chat_template,
tokenize=False,
**chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
**self.chat_template_kwargs,
)
batch = self.processor(
text=text,
@@ -132,7 +104,9 @@ class ChatTemplatePrompter(Prompter):
return self.tokenizer.apply_chat_template(
conversation,
**chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
**self.chat_template_kwargs,
)
def get_offsets_for_train_detail(
@@ -276,15 +250,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
# Default to eos_token if eot_tokens not provided
self.eot_tokens = []
if eot_tokens is not None:
self.eot_tokens = eot_tokens
elif (
hasattr(self.tokenizer, "eos_token")
and self.tokenizer.eos_token is not None
):
self.eot_tokens = [self.tokenizer.eos_token]
self.eot_tokens = (
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
)
self.split_thinking = split_thinking
self.images = "images"
@@ -408,7 +376,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
and not self.prompter.message_field_training_detail # type: ignore
):
turns = self.get_conversation_thread(prompt)
images = self._get_images(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt( # type: ignore
turns[:-1],
add_generation_prompt=True,
@@ -437,8 +405,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt
turns = self.get_conversation_thread(prompt)
tools = self._get_tools(prompt)
input_ids = self.prompter.build_prompt(turns, tools=tools) # type: ignore
input_ids = self.prompter.build_prompt(turns) # type: ignore
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
@@ -477,9 +444,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
continue
turn_start_idx, turn_end_idx = self.find_turn(
turns=turns, turn_idx=index, tools=tools
)
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
@@ -581,9 +546,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i
return -1
def find_turn(
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
):
def find_turn(self, turns: list[dict], turn_idx: int):
"""
Locate the starting and ending indices of the specified turn in a conversation.
"""
@@ -614,10 +577,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
turns_with_content = turns[: turn_idx + 1]
# Generate the conversation up to the turn, with final turn replaced with dummy content
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
# Generate the conversation up to the turn, with final turn included
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
if not full_ids or not dummy_ids:
LOG.warning(f"Empty template generated for turn {turn_idx}")
@@ -670,10 +633,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
messages = self._get_messages(prompt)
possible_sys_turn = self.transform_message(messages[0])
possible_sys_turn = self.transform_message(
prompt[self.prompter.field_messages][0]
)
if (
possible_sys_turn["role"] != "system"
and self.prompter.field_system in prompt
@@ -681,7 +643,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
turns.append(turn)
for message in messages:
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)
turn = {
@@ -699,7 +661,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return turns
def transform_message(self, message: dict) -> dict:
def transform_message(self, message):
# Build the initial transformed message from the mappings
transformed_message = {}
for key, value in self.prompter.message_property_mappings.items():
@@ -776,135 +738,18 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return transformed_message
def _get_images(self, prompt):
def get_images(self, prompt):
return prompt.get(self.images, None)
def _get_tools(self, prompt) -> list[dict] | None:
"""Get tools from prompt if available."""
tools = prompt.get(self.prompter.field_tools, None)
if tools is None:
return None
if isinstance(tools, list):
return tools
raise ValueError(
"Unknown tools format. Please convert it into a list[dict].\n"
f"Current format: {type(tools)}"
)
def _get_messages(self, prompt):
messages = prompt.get(self.prompter.field_messages, None)
if messages is None:
raise ValueError("Messages is null. Please check `field_messages`.")
if isinstance(messages, list):
return messages
raise ValueError(
"Unknown messages format. Please convert it into a list[dict].\n"
f"Current format: {type(messages)}"
)
class MistralStrategy(ChatTemplateStrategy):
"""
Mistral strategy for chat template.
"""
def __init__(
self,
prompter: "ChatTemplatePrompter",
tokenizer: "HFMistralTokenizer",
train_on_inputs: bool,
sequence_len: int,
roles_to_train: list[str] | None = None,
train_on_eos: str | None = None,
train_on_eot: str | None = None,
eot_tokens: list[str] | None = None,
split_thinking: bool | None = False,
):
# Call the parent's parent __init__ (PromptTokenizingStrategy) to skip ChatTemplateStrategy's validation
# pylint: disable=non-parent-init-called,super-init-not-called
PromptTokenizingStrategy.__init__(
self, prompter, tokenizer, train_on_inputs, sequence_len
)
self.prompter: ChatTemplatePrompter = prompter
self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]
self.train_on_eos = train_on_eos
# Backward compatibility, load from train_on_eos
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
# Default to eos_token if eot_tokens not provided
self.eot_tokens = []
if eot_tokens is not None:
self.eot_tokens = eot_tokens
else:
# set eot_tokens to the eos_token
self.eot_tokens = [self.tokenizer.eos_token]
self.split_thinking = split_thinking
self.images = "images"
LOG.debug(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)
# Skip the validation that ChatTemplateStrategy calls
# TODO: address this in the future with mistral-specific checks
# self._validate_eot_and_eos_tokens()
@property
def supports_multiprocessing(self) -> bool:
"""
Whether this tokenizing strategy supports multiprocessing.
mistral_common tokenizers cannot be pickled for multiprocessing.
"""
return False
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# mistral-common tokenizer does not support eot_tokens
return self.find_first_eos_token(input_ids, start_idx)
class MistralPrompter(ChatTemplatePrompter):
"""
Mistral prompter for chat template.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._chat_template_msg_variables = set(["tool_call_id", "name", "tool_calls"])
class StrategyLoader:
"""
Load chat template strategy based on configuration.
"""
def _get_strategy_cls(self, cfg):
if cfg.tokenizer_use_mistral_common:
return MistralStrategy
def _get_strategy_cls(self):
return ChatTemplateStrategy
def _get_prompter_cls(self, cfg):
if cfg.tokenizer_use_mistral_common:
return MistralPrompter
return ChatTemplatePrompter
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
return {
"train_on_inputs": cfg.train_on_inputs,
@@ -930,14 +775,9 @@ class StrategyLoader:
else:
dataset_config = ds_cfg
if cfg.tokenizer_use_mistral_common:
# mistral-common does not use this, so we pass an empty string
chat_template_string = ""
else:
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
@@ -963,11 +803,10 @@ class StrategyLoader:
}
strategy_params = self._get_strategy_params(cfg, dataset_config)
strategy_cls = self._get_strategy_cls(cfg)
prompter_cls = self._get_prompter_cls(cfg)
strategy_cls = self._get_strategy_cls()
strategy = strategy_cls(
prompter_cls(**prompter_params),
ChatTemplatePrompter(**prompter_params),
tokenizer=tokenizer,
**strategy_params,
)

View File

@@ -46,14 +46,6 @@ def default(
)
messages = sample[field_messages]
if isinstance(messages, str):
messages = [
{
message_property_mappings["role"]: "user",
message_property_mappings["content"]: messages,
}
]
messages = [
{
"role": role_map[m[message_property_mappings["role"]]],
@@ -61,35 +53,13 @@ def default(
}
for m in messages
]
chosen_raw = sample[field_chosen]
if isinstance(chosen_raw, str):
chosen_msg = {
message_property_mappings["role"]: "assistant",
message_property_mappings["content"]: chosen_raw,
}
elif isinstance(chosen_raw, dict):
chosen_msg = chosen_raw
else:
chosen_msg = chosen_raw[-1]
chosen = {
"role": role_map[chosen_msg[message_property_mappings["role"]]],
"content": chosen_msg[message_property_mappings["content"]],
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
"content": sample[field_chosen][message_property_mappings["content"]],
}
rejected_raw = sample[field_rejected]
if isinstance(rejected_raw, str):
rejected_msg = {
message_property_mappings["role"]: "assistant",
message_property_mappings["content"]: rejected_raw,
}
elif isinstance(rejected_raw, dict):
rejected_msg = rejected_raw
else:
rejected_msg = rejected_raw[-1]
rejected = {
"role": role_map[rejected_msg[message_property_mappings["role"]]],
"content": rejected_msg[message_property_mappings["content"]],
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
"content": sample[field_rejected][message_property_mappings["content"]],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}

View File

@@ -32,3 +32,4 @@ def load(tokenizer, cfg, ds_cfg, processor=None):
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -3,7 +3,6 @@
import abc
from typing import Callable, Dict, List, Optional, Tuple, Union
from datasets import Dataset
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompters import Prompter
@@ -29,16 +28,6 @@ class DatasetWrappingStrategy(abc.ABC):
Abstract class for wrapping datasets for Chat Messages
"""
@abc.abstractmethod
def wrap_dataset(
self,
dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
) -> Dataset:
pass
class PromptTokenizingStrategy(abc.ABC):
"""
@@ -70,14 +59,6 @@ class PromptTokenizingStrategy(abc.ABC):
def supports_batched(self):
return False
@property
def supports_multiprocessing(self):
"""
Whether this tokenizing strategy supports multiprocessing.
Should return False if the tokenizer has unpicklable objects.
"""
return True
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:

View File

@@ -1,13 +1,10 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
from __future__ import annotations
import importlib
import inspect
import os
import signal
import sys
import typing
import weakref
from contextlib import ExitStack
from pathlib import Path
@@ -34,7 +31,7 @@ from axolotl.loaders import (
load_processor,
load_tokenizer,
)
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.ctx_managers import ContextParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
@@ -47,9 +44,6 @@ try:
except ImportError:
BetterTransformer = None
if typing.TYPE_CHECKING:
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__)
@@ -58,8 +52,8 @@ def setup_model_and_tokenizer(
) -> tuple[
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]:
"""Load the tokenizer, processor (for multimodal models), and model based on
configuration.
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -153,7 +147,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool
):
"""
Set up signal handler for graceful termination.
@@ -207,15 +201,20 @@ def execute_training(
)
)
if cfg.sequence_parallel_degree > 1:
if cfg.context_parallel_degree > 1 and not cfg.sdp_attention:
# Models to enter context parallel manager for
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
# Attention backend
backend = "sdp_attention" if cfg.sdp_attention else "flash_attention"
stack.enter_context(
SequenceParallelContextManager(
ContextParallelContextManager(
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree,
backend=backend,
context_parallel_degree=cfg.context_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
@@ -229,7 +228,7 @@ def execute_training(
def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
model: PeftModel | PreTrainedModel,
safe_serialization: bool,
):
"""
@@ -380,7 +379,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
def save_initial_configs(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
model: PeftModel | PreTrainedModel,
peft_config: PeftConfig | None,
processor: ProcessorMixin | None,
):
@@ -434,7 +433,7 @@ def setup_model_card(cfg: DictDefault):
def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
model: PeftModel | PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
@@ -477,7 +476,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
Trainer,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,

View File

@@ -52,10 +52,3 @@ def patch_optimized_env():
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()
def get_not_null(value, default=None):
"""
return the value if it's not None, otherwise return the default value
"""
return value if value is not None else default

View File

@@ -53,6 +53,25 @@ IGNORE_INDEX = -100
LOG = get_logger(__name__)
class EvalFirstStepCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""
Callback to trigger evals on the first step
"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
control.should_evaluate = True
return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods

File diff suppressed because one or more lines are too long

View File

@@ -1,7 +1,7 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
from dataclasses import dataclass
from typing import Any, List
from typing import Any
import numpy as np
from transformers import PreTrainedTokenizerBase
@@ -81,11 +81,9 @@ class DataCollatorForSeq2Seq:
padding_side = self.tokenizer.padding_side
for feature in features:
remainder_len = max_feature_length - len(feature[feature_name])
if feature_name == "position_ids":
remainder = list(range(remainder_len))
else:
remainder = [pad_token_id] * remainder_len
remainder = [pad_token_id] * (
max_feature_length - len(feature[feature_name])
)
if isinstance(feature[feature_name], list):
feature[feature_name] = (
feature[feature_name] + remainder
@@ -163,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features: List[List[dict]] = [features]
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():

View File

@@ -1,6 +1,5 @@
"""Init for context manager submodule"""
"""Init for context manager submodule."""
# pylint: disable=unused-import
# flake8: noqa
from .context_parallel.manager import ContextParallelContextManager
from .sequence_parallel import SequenceParallelContextManager
__all__ = ["ContextParallelContextManager"]

View File

@@ -0,0 +1,146 @@
# BSD 3-Clause License
# Copyright 2024 Meta
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,this list
# of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice, this
# list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its contributors may
# be used to endorse or promote products derived from this software without specific
# prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
"""
Distributed utils for SDPA context parallel implementation. Slightly modified from
https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5c2/torchtune/training/_distributed.py.
"""
import contextlib
from typing import Callable, Generator, Optional, Union
import torch
from torch import nn
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import BlockMask
def _get_sdpa_context() -> (
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
):
"""
Creates a context manager to confine to flash/efficient/cuDNN attention backends.
Returns:
A context manager function that takes an optional context parallel context.
"""
@contextlib.contextmanager
def context(cp_context: Union[Generator[None, None, None], None] = None):
with contextlib.ExitStack() as stack:
if cp_context is not None:
stack.enter_context(
sdpa_kernel(
[
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
]
)
)
stack.enter_context(cp_context)
yield
return context
def get_context_parallel_manager(
*,
world_mesh: torch.distributed.DeviceMesh,
model: nn.Module,
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
"""
Context manager for applying context parallelism to a model. In addition to applying the
standard context manager to patch SDPA and shard model inputs and buffers along the sequence
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
Args:
world_mesh: Global device mesh.
model: Model to apply context parallelism to.
Returns:
A context manager applying context parallelism if enabled is True. Otherwise a context manager
disabling the math SDPA backend.
Raises:
ValueError: if enabled is True but world_mesh does not contain a "cp" dimension
"""
if "cp" not in world_mesh.mesh_dim_names:
raise ValueError(
"Context parallel is enabled but no context parallel device mesh is provided."
)
# TODO: context parallel for multimodal models requires extra work
# if not isinstance(model, TransformerDecoder):
# raise ValueError("Context parallel is only supported for text models")
# model_buffers = list(model.buffers())
# def get_all_buffers(module, prefix=""):
# buffers = {}
# for name, buffer in module.named_buffers(recurse=False):
# full_name = f"{prefix}.{name}" if prefix else name
# buffers[full_name] = buffer
# for name, child in module.named_children():
# child_prefix = f"{prefix}.{name}" if prefix else name
# buffers.update(get_all_buffers(child, child_prefix))
# return buffers
# model_buffers = get_all_buffers(model)
@contextlib.contextmanager
def context(model_inputs: list[torch.Tensor]):
# Create context parallel context if enabled
cp_context = None
if any([isinstance(input, BlockMask) for input in model_inputs]):
raise ValueError(
"Context parallel with flex attention is not yet supported"
)
set_rotate_method("allgather")
cp_context = context_parallel(
world_mesh["cp"],
# buffers=model_inputs + model_buffers,
buffers=model_inputs,
# buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
buffer_seq_dims=[1] * len(model_inputs),
no_restore_buffers=set(model_inputs),
)
# Create and enter the train context with the optional cp_context
sdpa_context = _get_sdpa_context()
with sdpa_context(cp_context):
yield
return context

View File

@@ -0,0 +1,216 @@
"""Module for Axolotl trainer context parallelism manager and utilities."""
import functools
import inspect
from typing import Callable, Literal
import torch
import torch.distributed as dist
from torch.utils.hooks import RemovableHandle
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
)
from axolotl.utils.ctx_managers.context_parallel.distributed import (
get_context_parallel_manager,
)
from axolotl.utils.ctx_managers.context_parallel.utils import (
AllGatherWithGrad,
apply_context_parallelism,
)
from axolotl.utils.schemas.enums import RingAttnFunc
class ContextParallelContextManager:
"""Context manager for context parallelism operations.
This class provides a context that will automatically apply context parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the context parallelism group using a post-forward hook.
Args:
models: List of models to apply context parallelism to pre- and post- forward
hooks.
backend: Which attention backend to use.
context_parallel_degree: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Context parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
"""
def __init__(
self,
models: list[PreTrainedModel],
backend: Literal["sdp_attention", "flash_attention"],
context_parallel_degree: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
):
self.models = models
self.backend = backend
self.context_parallel_degree = context_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
self._register_ring_attn()
# Store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
if self.backend == "flash_attention":
# Set distributed info for local rank
self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Create a partially applied version of the apply_context_parallelism function
self.apply_context_parallelism = functools.partial(
apply_context_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func,
)
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
else:
# SPDA device mesh init
world_size = dist.get_world_size()
mesh_shape = (
world_size // self.context_parallel_degree,
self.context_parallel_degree,
)
world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
# SDPA context parallel managers
self.context_parallel_managers = []
for model in models:
ctx_manager = get_context_parallel_manager(
world_mesh=world_mesh,
model=model,
)
self.context_parallel_managers.append(ctx_manager)
def __enter__(self):
self._register_model_hooks()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
def _register_ring_attn(self):
if self.backend == "flash_attention":
# Initialize ring attn for context parallelism
register_ring_attn(
context_parallel_degree=self.context_parallel_degree,
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,
)
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(context_parallel_degree=self.context_parallel_degree)
def _register_model_hooks(self):
# Forward pre-hook to apply context parallelism
def cp_flash_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
# Apply context parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_context_parallelism(updated_kwargs)
)
return remaining_args, updated_kwargs
# Forward post-hook to gather outputs
def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput:
# Gather the sharded outputs
output = self._gather_outputs(output)
# Remove padding if it was added
if self.pad_len > 0:
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
if value.size(1) == self.original_seq_len + self.pad_len:
# Slice to remove padding
output[key] = value[:, : self.original_seq_len].contiguous()
return output
def make_sdpa_pre_hook(model_idx: int) -> Callable:
def cp_sdpa_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1}
with self.context_parallel_managers[model_idx](list(to_shard.values())):
return remaining_args, updated_kwargs
return cp_sdpa_pre_hook
# Register both hooks
for i, model in enumerate(self.models):
if self.backend == "flash_attention":
self.hook_handles.append(
model.register_forward_pre_hook(cp_flash_pre_hook, with_kwargs=True)
)
self.hook_handles.append(
model.register_forward_hook(cp_flash_post_hook)
)
else:
self.hook_handles.append(
model.register_forward_pre_hook(
make_sdpa_pre_hook(i), with_kwargs=True
)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
output[key] = AllGatherWithGrad.apply(value, self.process_group)
return output

View File

@@ -1,28 +1,15 @@
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
import functools
import inspect
"""Utils for context parallel context manager."""
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
update_ring_attn_params,
)
from axolotl.monkeypatch.ring_attn.patch import update_ring_attn_params
from axolotl.utils.schemas.enums import RingAttnFunc
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
def apply_sequence_parallelism(
def apply_context_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
@@ -30,15 +17,15 @@ def apply_sequence_parallelism(
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
) -> tuple[dict[str, torch.Tensor], int, int]:
"""
Apply sequence parallelism slicing to a batch.
Apply context parallelism slicing to a batch.
Special handling is implemented for integer logits_to_keep, which indicates
to only keep the last N tokens in the sequence during generation.
to only keep the last N tokens in the input sequence during generation.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
local_rank: Local rank in the context parallel group.
local_world_size: World size of the context parallel group.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused, but
related to above TODO.
@@ -133,7 +120,7 @@ def apply_sequence_parallelism(
# Update the total sequence length after padding
total_seq_len = batch["input_ids"].size(1)
# Slice batch for sequence parallel
# Slice batch for context parallel
for key in batch:
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
continue
@@ -159,144 +146,6 @@ def apply_sequence_parallelism(
return batch, original_seq_len, pad_len
class SequenceParallelContextManager:
"""Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
Args:
models: List of models to apply sequence parallelism to pre- and post- forward
hooks.
sequence_parallel_degree: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
"""
def __init__(
self,
models: list[nn.Module],
sequence_parallel_degree: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
):
self.models = models
self.sequence_parallel_degree = sequence_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
self._register_ring_attn()
# Set distributed info for local rank
self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
# Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
self._register_model_hooks()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism
register_ring_attn(
sequence_parallel_degree=self.sequence_parallel_degree,
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,
)
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)
def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
# Apply sequence parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(updated_kwargs)
)
return remaining_args, updated_kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
# Gather the sharded outputs
output = self._gather_outputs(output)
# Remove padding if it was added
if self.pad_len > 0:
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
if value.size(1) == self.original_seq_len + self.pad_len:
# Slice to remove padding
output[key] = value[:, : self.original_seq_len].contiguous()
return output
# Register both hooks
for model in self.models:
self.hook_handles.append(
model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
output[key] = AllGatherWithGrad.apply(value, self.process_group)
return output
class AllGatherWithGrad(torch.autograd.Function):
"""Custom autograd function for all-gather to preserve gradients."""

View File

@@ -1,21 +1,16 @@
"""Init for `axolotl.utils.data` module."""
"""
Data processing modules
"""
from axolotl.utils.data.pretraining import (
from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import (
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper,
prepare_datasets,
load_prepare_datasets,
load_tokenized_prepared_datasets,
prepare_dataset,
)
from axolotl.utils.data.utils import md5
__all__ = [
"encode_pretraining",
"wrap_pretraining_dataset",
"prepare_preference_datasets",
"get_dataset_wrapper",
"prepare_datasets",
"md5",
]
from axolotl.utils.data.utils import md5 # noqa: F401

View File

@@ -1,66 +0,0 @@
"""Logic for loading / preparing a dataset once over all processes."""
import time
from pathlib import Path
from typing import Any, Callable
from filelock import FileLock
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.dict import DictDefault
LOCK_FILE_NAME = "datasets_prep.lock"
READY_FILE_NAME = "datasets_ready.flag"
PROCESS_COUNTER_FILE_NAME = "process_counter.txt"
class FileLockLoader:
"""
Simple class for abstracting single process data loading / processing. The first
process that creates a lock file does the work; the remaining procesees simply load
the preprocessed dataset once the first process is done.
"""
def __init__(self, cfg: DictDefault):
self.cfg = cfg
self.dataset_prepared_path = (
cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
)
self.lock_file_path = Path(self.dataset_prepared_path) / LOCK_FILE_NAME
self.ready_flag_path = Path(self.dataset_prepared_path) / READY_FILE_NAME
self.counter_path = Path(self.dataset_prepared_path) / PROCESS_COUNTER_FILE_NAME
def load(self, load_fn: Callable[[], Any]) -> Any:
with FileLock(str(self.lock_file_path)):
self._increment_counter()
if not self.ready_flag_path.exists():
result = load_fn()
self.ready_flag_path.touch()
return result
while not self.ready_flag_path.exists():
time.sleep(1)
return load_fn()
def _increment_counter(self):
"""Safely increment the process counter."""
if self.counter_path.exists():
count = int(self.counter_path.read_text().strip())
else:
count = 0
self.counter_path.write_text(str(count + 1))
def cleanup(self):
"""Clean up ready flag when last process is done."""
with FileLock(str(self.lock_file_path)):
count = int(self.counter_path.read_text().strip())
count -= 1
if count == 0:
# Last process cleans everything up
self.ready_flag_path.unlink(missing_ok=True)
self.counter_path.unlink(missing_ok=True)
else:
# Still have active processes
self.counter_path.write_text(str(count))

View File

@@ -250,7 +250,7 @@ def encode_packed_pretraining(
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
train_dataset = process_pretraining_datasets_for_packing(
train_dataset,

View File

@@ -1,117 +1,75 @@
"""Data handling specific to RL trainers."""
"""data handling specific to DPO"""
import inspect
from functools import partial
from typing import Any, Callable, Literal
from pathlib import Path
from typing import Any, List, Union
from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
import yaml
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.loaders import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.lock import FileLockLoader
from axolotl.utils.data.shared import (
create_train_validation_split,
datasets_with_name_generator,
generate_dataset_hash_from_config,
load_dataset_with_config,
load_preprocessed_dataset,
merge_datasets,
save_preprocessed_dataset,
try_load_from_hub,
)
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
retry_on_request_exceptions,
)
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
LOG = get_logger(__name__)
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_preference_datasets(
cfg: DictDefault, tokenizer: PreTrainedTokenizer
) -> tuple[Dataset, Dataset | None]:
"""Load and prepare preference datasets for RL training.
def _get_path(ds_hash, cfg):
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
)
Loads training and evaluation datasets, handling preprocessing, caching, and
deduplication as configured. Uses FileLock for distributed coordination.
Args:
cfg: Configuration object containing dataset and training settings.
tokenizer: Tokenizer to use for processing text.
Returns:
Tuple of (train_dataset, eval_dataset). eval_dataset may be None
if no evaluation dataset is configured.
"""
def _load_datasets():
# Load training dataset
train_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="train")
# Load or create evaluation dataset
eval_dataset: Dataset | None = None
if cfg.test_datasets:
eval_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="test")
elif cfg.val_set_size:
# Create validation split from training data
train_dataset, eval_dataset = create_train_validation_split(
train_dataset, cfg, cfg.val_set_size
)
return train_dataset, eval_dataset
# Prepare datasets (with file locking logic for multiple ranks)
loader = FileLockLoader(cfg)
try:
train_dataset, eval_dataset = loader.load(_load_datasets)
finally:
loader.cleanup()
# Apply deduplication if configured
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset = deduplicate_and_log_datasets(
dataset=train_dataset, other_dataset=eval_dataset
)
return train_dataset, eval_dataset
return prepared_ds_path
def _map_dataset(
cfg: DictDefault,
dataset: Dataset | DatasetDict,
ds_transform_fn: Callable[..., Any],
tokenizer: Any | None = None,
**map_kwargs: Any,
) -> Dataset:
"""Apply transformation function to dataset.
def _load_preprocessed_ds(cfg, sub_cfg):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)
dataset = None
Args:
cfg: Configuration object.
dataset: Dataset to transform.
ds_transform_fn: Transformation function to apply.
tokenizer: Optional tokenizer for transformation.
**map_kwargs: Additional arguments for dataset mapping.
# pylint: disable=duplicate-code
if (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
Returns:
Transformed dataset.
"""
return dataset
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)
if cfg.is_preprocess and is_main_process():
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset.save_to_disk(str(prepared_ds_path))
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters:
if not tokenizer:
tokenizer = load_tokenizer(cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
if isinstance(dataset, DatasetDict):
dataset = dataset["train"]
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
dataset = dataset.map(
data_set = data_set.map(
ds_transform_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
@@ -119,27 +77,13 @@ def _map_dataset(
**map_kwargs,
)
return dataset
return data_set
def _drop_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
"""Filter out samples that exceed maximum sequence length.
Args:
sample: Dataset sample to check.
rl: Reinforcement learning type.
tokenizer: Tokenizer for length calculation.
sequence_len: Maximum allowed sequence length.
Returns:
True if sample should be kept, False if it should be dropped.
Raises:
ValueError: If required keys are missing or RL type is unknown.
"""
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
):
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
@@ -179,115 +123,132 @@ def _drop_long_sequences(
raise ValueError("Unknown RL type")
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training.
def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token
for config_dataset in datasets_w_name_generator(dataset_cfgs):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=False
)
split_datasets.append(ds)
Args:
cfg: Configuration object containing dataset settings.
split: Dataset split to load ("train" or "test").
tokenizer = load_tokenizer(cfg)
Returns:
Combined and processed dataset for the specified split.
"""
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
split_datasets: list[Dataset | DatasetDict] = []
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if _cfg.rl is RLType.ORPO:
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
for dataset_config in datasets_with_name_generator(datasets_configs):
dataset: Dataset | DatasetDict = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=False
)
split_datasets.append(dataset)
tokenizer = load_tokenizer(cfg)
for i, dataset in enumerate(split_datasets):
_type = datasets_configs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if cfg.rl is RLType.ORPO:
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
elif cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
else:
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
map_kwargs: dict[str, Any] = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = _map_dataset(
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
)
if not cfg.skip_prepare_dataset:
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
return combined_datasets
with zero_first(is_main_process()):
train_is_preprocessed = False
eval_is_preprocessed = False
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
train_is_preprocessed = True
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen", and "rejected" already preprocessed
split_datasets[i] = dataset
train_dataset = load_split(cfg.datasets, cfg)
if not cfg.skip_prepare_dataset:
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
eval_dataset = None
if cfg.test_datasets:
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
eval_is_preprocessed = True
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
if cfg.val_set_size:
seed = cfg.seed if cfg.seed is not None else 42
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ str(cfg.seed or 42)
)
to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ str(cfg.seed or 42)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
ds_w_test_split = train_dataset.train_test_split(
test_size=cfg.val_set_size,
seed=seed,
shuffle=False,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
eval_dataset = ds_w_test_split["test"]
train_dataset = ds_w_test_split["train"]
# Merge datasets
dataset = merge_datasets(split_datasets, cfg)
if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
if not cfg.skip_prepare_dataset:
# Save preprocessed dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset
# pylint: disable=duplicate-code
def _load_or_create_dataset_split(
cfg: DictDefault, tokenizer: PreTrainedTokenizer, split: Literal["train", "test"]
) -> Dataset:
"""Load preprocessed dataset or create new one for given split.
Args:
cfg: Configuration object.
tokenizer: Tokenizer to use for processing text.
split: Dataset split to load.
Returns:
Tuple of (dataset, is_preprocessed).
"""
# Select correct dataset configuration based on split
datasets_config = cfg.datasets if split == "train" else cfg.test_datasets
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_config, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# Attempt to load preprocessed dataset
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# Otherwise, load it
if dataset is None:
dataset = _load_split(cfg, split=split)
return dataset
return train_dataset, eval_dataset

File diff suppressed because it is too large Load Diff

View File

@@ -1,21 +1,11 @@
"""Dataset loading shared utils."""
"""
dataset loading shared utils
"""
from __future__ import annotations
import functools
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generator
from typing import Optional, Union
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.errors import (
HFValidationError,
@@ -23,141 +13,78 @@ from huggingface_hub.errors import (
RevisionNotFoundError,
)
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from adlfs import AzureBlobFileSystem
from gcsfs import GCSFileSystem
from ocifs import OCIFileSystem
from s3fs import S3FileSystem
LOG = get_logger(__name__)
EXTENSIONS_TO_DATASET_TYPES = {
".parquet": "parquet",
".arrow": "arrow",
".csv": "csv",
".txt": "text",
}
def get_dataset_type(dataset_config: DictDefault) -> str:
"""Get the dataset type from the path if it's not specified."""
if dataset_config.ds_type:
return dataset_config.ds_type
for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items():
if extension in dataset_config.path:
return dataset_type
return "json"
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def datasets_with_name_generator(
dataset_configs: list[DictDefault],
) -> Generator[DictDefault, None, None]:
"""Yields expanded dataset configurations based on multiple names or preprocessing
shards.
When a dataset config has a list of names, it yields separate configs for each
name. When a dataset config specifies preprocessing shards, it yields configs for
each shard.
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
"""
Yields dataset configs handling multiple names or preprocess_shards
Args:
dataset_configs: List of dataset configuration objects.
Yields:
Individual dataset configurations, expanded as needed for names or shards.
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
"""
for config in dataset_configs:
if config.name and isinstance(config.name, list):
for name in config.name:
yield DictDefault({**config, "name": name})
elif config.preprocess_shards and not config.shards:
for shard_idx in range(config.preprocess_shards):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**config,
"shards": config.preprocess_shards,
"shards_idx": shard_idx,
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield config
yield dataset
def load_dataset_with_config(
dataset_config: DictDefault, use_auth_token: bool, streaming=False
) -> Dataset | IterableDataset:
"""Load a dataset from a config. Handles datasets that are stored locally, in the
HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or
`data_files`.
def load_dataset_w_config(
config_dataset: DictDefault, use_auth_token: bool, streaming=False
) -> Union[Dataset, DatasetDict]:
"""
Load a dataset from a config
Args:
dataset_config: Single dataset config.
use_auth_token: Whether to use HF auth token.
streaming: Whether to stream the dataset.
Returns:
Loaded dataset.
config_dataset: single dataset config
use_auth_token: whether to use HF auth token
streaming: whether to stream the dataset
"""
# Set up common kwargs for dataset loading
load_dataset_kwargs = {
"split": dataset_config.split if dataset_config.split else None,
"name": dataset_config.name,
"streaming": streaming,
"trust_remote_code": dataset_config.trust_remote_code,
}
# First check if it's a local path
if Path(dataset_config.path).exists():
return _load_from_local_path(dataset_config, load_dataset_kwargs)
# Check if it's a HuggingFace dataset
is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token)
# Check if it's a cloud storage path and get appropriate filesystem
remote_fs, storage_options = _get_remote_filesystem(dataset_config.path)
is_cloud_dataset = False
if remote_fs:
try:
is_cloud_dataset = remote_fs.exists(dataset_config.path)
except (FileNotFoundError, ConnectionError):
pass
# Load from appropriate source
if is_hub_dataset:
return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs)
if is_cloud_dataset:
return _load_from_cloud(
dataset_config, remote_fs, storage_options, load_dataset_kwargs
)
if dataset_config.path.startswith("https://"):
return _load_from_url(dataset_config, load_dataset_kwargs)
if dataset_config.data_files:
return _load_from_data_files(dataset_config, load_dataset_kwargs)
raise ValueError(
f"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({dataset_config.path}). Try double-check your path / name / data_files. "
f"This is not caused by the dataset type."
)
def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool:
"""Check if a dataset exists on the HuggingFace Hub."""
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
snapshot_download(
repo_id=dataset_config.path,
repo_id=config_dataset.path,
repo_type="dataset",
token=use_auth_token,
revision=dataset_config.revision,
revision=config_dataset.revision,
ignore_patterns=["*"],
)
return True
ds_from_hub = True
except (
RepositoryNotFoundError,
RevisionNotFoundError,
@@ -166,373 +93,198 @@ def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) ->
HFValidationError,
ValueError,
):
return False
pass
def _get_remote_filesystem(
path: str,
) -> tuple[
S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict
]:
"""Get the appropriate filesystem for a remote path."""
if path.startswith("s3://"):
ds_from_cloud = False
storage_options: dict = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import s3fs
storage_options = {"anon": False}
return s3fs.S3FileSystem(**storage_options), storage_options
import s3fs # type: ignore
except ImportError as exc:
raise ImportError("s3:// paths require s3fs to be installed") from exc
elif path.startswith(("gs://", "gcs://")):
# Reads env, credentials from ~/.aws/credentials, or IAM metadata provider
# https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials
storage_options = {"anon": False}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
"gcs://"
):
try:
import gcsfs
storage_options = {"token": None} # type: ignore
return gcsfs.GCSFileSystem(**storage_options), storage_options
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
elif path.startswith(("adl://", "abfs://", "az://")):
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
elif (
config_dataset.path.startswith("adl://")
or config_dataset.path.startswith("abfs://")
or config_dataset.path.startswith("az://")
):
try:
import adlfs
storage_options = {"anon": False}
return adlfs.AzureBlobFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError(
"adl:// or abfs:// paths require adlfs to be installed"
) from exc
elif path.startswith("oci://"):
# # Ensure you have the following environment variables set:
# # Gen 1
# storage_options = {
# "tenant_id": AZURE_STORAGE_TENANT_ID,
# "client_id": AZURE_STORAGE_CLIENT_ID,
# "client_secret": AZURE_STORAGE_CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": AZURE_STORAGE_ACCOUNT_NAME,
# "account_key": AZURE_STORAGE_ACCOUNT_KEY,
# }
# Reads env
# https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
storage_options = {"anon": False}
remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
elif config_dataset.path.startswith("oci://"):
try:
import ocifs
storage_options = {}
return ocifs.OCIFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError("oci:// paths require ocifs to be installed") from exc
return None, {}
# https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables
remote_file_system = ocifs.OCIFileSystem(**storage_options)
def _load_from_local_path(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from a local path."""
local_path = Path(dataset_config.path)
if local_path.is_dir():
if dataset_config.data_files:
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.data_files,
**load_dataset_kwargs,
)
try:
return load_from_disk(dataset_config.path)
except FileNotFoundError:
load_dataset_kwargs["streaming"] = False
return load_dataset(dataset_config.path, **load_dataset_kwargs)
elif local_path.is_file():
dataset_type = get_dataset_type(dataset_config)
load_dataset_kwargs["streaming"] = False
return load_dataset(
dataset_type,
data_files=dataset_config.path,
**load_dataset_kwargs,
)
else:
raise ValueError(
"Unhandled dataset load: local path exists, but is neither a directory or a file"
)
def _load_from_hub(
dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from the HuggingFace Hub."""
return load_dataset(
dataset_config.path,
data_files=dataset_config.data_files,
token=use_auth_token,
revision=dataset_config.revision,
**load_dataset_kwargs,
)
def _load_from_cloud(
dataset_config: DictDefault,
remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem,
storage_options: dict,
load_dataset_kwargs: dict,
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from cloud storage."""
if remote_fs.isdir(dataset_config.path):
return load_from_disk(
dataset_config.path,
storage_options=storage_options,
)
if remote_fs.isfile(dataset_config.path):
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.path,
storage_options=storage_options,
**load_dataset_kwargs,
)
raise ValueError(
f"Cloud path {dataset_config.path} is neither a directory nor a file"
)
def _load_from_url(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from a URL."""
dataset_type = get_dataset_type(dataset_config)
return load_dataset(
dataset_type,
data_files=dataset_config.path,
**load_dataset_kwargs,
)
def _load_from_data_files(
dataset_config: DictDefault, load_dataset_kwargs: dict
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
"""Load a dataset from data files."""
file_path = None
if isinstance(dataset_config.data_files, str):
file_path = hf_hub_download(
repo_id=dataset_config.path,
repo_type="dataset",
filename=dataset_config.data_files,
revision=dataset_config.revision,
)
elif isinstance(dataset_config.data_files, list):
file_path = [
hf_hub_download(
repo_id=dataset_config.path,
repo_type="dataset",
filename=file,
revision=dataset_config.revision,
)
for file in dataset_config.data_files
]
else:
raise ValueError("data_files must be either a string or list of strings")
return load_dataset("json", data_files=file_path, **load_dataset_kwargs)
def generate_split_fingerprints(
dataset: Dataset, val_set_size: int | float, seed: int
) -> tuple[str, str]:
"""Generate consistent fingerprints for train/test splits."""
fingerprint = dataset._fingerprint # pylint: disable=protected-access
train_hash_input = f"{fingerprint}|{val_set_size}|train|{seed}"
test_hash_input = f"{fingerprint}|{val_set_size}|test|{seed}"
train_fingerprint = md5(train_hash_input)
test_fingerprint = md5(test_hash_input)
return train_fingerprint, test_fingerprint
def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path:
"""Get standardized path for prepared datasets.
Args:
cfg: Configuration object.
dataset_hash: Hash identifying the specific dataset configuration.
Returns:
Path where the prepared dataset should be stored.
"""
base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
return Path(base_path) / dataset_hash
def create_train_validation_split(
dataset: Dataset, cfg: DictDefault, val_set_size: int | float
) -> tuple[Dataset, Dataset]:
"""Create train/validation split with consistent fingerprinting.
Args:
dataset: Dataset to split.
cfg: Configuration object containing seed and other settings.
val_set_size: Size of validation set (absolute number or fraction).
Returns:
Tuple of (train_dataset, eval_dataset).
"""
train_fingerprint, test_fingerprint = generate_split_fingerprints(
dataset, val_set_size, cfg.seed
)
# Apply deduplication before splitting if configured
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
split_dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
seed=cfg.seed,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
return split_dataset["train"], split_dataset["test"]
def _generate_from_iterable_dataset(
dataset: IterableDataset, worker_id: list[int], num_workers: list[int]
) -> Generator[Any, None, None]:
"""Generator function to correctly split the dataset for each worker"""
for i, item in enumerate(dataset):
if i % num_workers[0] == worker_id[0]:
yield item
def save_preprocessed_dataset(
cfg: DictDefault,
dataset: Dataset,
dataset_hash: str,
split: str,
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
if isinstance(dataset, IterableDataset):
num_workers = cfg.dataset_processes
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),
features=dataset.features,
num_proc=num_workers,
split=split,
gen_kwargs={
"worker_id": list(range(num_workers)),
"num_workers": [num_workers] * num_workers,
},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
"Pushing merged prepared dataset to Huggingface hub at "
f"{cfg.push_dataset_to_hub} (version {dataset_hash})...",
main_process_only=False,
)
dataset.push_to_hub(
cfg.push_dataset_to_hub,
dataset_hash,
private=True,
)
def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None:
"""Load preprocessed dataset from disk if available.
Args:
cfg: Configuration object.
dataset_hash: Hash identifying the dataset configuration.
Returns:
Loaded dataset if found and conditions are met, None otherwise.
"""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
if (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.skip_prepare_dataset
and not cfg.is_preprocess
):
LOG.info(
f"Loading prepared dataset from disk at {prepared_ds_path}...",
main_process_only=False,
)
return load_from_disk(str(prepared_ds_path))
LOG.info(
f"Unable to find prepared dataset in {prepared_ds_path}",
main_process_only=False,
)
return None
def try_load_from_hub(
cfg: DictDefault, dataset_hash: str, split: str
) -> Dataset | None:
"""Try to load the prepared dataset from HuggingFace Hub."""
try:
LOG.info(
"Attempting to load prepared dataset from HuggingFace Hub at "
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
)
dataset = load_dataset(
cfg.push_dataset_to_hub,
dataset_hash,
token=cfg.hf_use_auth_token,
)
return dataset[split]
except Exception: # pylint: disable=broad-except # nosec
LOG.info("Unable to find prepared dataset in HuggingFace Hub")
return None
if remote_file_system and remote_file_system.exists(config_dataset.path):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
def generate_dataset_hash_from_config(
cfg: DictDefault, cfg_datasets: list, tokenizer_name: str
) -> str:
"""Generate a hash to uniquely identify a dataset configuration for SFT.
Args:
cfg: Main configuration object.
cfg_datasets: List of dataset configurations.
tokenizer_name: Name of the tokenizer being used.
Returns:
MD5 hash string representing the configuration.
"""
config_str = (
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
f"|{tokenizer_name}"
)
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
datasets: List of datasets to merge.
cfg: Configuration object containing shuffle settings.
Returns:
Merged dataset.
"""
if len(datasets) == 1:
return datasets[0]
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...")
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
# gather extra args from the config
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
else:
LOG.debug("Not shuffling merged datasets.")
load_ds_kwargs["split"] = None
return merged_dataset
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
if config_dataset.data_files:
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=streaming,
**load_ds_kwargs,
)
else:
try:
ds = load_from_disk(
config_dataset.path
) # pylint: disable=invalid-name
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
**load_ds_kwargs,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
**load_ds_kwargs,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=streaming,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=streaming,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=streaming,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.data_files:
fp: str | list[str] | None = None
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
raise ValueError("data_files must be either a string or list of strings")
ds = load_dataset(
"json",
name=config_dataset.name,
data_files=fp,
streaming=streaming,
**load_ds_kwargs,
)
if not ds:
raise ValueError(
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({config_dataset.path}). Try double-check your path / name / data_files. "
"This is not caused by the dataset type."
)
return ds

View File

@@ -1,11 +1,9 @@
"""Data handling helpers"""
"""data handling helpers"""
import contextlib
import functools
import hashlib
import time
from enum import Enum
from typing import Callable
import huggingface_hub
import numpy as np
@@ -21,7 +19,9 @@ LOG = get_logger(__name__)
class RetryStrategy(Enum):
"""Enum for retry strategies."""
"""
Enum for retry strategies.
"""
CONSTANT = 1
LINEAR = 2
@@ -30,18 +30,7 @@ class RetryStrategy(Enum):
def retry_on_request_exceptions(
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
) -> Callable:
"""Decorator that retries function calls on specific request exceptions.
Args:
max_retries: Maximum number of retry attempts.
delay: Base delay between retries in seconds.
retry_strategy: Strategy for calculating retry delays.
Returns:
Decorated function with retry logic.
"""
):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
@@ -51,7 +40,6 @@ def retry_on_request_exceptions(
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1:
@@ -71,7 +59,6 @@ def retry_on_request_exceptions(
def md5(to_hash: str, encoding: str = "utf-8") -> str:
"""Generate MD5 hash of a string."""
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
@@ -79,89 +66,102 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
"""Generate SHA256 hash of a string."""
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
def _deduplicate_dataset(
dataset: Dataset,
seen_hashes: set[str] | None = None,
) -> tuple[Dataset, set[str]]:
"""Remove duplicate rows from a dataset using SHA256 hashes.
Args:
dataset: Dataset to deduplicate.
seen_hashes: Set of previously seen row hashes (for cross-deduplication).
Returns:
Tuple of deduplicated dataset and the set of seen hashes.
"""
if seen_hashes is None:
seen_hashes = set()
def deduplicate_dataset(
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
) -> Dataset:
unique_indices = []
for idx, row in enumerate(dataset):
row_hash = sha256(str(row)) # Using SHA256 for collision resistance
if row_hash not in seen_hashes:
seen_hashes.add(row_hash)
unique_indices.append(idx)
return dataset.select(unique_indices), seen_hashes
for idx, row in enumerate(dataset):
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
if row_hash not in seen_hashes:
seen_hashes[row_hash] = [idx]
unique_indices.append(idx)
else:
# Check for collision by looking up the original dataset indices
original_indices = seen_hashes[row_hash]
is_duplicate = False
for original_idx in original_indices:
if (
not idx == original_idx
and original_idx < len(dataset)
and str(dataset[original_idx]) == str(row)
):
is_duplicate = True
break
# Check in the other dataset if provided
if other_dataset is not None:
if original_idx < len(other_dataset) and str(
other_dataset[original_idx]
) == str(row):
is_duplicate = True
break
if not is_duplicate:
seen_hashes[row_hash].append(idx)
unique_indices.append(idx)
continue
return dataset.select(unique_indices)
def deduplicate_and_log_datasets(
dataset: Dataset,
other_dataset: Dataset | None = None,
dataset_name: str | None = "train",
other_name: str | None = "eval",
) -> tuple[Dataset, Dataset | None]:
"""Deduplicate datasets, with optional cross-dataset deduplication.
Args:
dataset: Primary dataset to deduplicate.
other_dataset: Optional second dataset to deduplicate against the first.
dataset_name: Name for the primary dataset (for logging).
other_name: Name for the second dataset (for logging).
*,
train_dataset: Dataset = None,
eval_dataset: Dataset = None,
dataset: Dataset = None,
) -> tuple[Dataset, Dataset, Dataset]:
"""
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
Returns:
Tuple of (deduplicated_dataset, deduplicated_other_dataset).
tuple: Deduplicated train, eval, and additional datasets.
"""
# Deduplicate primary dataset
LOG.info(
f"Starting deduplication for {dataset_name} dataset. Original size: {len(dataset)}"
)
dataset, seen_rows = _deduplicate_dataset(dataset)
LOG.info(
f"Deduplication complete for {dataset_name} dataset. New size: {len(dataset)}"
)
seen_hashes: dict[str, list[int]] = {}
# Deduplicate second dataset if provided
if other_dataset is not None:
# Handle cases where datasets are None
if train_dataset is not None:
LOG.info(
f"Starting deduplication for {other_name} dataset. Original size: {len(other_dataset)}"
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
)
train_dataset = deduplicate_dataset(
dataset=train_dataset, seen_hashes=seen_hashes
)
other_dataset, _ = _deduplicate_dataset(other_dataset, seen_rows)
LOG.info(
f"Deduplication complete for {other_name} dataset. New size: {len(other_dataset)}"
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
)
else:
LOG.info("Train dataset is None. Skipping deduplication.")
if eval_dataset is not None:
LOG.info(
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
)
eval_dataset = deduplicate_dataset(
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
)
LOG.info(
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
)
else:
LOG.info("Eval dataset is None. Skipping deduplication.")
if dataset is not None and (eval_dataset is None and train_dataset is None):
LOG.info(
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
)
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
LOG.info(
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
)
return dataset, other_dataset
return train_dataset, eval_dataset, dataset
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
"""Remove sequences longer than configured maximum from dataset.
Args:
dataset: Dataset to filter.
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Filtered dataset with long sequences removed.
"""
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
)
return dataset
@@ -171,14 +171,20 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
min_sequence_len=cfg.min_sample_len,
)
with contextlib.suppress(AttributeError):
try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
except AttributeError:
pass
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
try:
prior_len = len(dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):

View File

@@ -1,425 +0,0 @@
"""Data handling specific to SFT."""
import logging
from typing import Any, NoReturn, cast
from datasets import (
Dataset,
IterableDataset,
Sequence,
Value,
)
from transformers import PreTrainedTokenizer
from transformers.processing_utils import ProcessorMixin
from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt
from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
DatasetWrappingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
PromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
AlpacaPrompter,
GPTeacherPrompter,
JeopardyPrompter,
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
Prompter,
ReflectAlpacaPrompter,
SummarizeTLDRPrompter,
UnsupportedPrompter,
)
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def handle_unknown_dataset_strategy(dataset_config: DictDefault) -> NoReturn:
"""Raise error for unknown dataset strategy."""
ds_type = dataset_config.type
suffix = ""
if ":load_" in ds_type:
suffix = f"Did you mean {ds_type.replace(':load_', '.load_')}?"
error_message = f"unhandled prompt tokenization strategy: {ds_type}. {suffix}"
LOG.error(error_message)
raise ValueError(error_message)
# pylint: disable=too-many-return-statements
def get_dataset_wrapper(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset_base_type: str | None,
dataset: Dataset | IterableDataset,
dataset_prompt_style: str | None = None,
processor: ProcessorMixin | None = None, # pylint: disable=unused-argument
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Create an appropriate dataset wrapper and prompter based on dataset
configuration.
Args:
dataset_config: Configuration for the dataset.
tokenizer: Tokenizer to use for processing text.
cfg: Global configuration object.
dataset_base_type: The base type of the dataset.
dataset: The actual dataset object.
dataset_prompt_style: Optional prompt style specification.
processor: Optional processor for multimodal datasets.
Returns:
tuple of (dataset_wrapper, dataset_prompter).
"""
# Common parameters for dataset wrapping
dataset_kwargs: dict[str, Any] = {
"process_count": cfg.dataset_processes,
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}
LOG.info(
f"Loading dataset: {dataset_config['path']} with base_type: "
f"{dataset_base_type} and prompt_style: {dataset_prompt_style}"
)
# Dataset is already tokenized
if _is_dataset_already_tokenized(dataset):
return dataset, UnsupportedPrompter()
# Custom dataset type definition
if isinstance(dataset_config.type, DictDefault):
return _handle_custom_dataset_type(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Skip preparation if configured
if cfg.skip_prepare_dataset:
return dataset, None
# Bradley-Terry dataset
if dataset_config.type.startswith("bradley_terry"):
return _handle_bradley_terry_dataset(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Stepwise supervised dataset
if dataset_config.type.startswith("stepwise_supervised"):
return _handle_stepwise_supervised_dataset(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Try to load prompt tokenizer / dataset wrapper strategy from registry
dataset_strategy = load(
dataset_config.type, tokenizer, cfg, dataset_config, processor=processor
)
if dataset_strategy:
return _handle_loaded_strategy(dataset_strategy, dataset, dataset_kwargs)
# Known dataset types with specific handling
if dataset_base_type in DATASET_HANDLERS:
handler = DATASET_HANDLERS[dataset_base_type]
return handler(dataset_prompt_style, tokenizer, cfg, dataset, dataset_kwargs)
# Unhandled dataset type
handle_unknown_dataset_strategy(dataset_config)
def _is_dataset_already_tokenized(dataset: Dataset | IterableDataset) -> bool:
"""Check if the dataset is already tokenized."""
return (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
)
def _handle_custom_dataset_type(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a custom dataset type defined in the configuration."""
dataset_strategy = cast(
PromptTokenizingStrategy,
load("user_defined", tokenizer, cfg, dataset_config.type.to_dict()),
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_bradley_terry_dataset(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Handle a Bradley-Terry dataset."""
bt_type = dataset_config.type.split(".", 1)[1]
dataset_strategy = bradley_terry_load(bt_type, tokenizer, cfg, dataset_config)
if not dataset_strategy:
handle_unknown_dataset_strategy(dataset_config)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_stepwise_supervised_dataset(
dataset_config: DictDefault,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a stepwise supervised dataset."""
dataset_prompter = UnsupportedPrompter()
dataset_strategy = load(dataset_config.type, tokenizer, cfg, dataset_config)
# We need to explicitly cast boolean labels to int
# for compatibility with how trl's PRMTrainer works
if isinstance(dataset, Dataset):
dataset = dataset.cast_column("labels", Sequence(Value("int64")))
dataset_wrapper = TokenizedPromptDataset(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_loaded_strategy(
dataset_strategy: PromptTokenizingStrategy | DatasetWrappingStrategy,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Handle a dataset with a strategy loaded from the registry."""
if isinstance(dataset_strategy, DatasetWrappingStrategy):
return dataset_strategy.wrap_dataset(dataset, **dataset_kwargs), None
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_alpaca_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle an Alpaca dataset."""
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
dataset_strategy = AlpacaPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_explainchoice_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle an ExplainChoice dataset."""
dataset_prompter = MultipleChoiceExplainPrompter(dataset_prompt_style)
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_concisechoice_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a ConciseChoice dataset."""
dataset_prompter = MultipleChoiceConcisePrompter(dataset_prompt_style)
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_summarizetldr_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a SummarizeTLDR dataset."""
dataset_prompter = SummarizeTLDRPrompter(dataset_prompt_style)
dataset_strategy = SummarizeTLDRPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_jeopardy_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a Jeopardy dataset."""
dataset_prompter = JeopardyPrompter(dataset_prompt_style)
dataset_strategy = JeopardyPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_oasst_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle an OpenAssistant dataset."""
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
dataset_strategy = OpenAssistantPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_gpteacher_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a GPTeacher dataset."""
dataset_prompter = GPTeacherPrompter(dataset_prompt_style)
dataset_strategy = GPTeacherPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
def _handle_reflection_dataset(
dataset_prompt_style: str | None,
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
dataset: Dataset | IterableDataset,
dataset_kwargs: dict[str, Any],
) -> tuple[Dataset | IterableDataset, Prompter]:
"""Handle a Reflection dataset."""
dataset_prompter = ReflectAlpacaPrompter(dataset_prompt_style)
dataset_strategy = AlpacaReflectionPTStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
dataset_strategy,
dataset,
**dataset_kwargs,
)
return dataset_wrapper, dataset_prompter
DATASET_HANDLERS = {
"alpaca": _handle_alpaca_dataset,
"explainchoice": _handle_explainchoice_dataset,
"concisechoice": _handle_concisechoice_dataset,
"summarizetldr": _handle_summarizetldr_dataset,
"jeopardy": _handle_jeopardy_dataset,
"oasst": _handle_oasst_dataset,
"gpteacher": _handle_gpteacher_dataset,
"reflection": _handle_reflection_dataset,
}

View File

@@ -1,567 +0,0 @@
"""Wrapper for MistralTokenizer from mistral-common"""
import math
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Optional
import numpy as np
from huggingface_hub import hf_hub_download
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from torch import Tensor
from transformers.utils import PaddingStrategy
from axolotl.utils.collators.core import IGNORE_INDEX
if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import ChatCompletionRequest
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
"""Get the file path from local or HF Hub"""
if os.path.exists(path_or_repo_id):
maybe_file_path = os.path.join(path_or_repo_id, filename)
if os.path.exists(maybe_file_path):
return maybe_file_path
raise FileNotFoundError(f"File not found at {path_or_repo_id}")
return hf_hub_download(repo_id=path_or_repo_id, filename=filename)
class HFMistralTokenizer:
"""
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
and exposes HuggingFace API for special tokens.
"""
def __init__(
self, mistral: MistralTokenizer, name_or_path: str, tokenizer_path: str
):
"""
Args:
mistral: The mistral-common tokenizer to wrap.
name_or_path: The name or path to the tokenizer files or the repo id.
"""
self._mistral = mistral
self._padding_side = "right"
self._name_or_path = name_or_path
self._tokenizer_path = tokenizer_path
# Manual set to training mode
from mistral_common.protocol.instruct.validator import (
MistralRequestValidator,
ValidationMode,
)
# Check if MistralRequestValidator has a _mode attribute.
# This is a private API and may change in the future.
# pylint: disable=protected-access
if not (
hasattr(self._mistral, "_chat_completion_request_validator")
and isinstance(
self._mistral._chat_completion_request_validator,
MistralRequestValidator,
)
and hasattr(self._mistral._chat_completion_request_validator, "_mode")
):
raise RuntimeError(
"Unable to switch mistral tokenizer to finetuning mode "
"private API `_chat_completion_request_validator._mode` missing."
)
self._mistral._chat_completion_request_validator._mode = (
ValidationMode.finetuning
)
def _load_system_prompt(self, path_or_repo_id: str) -> str:
"""Load system prompt from local or HF Hub.
Note: Unused for now as we don't want to explicitly set the system prompt if a user does
not provide one.
Args:
path_or_repo_id: The path to the tokenizer files or the repo id.
Returns:
The system prompt.
"""
file_path = _get_file_path(path_or_repo_id, "SYSTEM_PROMPT.txt")
if not os.path.exists(file_path):
raise FileNotFoundError(f"System prompt file not found at {file_path}")
with open(file_path, "r", encoding="utf-8") as file:
return file.read()
@property
def bos_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.bos_id
@property
def eos_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.eos_id
@property
def pad_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.pad_id
@property
def unk_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.unk_id
@property
def bos_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.bos_token_id)
@property
def eos_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.eos_token_id)
@property
def pad_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.pad_token_id)
@property
def unk_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.unk_token_id)
@property
def padding_side(self) -> str:
return self._padding_side
@property
def name_or_path(self) -> str:
return self._name_or_path
@property
def chat_template(self) -> str | None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return None
def __len__(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.n_words
@classmethod
def from_pretrained(
cls,
name_or_path: str,
*,
revision: Optional[str] = None,
**kwargs, # pylint: disable=unused-argument
) -> "HFMistralTokenizer":
"""
Load a mistral tekken tokenizer from a local file or HF Hub and wrap it.
Args:
path_or_repo_id: The path to the tokenizer files or the repo id.
revision: The revision of the tokenizer to download.
kwargs: Additional keyword arguments.
Returns:
A HFMistralTokenizer instance.
"""
if revision:
raise NotImplementedError(
"Revision not supported yet for mistral-common tokenizer"
)
# only support Tekken tokenizer for now
# downloads from HF Hub if not local
tokenizer_path = _get_file_path(name_or_path, "tekken.json")
base = MistralTokenizer.from_file(tokenizer_path)
return cls(
base,
name_or_path=name_or_path,
tokenizer_path=tokenizer_path,
)
def save_pretrained(self, save_directory: str) -> None:
"""
Save the Tekken/SentencePiece model file so that from_pretrained can pick it up again.
Only Tekken models are supported.
Args:
save_directory: The directory to save the tokenizer files.
"""
inner = self._mistral.instruct_tokenizer.tokenizer
if isinstance(inner, Tekkenizer):
# Create the directory and save the model
try:
os.makedirs(save_directory, exist_ok=True)
# Verify directory was created
if not os.path.exists(save_directory):
raise RuntimeError(f"Failed to create directory: {save_directory}")
# Verify source file exists
if not os.path.exists(self._tokenizer_path):
raise FileNotFoundError(
f"Source tokenizer file not found: {self._tokenizer_path}"
)
destination_path = os.path.join(save_directory, "tekken.json")
copyfile(self._tokenizer_path, destination_path)
except Exception as e:
raise RuntimeError(
f"Failed to save tokenizer to {save_directory}: {e}. "
f"Source path: {self._tokenizer_path}, "
f"Directory exists: {os.path.exists(save_directory)}"
) from e
else:
raise RuntimeError(f"Unknown tokenizer type: {type(inner)}")
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
"""
Encode a text string into a list of token IDs.
Args:
text: The text string to encode.
add_special_tokens: Whether to add special tokens to the encoded tokens.
Returns:
A list of token IDs.
"""
return self._mistral.instruct_tokenizer.tokenizer.encode(
text,
bos=add_special_tokens,
eos=add_special_tokens,
)
def decode(
self, token_ids: int | list[int], skip_special_tokens: bool = False
) -> str:
"""
Decode a list of token IDs into a text string.
Args:
token_ids: The int or list of token IDs to decode.
skip_special_tokens: Whether to skip special tokens in the decoded text.
Returns:
The decoded text string.
"""
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
# to_string returns a string with special tokens
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None
) -> "ChatCompletionRequest":
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
SystemMessage,
ToolMessage,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool
messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = (
[]
)
for turn in conversation:
role = turn.get("role")
if role == "user":
messages.append(UserMessage(content=turn["content"]))
elif role == "assistant":
messages.append(
AssistantMessage(
content=turn.get("content"),
tool_calls=turn.get("tool_calls"),
)
)
elif role == "tool":
messages.append(
ToolMessage(
content=turn.get("content"),
tool_call_id=turn.get("tool_call_id"),
name=turn.get("name"),
)
)
elif role == "system":
messages.append(SystemMessage(content=turn["content"]))
else:
raise ValueError(
f"Unknown role for use with mistral-common tokenizer: {turn['role']}"
)
tool_calls: list[Tool] = []
if tools:
# convert to Tool
for tool in tools:
if tool["type"] != "function":
continue
function = tool["function"]
tool_calls.append(
Tool(
function=Function(
name=function["name"],
description=function["description"],
# set parameters to empty dict if not provided
parameters=function.get("parameters", {}),
)
)
)
chat_completion: ChatCompletionRequest = ChatCompletionRequest(
messages=messages,
tools=tool_calls,
)
return chat_completion
def apply_chat_template(
self,
messages: list[dict],
tokenize: bool = True,
tools: list[dict] | None = None,
chat_template: str | None = None, # pylint: disable=unused-argument
add_generation_prompt: bool = False, # pylint: disable=unused-argument
) -> list[int] | str:
if chat_template:
raise NotImplementedError("chat_template not supported yet")
if add_generation_prompt:
raise NotImplementedError("add_generation_prompt not supported yet")
chat_completion: ChatCompletionRequest = (
self._create_mistral_chat_completion_request(messages, tools)
)
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
if tokenize:
return tokens
return self.decode(tokens)
def pad(
self,
features: list[dict[str, list[int] | np.ndarray]],
*,
padding: bool | str | PaddingStrategy = True,
max_length: int | None = None,
pad_to_multiple_of: int | None = None,
return_tensors: str | None = None, # "np", "pt", or "tf"
) -> dict[str, np.ndarray | Tensor]:
"""
HF-style pad method that properly handles all sequence-related features:
- pad 'input_ids' & 'labels' to the longest (or to max_length)
"""
import torch
from torch.nn import functional as F
# Check for unsupported fields
if any("token_type_ids" in f for f in features):
raise ValueError("token_type_ids is not supported by this tokenizer")
# Determine desired sequence length
lengths = [len(f["input_ids"]) for f in features]
if padding in (True, "longest", PaddingStrategy.LONGEST):
target_length = max(lengths)
elif padding in ("max_length", PaddingStrategy.MAX_LENGTH):
if max_length is None:
raise ValueError("max_length must be set for 'max_length' padding")
target_length = max_length
elif padding in (False, "do_not_pad", PaddingStrategy.DO_NOT_PAD):
target_length = None
else:
raise ValueError(f"Unknown padding strategy: {padding}")
# Apply pad_to_multiple_of
if target_length is not None and pad_to_multiple_of is not None:
target_length = (
math.ceil(target_length / pad_to_multiple_of) * pad_to_multiple_of
)
# If no padding requested, just stack tensors
do_pad = target_length is not None
# Pad sequences using torch.nn.utils.rnn.pad_sequence
input_ids = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["input_ids"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=self.pad_token_id if self.pad_token_id is not None else 0,
)
labels = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["labels"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=IGNORE_INDEX,
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=0,
)
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
if "position_ids" in features[0]:
if self.padding_side == "left":
# Likely not needed, but keeping for now
# For left padding, we'll pad with 0s using pad_sequence, then handle manually
position_ids = torch.nn.utils.rnn.pad_sequence(
[
torch.tensor(x["position_ids"], dtype=torch.long)
for x in features
],
batch_first=True,
padding_value=0,
)
else:
# For right padding, continue the sequence
max_pos_len = max(len(f["position_ids"]) for f in features)
position_ids_list = []
for f in features:
pos_seq = torch.tensor(f["position_ids"], dtype=torch.long)
if len(pos_seq) < max_pos_len:
# Continue the sequence
last_pos = pos_seq[-1].item() if len(pos_seq) > 0 else -1
pad_len = max_pos_len - len(pos_seq)
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
pos_seq = torch.cat([pos_seq, pad_positions])
position_ids_list.append(pos_seq)
position_ids = torch.stack(position_ids_list)
else:
# Create position_ids if not present
seq_len = input_ids.size(1)
position_ids = (
torch.arange(seq_len, dtype=torch.long)
.unsqueeze(0)
.expand(input_ids.size(0), -1)
)
# Ensure all tensors have the same sequence length
max_seq_len = max(
input_ids.size(1),
labels.size(1),
attention_mask.size(1),
position_ids.size(1),
)
# TODO: check if trimming is needed? and correct.
if do_pad and target_length is not None:
max_seq_len = target_length
# Pad all tensors to the same length
if input_ids.size(1) < max_seq_len:
pad_len = max_seq_len - input_ids.size(1)
if self.padding_side == "right":
input_ids = F.pad(
input_ids,
(0, pad_len),
value=self.pad_token_id if self.pad_token_id is not None else 0,
)
else:
input_ids = F.pad(
input_ids,
(pad_len, 0),
value=self.pad_token_id if self.pad_token_id is not None else 0,
)
elif input_ids.size(1) > max_seq_len:
input_ids = input_ids[:, :max_seq_len]
if labels.size(1) < max_seq_len:
pad_len = max_seq_len - labels.size(1)
if self.padding_side == "right":
labels = F.pad(labels, (0, pad_len), value=IGNORE_INDEX)
else:
labels = F.pad(labels, (pad_len, 0), value=IGNORE_INDEX)
elif labels.size(1) > max_seq_len:
labels = labels[:, :max_seq_len]
if attention_mask.size(1) < max_seq_len:
pad_len = max_seq_len - attention_mask.size(1)
if self.padding_side == "right":
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
else:
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
elif attention_mask.size(1) > max_seq_len:
attention_mask = attention_mask[:, :max_seq_len]
if position_ids.size(1) < max_seq_len:
pad_len = max_seq_len - position_ids.size(1)
if self.padding_side == "right":
batch_size = position_ids.size(0)
new_position_ids = []
for i in range(batch_size):
seq = position_ids[i]
if len(seq) > 0:
# get last position and pad with sequential values
last_pos = seq[-1].item()
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
new_seq = torch.cat([seq, pad_positions])
else:
new_seq = torch.arange(pad_len, dtype=torch.long)
new_position_ids.append(new_seq)
position_ids = torch.stack(new_position_ids)
else:
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
elif position_ids.size(1) > max_seq_len:
position_ids = position_ids[:, :max_seq_len]
final_batch = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
# Handle non-sequence fields (raise error)
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
for f in features:
for key in f.keys():
if key not in sequence_fields:
raise NotImplementedError(
f"Non-sequence field {key} not handled yet"
)
# Convert to requested tensor type
if return_tensors is None or return_tensors == "np":
result = {}
for k, v in final_batch.items():
if isinstance(v, torch.Tensor):
result[k] = v.numpy().astype(np.long)
else:
result[k] = v
return result
if return_tensors == "pt":
return final_batch
raise ValueError(f"Unsupported return_tensors='{return_tensors}'")
def convert_ids_to_tokens(self, ids: list[int]) -> list[str]:
"""
Convert a list of token IDs to a list of tokens.
Args:
ids: The list of token IDs to convert.
Returns:
The list of tokens.
"""
return [
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
]

View File

@@ -3,7 +3,6 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length
into fixed-capacity batches to optimize memory usage and training throughput.
"""
import gc
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context
@@ -146,7 +145,7 @@ def pack_parallel(
"""
num_items = len(sequence_lengths)
if num_processes is None:
num_processes = max(1, min(num_items // group_size, cpu_count(), 16))
num_processes = max(1, min(num_items // group_size, cpu_count()))
# Create tasks for parallel processing
tasks = []
@@ -259,8 +258,8 @@ class MultipackBatchSampler(BatchSampler):
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 8, # Number of times to estimate batch count
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 16, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
@@ -350,7 +349,6 @@ class MultipackBatchSampler(BatchSampler):
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
del all_bins
# Group bins into batches (each batch contains batch_size bins)
batches = [
@@ -370,7 +368,6 @@ class MultipackBatchSampler(BatchSampler):
self.total_token_slots += total_slots
self._batches = batches
gc.collect()
return batches
def __iter__(self) -> Iterator[list[list[int]]]:
@@ -446,18 +443,10 @@ class MultipackBatchSampler(BatchSampler):
if self._len_across_ranks is None:
# Sample multiple times to get stable estimate
_sampled_lens = []
for _ in range(self.num_count_samples):
self._batches = None # Reset cached batches
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
len_batches = min(_sampled_lens)
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
if self._len_across_ranks is None:
self._len_across_ranks = self.gather_len_batches(len_batches)
else:
self._len_across_ranks = min(
self._len_across_ranks, self.gather_len_batches(len_batches)
)
self._len_across_ranks = self.gather_len_batches(len_batches)
return self._len_across_ranks

View File

@@ -102,8 +102,6 @@ class AxolotlInputConfig(
dpo_use_weighting: bool | None = None
dpo_use_logits_to_keep: bool | None = None
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
dpo_padding_free: bool | None = None
datasets: (
Annotated[
@@ -264,7 +262,7 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = None
context_parallel_degree: int | None = None
heads_k_stride: int | None = None
ring_attn_func: RingAttnFunc | None = None
@@ -338,14 +336,6 @@ class AxolotlInputConfig(
plugins: list[str] | None = Field(default=None)
@field_validator("seed", mode="after")
@classmethod
def set_default_seed(cls, seed):
if seed is None:
LOG.info("`seed` not set in config; setting to 42")
seed = 42
return seed
@field_validator("datasets", mode="before")
@classmethod
def deprecate_sharegpt_datasets(cls, datasets):
@@ -1189,47 +1179,63 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_grpo_liger_sequence_parallel(cls, data):
def check_grpo_liger_context_parallel(cls, data):
if (
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("sequence_parallel_degree", 1) > 1
and data.get("context_parallel_degree", 1) > 1
):
raise ValueError("GRPO + SP + Liger not currently supported")
raise ValueError("GRPO + CP + Liger not currently supported")
return data
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
if not self.flash_attention:
def check_context_parallel_degree(self):
if not self.context_parallel_degree:
self.context_parallel_degree = 1
elif self.context_parallel_degree > 1:
import torch
world_size = torch.cuda.device_count()
if not world_size >= self.context_parallel_degree:
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
f"World size ({world_size}) must be greater "
f"than or equal to CP degree ({self.context_parallel_degree})"
)
if not world_size % self.context_parallel_degree == 0:
raise ValueError(
f"SP degree ({self.context_parallel_degree}) "
f"must evenly divide world size ({world_size})"
)
if self.sample_packing and getattr(self, "micro_batch_size", 1) > 1:
if not (self.flash_attention or self.sdp_attention):
raise ValueError(
"flash_attention: true or sdp_attention: true "
"must be set with context_parallel_degree > 1"
)
if self.sample_packing and self.micro_batch_size > 1:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled "
"due to a `ring-flash-attn` requirement"
)
try:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
if self.flash_attention:
try:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"context_parallel_degree > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
# TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
# TODO: monkeypatch / callback to average losses correctly across CP ranks
# / fix gradient scaling across CP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank.
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"Context parallelism (SP) is enabled with "
f"context_parallel_degree={self.context_parallel_degree}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
@@ -1240,7 +1246,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1:
if getattr(self, "context_parallel_degree", 1) == 1:
return self
if self.ring_attn_func is not None:
@@ -1267,68 +1273,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_tokenizer_use_mistral_common(cls, data):
if data.get("tokenizer_use_mistral_common") is None:
if any(
"magistral" in name.lower()
for name in [
data.get("base_model", ""),
data.get("base_model_config", ""),
data.get("tokenizer_config", ""),
]
):
LOG.warning(
"tokenizer_use_mistral_common auto inferred to True for Magistral models. Please set it to True explicitly if you want to use mistral-common tokenizer."
)
data["tokenizer_use_mistral_common"] = True
return data
@field_validator("tokenizer_use_mistral_common", mode="after")
@classmethod
def check_mistral_common_import(cls, tokenizer_use_mistral_common):
if tokenizer_use_mistral_common:
try:
import mistral_common # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`."
) from exception
return tokenizer_use_mistral_common
@model_validator(mode="before")
@classmethod
def check_mistral_common_incompatible_options(cls, data):
if not data.get("tokenizer_use_mistral_common"):
return data
# NOTE: mistral-common tokenizer is not compatible with editing tokenizer at the moment
if data.get("added_tokens_overrides"):
raise ValueError(
"added_tokens_overrides is not supported with mistral-common tokenizer"
)
if data.get("special_tokens"):
raise ValueError(
"special_tokens override is not supported with mistral-common tokenizer"
)
if data.get("tokens"):
raise ValueError(
"tokens override is not supported with mistral-common tokenizer"
)
if data.get("chat_template"):
raise ValueError(
"Setting chat_template is not supported with mistral-common tokenizer"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -43,7 +43,6 @@ class SFTDataset(BaseModel):
field_human: str | None = None
field_model: str | None = None
field_messages: str | None = None
field_tools: str | None = None
# deprecated, use message_property_mappings
message_field_role: str | None = None
# deprecated, use message_property_mappings

View File

@@ -18,7 +18,6 @@ class ModelInputConfig(BaseModel):
tokenizer_config: str | None = None
tokenizer_use_fast: bool | None = None
tokenizer_legacy: bool | None = None
tokenizer_use_mistral_common: bool | None = None
tokenizer_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers tokenizer class"}
)

View File

@@ -16,6 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
@@ -441,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1
)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_degree
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
@@ -466,7 +467,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
num_processes=cfg.dataset_processes,
)
data_loader = DataLoader(
@@ -479,12 +479,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
math.floor(
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
data_loader_len * cfg.num_epochs * cfg.context_parallel_degree
)
)
if cfg.dataloader_drop_last:
# drop the last batch for each epoch
total_num_steps -= int(math.ceil(cfg.num_epochs))
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -505,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_degree
/ cfg.batch_size
)
)
@@ -632,8 +629,6 @@ def setup_trainer(
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
if (
cfg.torch_compile
and cfg.fsdp_config

View File

@@ -12,7 +12,7 @@ from axolotl.common.datasets import load_datasets
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_preference_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RLType
@@ -64,7 +64,7 @@ def fixture_base_cfg():
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
"context_parallel_degree": 1,
# Dtype
"fp16": False,
"bf16": False,
@@ -451,19 +451,15 @@ def rand_reward_func(prompts, completions) -> list[float]:
# Only use mock for the commented out configs
if dataset_name is not None:
with patch(
"axolotl.utils.data.rl.load_dataset_with_config"
"axolotl.utils.data.rl.load_dataset_w_config"
) as mock_load_dataset:
mock_load_dataset.return_value = request.getfixturevalue(
dataset_name
)
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
else:
# Load actual datasets for orpo_cfg and kto_cfg
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
builder.train_dataset = train_dataset
builder.eval_dataset = eval_dataset

View File

@@ -4,6 +4,7 @@ Simple end-to-end test for Cut Cross Entropy integration
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
@@ -58,7 +59,8 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
@@ -103,7 +105,8 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
@@ -131,7 +134,8 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):

View File

@@ -5,6 +5,7 @@ e2e tests to make sure all the hooks are fired on the plugin
import os
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import BasePlugin
from axolotl.train import train
@@ -159,7 +160,8 @@ class TestPluginHooks:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,9 +5,11 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@@ -16,8 +18,8 @@ from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):
return {
"base_model": "Qwen/Qwen3-0.6B",
"tokenizer_config": "winglian/qwen3-14b-math",
"base_model": "osllmai-community/Llama-3.2-1B",
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
"plugins": [
"axolotl.integrations.kd.KDPlugin",
"axolotl.integrations.liger.LigerPlugin",
@@ -30,22 +32,20 @@ def min_cfg(temp_dir):
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,
"kd_temperature": 1.0,
"kd_beta": 0.0,
"kd_normalize_topk": True,
"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
"datasets": [
{
"path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized",
"type": "chat_template",
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
"type": "axolotl.integrations.kd.chat_template",
"field_messages": "messages_combined",
"split": "train",
"split_thinking": True,
"eot_tokens": ["<|im_end|>"],
"data_files": ["train/batch-000000.parquet"],
"logprobs_field": "llm_text_generation_vllm_logprobs",
"temperature": 1.0,
"preprocess_shards": 2,
},
],
"skip_prepare_dataset": True,
"val_set_size": 0.0,
"sequence_len": 2048,
"sample_packing": True,
@@ -81,29 +81,18 @@ class TestKnowledgeDistillation:
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
)
@pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
@@ -123,22 +112,13 @@ class TestKnowledgeDistillation:
| kd_min_cfg
)
# pylint: disable=duplicate-code
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"

View File

@@ -2,6 +2,7 @@
Simple end-to-end test for Liger integration
"""
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -56,7 +57,8 @@ class LigerIntegrationTestCase:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -102,7 +104,8 @@ class LigerIntegrationTestCase:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,6 +6,7 @@ from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -87,7 +88,8 @@ class TestLLMCompressorIntegration:
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
try:
train(cfg=cfg, dataset_meta=dataset_meta)

Some files were not shown because too many files have changed in this diff Show More