Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
e1c7a61243 fix reentrant when using offloading 2025-09-14 10:42:15 -04:00
64 changed files with 493 additions and 3502 deletions

View File

@@ -44,7 +44,7 @@ jobs:
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]

View File

@@ -304,7 +304,7 @@ jobs:
pytorch: 2.8.0
num_gpus: 1
gpu_type: "B200"
axolotl_extras: fbgemm-gpu
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -14,7 +14,7 @@ repos:
rev: v0.12.12
hooks:
- id: ruff
args: [--fix, --select, I]
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1

View File

@@ -1,6 +1,6 @@
cff-version: 1.2.0
type: software
title: "Axolotl: Open Source LLM Post-Training"
title: "Axolotl: Post-Training for AI Models"
message: "If you use this software, please cite it as below."
authors:
- name: "Axolotl maintainers and contributors"

View File

@@ -5,9 +5,6 @@
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
<p align="center">
<strong>A Free and Open Source LLM Fine-tuning Framework</strong><br>
</p>
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
@@ -53,21 +50,20 @@
## ✨ Overview
Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
Axolotl is a tool designed to streamline post-training for various AI models.
Features:
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
## 🚀 Quick Start - LLM Fine-tuning in Minutes
## 🚀 Quick Start
**Requirements**:
@@ -164,7 +160,7 @@ If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
title = {Axolotl: Open Source LLM Post-Training},
title = {Axolotl: Post-Training for AI Models},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},

View File

@@ -51,11 +51,3 @@ axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
::: {.callout-note}
If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,
e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`
:::

View File

@@ -176,8 +176,8 @@
}
],
"source": [
"from axolotl.cli.config import load_cfg\n",
"from axolotl.utils.dict import DictDefault\n",
"from axolotl.cli.config import load_cfg\n",
"\n",
"# Axolotl provides full control and transparency over model and training configuration\n",
"config = DictDefault(\n",

View File

@@ -20,13 +20,7 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
```bash
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/devstral/devstral-small-qlora.yml

View File

@@ -106,16 +106,6 @@ See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-to
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
### Thinking and chat_template masking conflict
OpenAIs Harmony template hides `thinking` in all non-final turns, which conflicts with Axolotls `chat_template` masking.
If your dataset has `thinking` content mid-turn, there are two paths we recommend:
- Train only on the last turn. This can be accomplished via chat_template's [train on last doc](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#training-on-last-message).
- Adjust your dataset to only have `thinking` content in the last turn.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).

View File

@@ -1,85 +0,0 @@
# Finetune HunYuan with Axolotl
Tencent released a family of opensource models called HunYuan with varying parameter scales of 0.5B, 1.8B, 4B, and 7B scale for both Pre-trained and Instruct variants. The models can be found at [HuggingFace](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as HunYuan is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/hunyuan/hunyuan-v1-dense-qlora.yaml
```
This config uses about 4.7 GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### Dataset
HunYuan Instruct models can choose to enter a slow think or fast think pattern. For best performance on fine-tuning their Instruct models, your dataset should be adjusted to match their pattern.
```python
# fast think pattern
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "/no_think What color is the sun?" },
{"role": "assistant", "content": "<think>\n\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
]
# slow think pattern
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "/no_think What color is the sun?" },
{"role": "assistant", "content": "<think>\nThe user is asking about the color of the sun. I need to ...\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
]
```
### TIPS
- For inference, the official Tencent team recommends
```json
{
"do_sample": true,
"top_k": 20,
"top_p": 0.8,
"repetition_penalty": 1.05,
"temperature": 0.7
}
```
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Tencent HunYuan Blog](https://hunyuan.tencent.com/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: tencent/Hunyuan-0.5B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,64 +0,0 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/dataset_prepared
sequence_len: 8192
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_checkpointing: true
gradient_accumulation_steps: 1
micro_batch_size: 64
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -15,18 +15,20 @@ liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/qat_out/dataset_prepared
sample_packing: false
sequence_len: 8192
flash_attention: true
sample_packing: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
@@ -65,7 +67,7 @@ fsdp:
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
@@ -74,6 +76,6 @@ fsdp_config:
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,56 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
pretraining_dataset:
- path: wikitext
name: wikitext-103-raw-v1
type: completion
field: text
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
diffusion:
noise_schedule: cosine
min_mask_ratio: 0.15
max_mask_ratio: 0.85
num_diffusion_steps: 128
eps: 5e-4
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 250
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
gradient_accumulation_steps: 8
micro_batch_size: 4
max_steps: 10000
warmup_ratio: 0.1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 3e-4
sdp_attention: true
bf16: auto
tf32: true
logging_steps: 1
save_strategy: steps
save_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,59 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.05
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
diffusion:
noise_schedule: cosine
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 250
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
eval_sample_packing: true
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
warmup_steps: 0.1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 1e-5
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
sdp_attention: true
logging_steps: 1
save_strategy: best
eval_strategy: epoch
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -18,13 +18,7 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
```bash
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/magistral/magistral-small-qlora.yaml

View File

@@ -1,54 +0,0 @@
# Finetune ByteDance's Seed-OSS with Axolotl
[Seed-OSS](https://huggingface.co/collections/ByteDance-Seed/seed-oss-68a609f4201e788db05b5dcd) are a series of 36B parameter open source models trained by ByteDance's Seed Team.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/seed-oss/seed-oss-36b-qlora.yaml
```
This config uses about 27.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Seed Team recommends `top_p=0.95` and `temperature=1.1`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [ByteDance Seed Website](https://seed.bytedance.com/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,56 +0,0 @@
base_model: ByteDance-Seed/Seed-OSS-36B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -22,9 +22,6 @@ pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# audio
pip3 install librosa==0.11.0
pip3 install 'mistral_common[audio]==1.8.3'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:

View File

@@ -64,7 +64,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
torchao==0.12.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6

View File

@@ -162,7 +162,6 @@ extras_require = {
"llmcompressor": [
"llmcompressor==0.5.1",
],
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require

View File

@@ -115,7 +115,6 @@ class QuantizeCliArgs:
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
hub_model_id: Optional[str] = field(default=None)
@dataclass

View File

@@ -14,13 +14,6 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.cli.utils.diffusion import (
diffusion_inference,
launch_diffusion_gradio_ui,
render_html,
run_diffusion,
)
from axolotl.integrations.base import PluginManager
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -36,7 +29,6 @@ def get_multi_line_input() -> str:
Possibly multi-line, possibly empty stdin input as a string.
"""
print("Give me an instruction (Ctrl + D to submit): ")
print("=" * 80)
instruction = ""
for line in sys.stdin:
@@ -51,9 +43,9 @@ def do_inference(
cli_args: InferenceCliArgs,
):
"""
Runs inference on the command line in a loop. User input is accepted, a chat
template is (optionally) applied, and the model specified in the `axolotl` config is
used to generate completions according to a default generation config.
Runs inference on the command line in a loop. User input is accepted, a chat template
is (optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -72,28 +64,16 @@ def do_inference(
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Detect diffusion mode
plugin_manager = PluginManager.get_instance()
is_diffusion = any(
plugin.__class__.__name__ == "DiffusionPlugin"
for plugin in plugin_manager.plugins.values()
)
if is_diffusion:
print("=" * 80)
print("Commands:")
print(":complete N -> completion mode with N tokens (default 64)")
print(":mask R -> random masking with ratio R (0.01.0)")
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
@@ -123,19 +103,9 @@ def do_inference(
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 80)
print("=" * 40)
model.eval()
with torch.no_grad():
if is_diffusion:
diffusion_inference(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
)
continue
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
@@ -158,7 +128,7 @@ def do_inference(
generation_config=generation_config,
streamer=streamer,
)
print("=" * 80)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@@ -191,30 +161,13 @@ def do_inference_gradio(
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Detect diffusion mode
plugin_manager = PluginManager.get_instance()
is_diffusion = any(
plugin.__class__.__name__ == "DiffusionPlugin"
for plugin in plugin_manager.plugins.values()
)
if is_diffusion:
launch_diffusion_gradio_ui(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompter_module=prompter_module,
chat_template_str=chat_template_str,
)
return
def generate(instruction):
if not instruction:
return

View File

@@ -5,17 +5,12 @@ CLI to post-training quantize a model using torchao
from pathlib import Path
from typing import Union
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from transformers import AutoModelForCausalLM
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import (
TorchAOQuantDType,
get_quantization_config,
quantization_config_to_str,
quantize_model,
)
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = get_logger(__name__)
@@ -48,13 +43,13 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("base_model") or cfg.output_dir
model_path = cli_args.get("model_path") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
weight_dtype = TorchIntDType[weight_dtype]
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
activation_dtype = TorchIntDType[activation_dtype]
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
@@ -62,15 +57,10 @@ def do_quantize(
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
LOG.info(f"Loading model from {model_path}.")
LOG.info(f"Loading model from {model_path}...")
tokenizer = load_tokenizer(cfg)
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype=torch_dtype
)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
LOG.info(
f"Quantizing model with configuration: \n"
@@ -80,21 +70,11 @@ def do_quantize(
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model(
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
quantization_config = get_quantization_config(
weight_dtype, activation_dtype, group_size
)
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=quantize_embedding,
)
model.config.quantization_config = ao_config
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
@@ -106,14 +86,4 @@ def do_quantize(
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if hub_model_id:
hub_model_id = (
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -1,375 +0,0 @@
"""Helpers for diffusion-mode inference in CLI and Gradio."""
from __future__ import annotations
import gradio as gr
import torch
from colorama import Fore, Style
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
from axolotl.utils.dict import DictDefault
def diffusion_inference(
model,
tokenizer,
cfg,
prompt: str,
chat_template_str: str | None = None,
):
"""Diffusion inference helper method."""
mode = "random"
completion_tokens = 0
target_mask_ratio = None
mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt)
if cleaned:
prompt = cleaned
info = run_diffusion(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
mode=mode,
target_mask_ratio=target_mask_ratio,
completion_tokens=completion_tokens,
)
masked_text = info["masked_text"]
mask_ratio = info["mask_ratio"]
generated_ids = info["generated_ids"]
masked_positions = info["masked_positions"]
orig_ids = info["orig_ids"]
# Display with masked preview and colored diff
if masked_text is not None and mask_ratio is not None:
print(f"Masked ({mask_ratio:.1%}):\n{masked_text}\n")
if generated_ids is not None:
# Compute per-token style
styles: list[str] = []
for i, tid in enumerate(generated_ids):
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
styles.append("green") # correct fill
elif i < len(orig_ids):
styles.append("red") # incorrect fill
else:
styles.append("normal") # appended
else:
same = i < len(orig_ids) and tid == orig_ids[i]
styles.append("dim" if same else "normal")
# Group contiguous spans by style
styled_spans: list[tuple[str, int, int]] = []
if generated_ids:
current_style = styles[0]
start = 0
for i in range(1, len(generated_ids)):
s = styles[i]
if s != current_style:
styled_spans.append((current_style, start, i))
current_style, start = s, i
styled_spans.append((current_style, start, len(generated_ids)))
out_parts = []
for style_name, a, b in styled_spans:
chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
if style_name == "green":
out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
elif style_name == "red":
out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
else:
if style_name == "dim":
out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
else:
out_parts.append(chunk_text)
print("Generated:\n" + "".join(out_parts))
else:
print("Generated:\n(no output)")
def _parse_commands(text: str):
"""
Parse leading diffusion commands.
Supported at start of input (can be chained):
:complete N -> completion mode with N tokens (default 64)
:mask R -> random masking with ratio R in [0, 1]
"""
tokens = text.strip().split()
i = 0
mode = "random"
completion_tokens = 0
target_mask_ratio = None
consumed = 0
while i < len(tokens) and tokens[i].startswith(":"):
cmd = tokens[i]
i += 1
consumed = i
if cmd == ":complete":
mode = "completion"
if i < len(tokens):
try:
completion_tokens = int(tokens[i])
i += 1
consumed = i
except Exception:
completion_tokens = 64
else:
completion_tokens = 64
elif cmd == ":mask":
mode = "random"
if i < len(tokens):
try:
target_mask_ratio = float(tokens[i])
i += 1
consumed = i
except Exception:
target_mask_ratio = None
else:
i -= 1
consumed = i
break
cleaned = " ".join(tokens[consumed:])
return mode, completion_tokens, target_mask_ratio, cleaned
def run_diffusion(
*,
model,
tokenizer,
cfg: DictDefault,
prompt: str,
chat_template_str: str | None,
mode: str = "random",
target_mask_ratio: float | None = None,
completion_tokens: int = 0,
):
"""Run a single diffusion generation and return a structured result dict."""
if chat_template_str:
batch = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False)
seq = batch["input_ids"].to(cfg.device)
gen_mode = "completion" if mode == "completion" else "random"
comp_tokens = int(completion_tokens) if gen_mode == "completion" else 0
result = generate(
model,
tokenizer,
original_sequence=seq[:1],
num_diffusion_steps=cfg.diffusion.num_diffusion_steps,
temperature=cfg.diffusion.generation_temperature,
mask_token_id=int(mask_token_id),
mode=gen_mode, # type: ignore[arg-type]
completion_tokens=comp_tokens,
target_mask_ratio=target_mask_ratio,
)
masked_text = result.get("masked") if isinstance(result, dict) else None
mask_ratio = result.get("mask_ratio") if isinstance(result, dict) else None
generated_ids = result.get("generated_ids") if isinstance(result, dict) else None
masked_positions = (
set(result.get("masked_positions") or []) if isinstance(result, dict) else set()
)
orig_ids = seq[0].detach().cpu().tolist()
return {
"masked_text": masked_text,
"mask_ratio": mask_ratio,
"generated_ids": generated_ids,
"masked_positions": masked_positions,
"orig_ids": orig_ids,
}
def render_html(
*,
generated_ids: list[int] | None,
orig_ids: list[int],
masked_positions: set[int],
tokenizer,
) -> str:
"""Render HTML visualizing diffusion outputs."""
if not generated_ids:
return "<pre>Generated:\n(no output)</pre>"
def _style_for(i: int, tid: int) -> str:
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
return "green"
if i < len(orig_ids):
return "red"
return "normal"
same = i < len(orig_ids) and tid == orig_ids[i]
return "dim" if same else "normal"
# Group contiguous spans by style to reduce HTML size
spans: list[tuple[str, int, int]] = []
if generated_ids:
cur = _style_for(0, generated_ids[0])
start = 0
for i in range(1, len(generated_ids)):
s = _style_for(i, generated_ids[i])
if s != cur:
spans.append((cur, start, i))
cur, start = s, i
spans.append((cur, start, len(generated_ids)))
html_parts = []
for style_name, a, b in spans:
txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
if style_name == "green":
html_parts.append(f'<span style="color:#2e7d32">{txt}</span>')
elif style_name == "red":
html_parts.append(f'<span style="color:#c62828">{txt}</span>')
elif style_name == "dim":
html_parts.append(f'<span style="opacity:0.6">{txt}</span>')
else:
html_parts.append(txt)
legend = (
'<div style="font-size:0.9em;margin-bottom:4px">'
'<span style="color:#2e7d32">correct</span>, '
'<span style="color:#c62828">incorrect</span>, '
'<span style="opacity:0.6">unchanged</span>'
"</div>"
)
return (
legend
+ '<pre style="white-space:pre-wrap">Generated:\n'
+ "".join(html_parts)
+ "</pre>"
)
def launch_diffusion_gradio_ui(
*,
model,
tokenizer,
cfg: DictDefault,
prompter_module=None,
chat_template_str: str | None = None,
):
"""Build and launch a simple Gradio UI for diffusion inference."""
with gr.Blocks(
title=cfg.get("gradio_title", "Axolotl Diffusion Interface")
) as demo:
gr.Markdown(
"""
## Axolotl Diffusion Inference
- Mode "Random" masks tokens at a target ratio and fills them.
- Mode "Completion" appends N masked tokens at the end and fills them.
"""
)
with gr.Row():
mode = gr.Radio(
choices=["random", "completion"],
value="random",
label="Mode",
)
mask_ratio = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.4,
label="Mask ratio (random mode)",
interactive=True,
)
completion_tokens = gr.Number(
value=64,
precision=0,
label="Completion tokens (completion mode)",
interactive=True,
visible=False,
)
instruction = gr.Textbox(label="Instruction", lines=6)
run_btn = gr.Button("Generate")
masked_preview = gr.Textbox(label="Masked preview", lines=6)
html_out = gr.HTML(label="Generated")
def _toggle_controls(selected_mode: str):
return (
gr.update(visible=(selected_mode == "random")),
gr.update(visible=(selected_mode == "completion")),
)
mode.change(
_toggle_controls,
inputs=[mode],
outputs=[mask_ratio, completion_tokens],
)
def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int):
if not instruction_text:
return "", "<pre>Generated:\n(no output)</pre>"
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(
instruction=instruction_text.strip("\n")
)
)
else:
prompt = instruction_text.strip()
info = run_diffusion(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
mode=selected_mode,
target_mask_ratio=mratio if selected_mode == "random" else None,
completion_tokens=int(ctoks) if selected_mode == "completion" else 0,
)
masked_text = info.get("masked_text")
mask_ratio_val = info.get("mask_ratio")
generated_ids = info.get("generated_ids")
masked_positions = info.get("masked_positions") or set()
orig_ids = info.get("orig_ids") or []
preview = (
f"Masked ({mask_ratio_val:.1%}):\n{masked_text}"
if masked_text is not None and mask_ratio_val is not None
else ""
)
html = render_html(
generated_ids=generated_ids,
orig_ids=orig_ids,
masked_positions=masked_positions,
tokenizer=tokenizer,
)
return preview, html
run_btn.click(
_gen,
inputs=[instruction, mode, mask_ratio, completion_tokens],
outputs=[masked_preview, html_out],
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)

View File

@@ -7,11 +7,7 @@ from pathlib import Path
from typing import Type, Union
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
Trainer,
)
from transformers import DataCollatorWithFlattening, EarlyStoppingCallback
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
@@ -27,16 +23,15 @@ from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
@@ -44,6 +39,7 @@ from axolotl.utils.collators import (
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
@@ -395,11 +391,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None

View File

@@ -49,13 +49,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
"max": torch.max,
"sum": torch.sum,
}
class AxolotlTrainer(
PackingMixin,
@@ -96,9 +89,7 @@ class AxolotlTrainer(
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
)
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -371,11 +362,6 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
@override
def evaluate(self, *args, **kwargs):
LOG.info("Running evaluation step...")
return super().evaluate(*args, **kwargs)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}
@@ -599,17 +585,9 @@ class AxolotlTrainer(
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
for key, metric_data in self._stored_metrics[train_eval].items():
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
reduction_type = metric_data["reduction"]
fn = REDUCTION_FNS.get(reduction_type)
if fn is None:
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]"
)
logs[key] = round(fn(values).item(), 4)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
if is_main_process():
# Add memory usage
@@ -633,27 +611,10 @@ class AxolotlTrainer(
return super().log(logs, start_time)
def store_metrics(
self,
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
train_eval: Literal["train", "eval"] = "train",
reduction: Literal["mean", "min", "max", "sum"] = "mean",
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
"""
Store metrics with specified reduction type.
Args:
metrics: Dictionary of metric names to values, or metric names to (value,
reduction_type) tuples.
train_eval: Whether this is for training or evaluation.
"""
for key, value in metrics.items():
if isinstance(value, tuple):
value, _reduction = value # type: ignore[assignment]
else:
value, _reduction = value, reduction
self._stored_metrics[train_eval][key]["values"].append(value)
self._stored_metrics[train_eval][key]["reduction"] = _reduction
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey

View File

@@ -3,11 +3,14 @@ Trainer mixin for activation checkpointing w offloading
"""
import contextlib
from functools import partial
from peft import PeftModel
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers import GradientCheckpointingLayer, Trainer
@@ -46,9 +49,20 @@ class ActivationOffloadingMixin(Trainer):
return super().training_step(*args, **kwargs)
def ac_wrap_hf_model(model: nn.Module, **kwargs):
def ac_wrap_hf_model(model: nn.Module, use_reentrant=None, **kwargs):
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
if use_reentrant:
checkpoint_wrapper_fn = partial(
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT
)
else:
checkpoint_wrapper_fn = checkpoint_wrapper
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
auto_wrap_policy=auto_wrap_policy,
**kwargs,
)
def get_lora_act_offloading_ctx_manager(

View File

@@ -142,7 +142,7 @@ class BasePlugin:
model: The loaded model.
"""
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Returns a custom class for the trainer.
Args:

View File

@@ -20,8 +20,8 @@ from typing import Any, Dict, List, Type
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
def merge_input_args():

View File

@@ -34,7 +34,6 @@ plugins:
- arcee
- cohere
- cohere2
- deepseek_v3
- gemma
- gemma2
- gemma3
@@ -43,7 +42,6 @@ plugins:
- gemma3n_text
- glm
- glm4
- glm4_moe
- gpt_oss
- granite
- granitemoe
@@ -66,7 +64,6 @@ plugins:
- qwen3
- qwen3_moe
- smollm3
- seed_oss
- voxtral
## Citation

View File

@@ -1,154 +0,0 @@
# Diffusion LM Training Plugin for Axolotl
This plugin enables diffusion language model training using an approach inspired by
LLaDA (Large Language Diffusion Models) within Axolotl.
## Overview
LLaDA is a diffusion-based approach to language model training that uses:
- **Random token masking** during training instead of next-token prediction
- **Bidirectional attention** to allow the model to attend to the full context
- **Importance weighting** based on masking probabilities for stable training
This approach can lead to more robust language models with better understanding of
bidirectional context.
## Installation
The plugin is included with Axolotl. See our
[installation docs](https://docs.axolotl.ai/docs/installation.html).
## Quickstart
Train with an example config (Llama3.2 1B):
- Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml`
- SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml`
### Basic Configuration
You can also modify your existing configs to enable / customize diffusion training.
Add the following to your Axolotl config:
```yaml
# Enable diffusion LM training plugin
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
```
And, configure the nested `diffusion` block (defaults shown):
```yaml
diffusion:
noise_schedule: linear # or "cosine"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
# Mask token (training auto-adds if missing, avoid pad/eos)
mask_token_str: "<|diffusion_mask|>"
# Or use an existing special token id (e.g., 128002 for Llama-3.x)
# mask_token_id: 128002
# Sample generation during training (optional)
generate_samples: true
generation_interval: 100
num_generation_samples: 3
generation_steps: 128
generation_temperature: 0.0
generation_max_length: 100
```
## Supported Models
Any models that support 4D attention masks should work out of the box. If not, please
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a
[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)!
## How It Works
### Random Masking
During training, tokens are randomly masked:
- Sample timestep `t` uniformly from [0, 1]
- Calculate masking probability: `p = (1 - eps) * t + eps`
- Randomly mask tokens with probability `p`
### Diffusion Loss
Loss is computed only on masked tokens with (optional) importance weighting:
```python
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
```
## Sample Generation
When `diffusion.generate_samples: true`, the plugin generates samples during training:
```
Sample 1:
Original (45 tokens): The quick brown fox jumps over the lazy dog...
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
Generated: The quick brown fox jumps over the lazy dog...
```
Samples are logged to console and wandb (if enabled).
## Inference
Diffusion inference is integrated into the standard Axolotl CLI. Use the same config
you trained with and run:
```
axolotl inference path/to/your-config.yaml
```
Optionally, pass `--gradio` to use a simple web interface.
Interactive controls (prefix the prompt with commands):
- `:complete N` → completion mode with N new masked tokens appended (default 64)
- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0]
Example session:
```
================================================================================
Commands:
:complete N -> completion mode with N tokens (default 64)
:mask R -> random masking with ratio R (0.01.0)
================================================================================
Give me an instruction (Ctrl + D to submit):
:mask 0.4 The quick brown fox jumps over the lazy dog
Masked (40.0%):
The [MASK] brown [MASK] jumps over the [MASK] dog
Generated:
The quick brown fox jumps over the loud dog
```
## Metrics and Monitoring
The plugin adds (or modifies) several metrics to track diffusion training:
- `train/loss`: Weighted diffusion loss
- `train/accuracy`: Accuracy on masked tokens
- `train/mask_ratio`: Average fraction of tokens masked
- `train/num_masked_tokens`: Number of tokens masked
- `train/avg_p_mask`: Average masking probability
- `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight
## Limitations
- No flash attention support
- No RL training support
## References
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
- [Axolotl Documentation](https://docs.axolotl.ai/)
- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args)

View File

@@ -1,19 +0,0 @@
"""Diffusion LM training plugin init."""
from .args import DiffusionArgs, DiffusionConfig
from .callbacks import DiffusionGenerationCallback
from .generation import generate
from .plugin import DiffusionPlugin
from .trainer import DiffusionTrainer
from .utils import create_bidirectional_attention_mask, resolve_mask_token_id
__all__ = [
"DiffusionArgs",
"DiffusionPlugin",
"DiffusionTrainer",
"generate",
"resolve_mask_token_id",
"create_bidirectional_attention_mask",
"DiffusionGenerationCallback",
"DiffusionConfig",
]

View File

@@ -1,95 +0,0 @@
"""Config args for diffusion LM training (nested under `diffusion:`)."""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field, model_validator
class DiffusionConfig(BaseModel):
"""Nested diffusion configuration available under the `diffusion` key."""
# Noise schedule config
noise_schedule: Literal["linear", "cosine"] = Field(
default="linear", description="Type of noise schedule for diffusion training"
)
min_mask_ratio: float = Field(
default=0.1,
ge=0.0,
le=1.0,
description="Minimum masking ratio for diffusion noise schedule",
)
max_mask_ratio: float = Field(
default=0.9,
ge=0.0,
le=1.0,
description="Maximum masking ratio for diffusion noise schedule",
)
num_diffusion_steps: int = Field(
default=128, ge=1, description="Number of diffusion timesteps"
)
eps: float = Field(
default=1e-3,
ge=0.0,
le=1.0,
description="Epsilon value for minimum masking probability in forward process",
)
# Training config
importance_weighting: bool = Field(
default=True,
description="Apply importance weighting to loss based on masking probability",
)
mask_token_id: int | None = Field(
default=None,
description=(
"Token ID to use for masking. Unset by default; can use one of the "
"tokenizer's special tokens here."
),
)
mask_token_str: str | None = Field(
default=None,
description=(
"Token string to use as a mask. If `mask_token_id` is invalid or unset, "
"this token will be ensured to exist as an additional special token and "
"used. If absent, a default '<|diffusion_mask|>' will be added."
),
)
# Sample generation config
generate_samples: bool = Field(
default=True, description="Enable sample generation during training"
)
generation_interval: int = Field(
default=100, ge=1, description="Generate samples every N steps"
)
num_generation_samples: int = Field(
default=3, ge=1, description="Number of samples to generate each time"
)
generation_steps: int = Field(
default=128, ge=1, description="Number of diffusion steps for generation"
)
generation_temperature: float = Field(
default=0.0,
ge=0.0,
description="Temperature for generation sampling (0.0 = deterministic)",
)
generation_max_length: int = Field(
default=100, ge=1, description="Maximum sequence length for generation"
)
@model_validator(mode="after")
def _validate_mask_ratios(self) -> "DiffusionConfig":
if self.min_mask_ratio > self.max_mask_ratio:
raise ValueError("min_mask_ratio must be ≤ max_mask_ratio")
return self
class DiffusionArgs(BaseModel):
"""Plugin entry that exposes the nested `diffusion` block to the core config."""
diffusion: DiffusionConfig = Field(
default_factory=DiffusionConfig,
description="Diffusion training configuration. Only nested block is supported.",
)

View File

@@ -1,174 +0,0 @@
"""Callbacks for diffusion training."""
import logging
import sys
import wandb
from colorama import Fore, Style
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from .generation import generate_samples
# Simpler logger for more readable sample generation
logger = logging.getLogger(__name__)
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
logger.propagate = False
logger.setLevel(logging.INFO)
class DiffusionGenerationCallback(TrainerCallback):
"""Callback for generating samples during diffusion training."""
def __init__(self, trainer):
self.trainer = trainer
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Generate samples at specified intervals."""
if (
state.global_step > 0
and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0
):
if not self.trainer.state.is_world_process_zero:
return
# Use eval dataloader if available, otherwise use train dataloader
dataloader = None
try:
if getattr(self.trainer, "eval_dataset", None) is not None:
dataloader = self.trainer.get_eval_dataloader()
except Exception:
dataloader = None
if dataloader is None:
dataloader = self.trainer.get_train_dataloader()
# Generate samples
diffusion_cfg = self.trainer.cfg.diffusion
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.processing_class,
dataloader=dataloader,
num_generation_samples=diffusion_cfg.num_generation_samples,
max_length=diffusion_cfg.generation_max_length,
num_diffusion_steps=diffusion_cfg.generation_steps,
temperature=diffusion_cfg.generation_temperature,
mask_token_id=diffusion_cfg.mask_token_id,
)
# Log samples
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples."""
if not samples:
return
logger.info("=" * 60)
logger.info("GENERATED SAMPLES")
logger.info("=" * 60)
for i, sample_data in enumerate(samples, 1):
original = sample_data["original"]
masked = sample_data["masked"]
generated = sample_data["generated"]
mask_ratio = sample_data["mask_ratio"]
masked_tokens = sample_data["masked_tokens"]
total_tokens = sample_data["total_tokens"]
logger.info(f"\nSample {i}:")
logger.info(f"\tOriginal ({total_tokens} tokens): {original}")
logger.info(
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
f"{mask_ratio:.1%}): {masked}"
)
try:
gen_ids = sample_data.get("generated_ids")
orig_ids = sample_data.get("orig_ids")
masked_positions = set(sample_data.get("masked_positions") or [])
if isinstance(gen_ids, list) and isinstance(orig_ids, list):
styles: list[str] = []
for i, tid in enumerate(gen_ids):
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
styles.append("green")
elif i < len(orig_ids):
styles.append("red")
else:
styles.append("normal")
else:
same = i < len(orig_ids) and tid == orig_ids[i]
styles.append("dim" if same else "normal")
spans: list[tuple[str, int, int]] = []
if gen_ids:
cur = styles[0]
start = 0
for i in range(1, len(gen_ids)):
s = styles[i]
if s != cur:
spans.append((cur, start, i))
cur, start = s, i
spans.append((cur, start, len(gen_ids)))
parts = []
for style_name, a, b in spans:
chunk_text = self.trainer.processing_class.decode(
gen_ids[a:b], skip_special_tokens=False
)
if style_name == "green":
parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
elif style_name == "red":
parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
else:
if style_name == "dim":
parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
else:
parts.append(chunk_text)
logger.info("\tGenerated:\n%s", "".join(parts))
else:
logger.info(f"\tGenerated: {generated}")
except Exception:
logger.info(f"\tGenerated: {generated}")
logger.info("=" * 60)
if self.trainer.cfg.use_wandb:
if wandb.run is not None:
wandb.log(
{
"generated_samples": wandb.Table(
columns=[
"step",
"original",
"masked",
"generated",
"mask_ratio",
"masked_tokens",
"total_tokens",
],
data=[
[
step,
sample["original"],
sample["masked"],
sample["generated"],
f"{sample['mask_ratio']:.1%}",
sample["masked_tokens"],
sample["total_tokens"],
]
for sample in samples
],
)
},
step=step,
)

View File

@@ -1,409 +0,0 @@
"""Sample generation utilities for diffusion training."""
import re
from typing import Any, List, Literal, Optional
import torch
from axolotl.utils.logging import get_logger
from .utils import create_bidirectional_attention_mask
LOG = get_logger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
dataloader: Optional[Any] = None,
num_generation_samples: int = 3,
max_length: int = 100,
num_diffusion_steps: int = 128,
temperature: float = 0.0,
mask_token_id: int = 32000,
mode: Literal["random", "completion"] = "random",
completion_tokens: int = 0,
target_mask_ratio: Optional[float] = None,
) -> List[dict]:
"""
Generate text samples using the diffusion model by randomly masking sequences from
the given dataset and running the reverse diffusion process.
Args:
model: The wrapped or unwrapped model
tokenizer: Tokenizer for encoding/decoding
dataloader: Validation dataloader (for sampling sequences)
num_generation_samples: Number of samples to generate
max_length: Maximum length of sequences to use
num_diffusion_steps: Number of diffusion steps for generation
temperature: Temperature for sampling (0.0 = deterministic)
mask_token_id: Token ID used for masking
Returns:
List of dictionaries with original text, masked text, and generated text
"""
if dataloader is None:
LOG.warning("No validation dataloader provided, cannot generate samples")
return []
unwrapped_model = model.module if hasattr(model, "module") else model
training = unwrapped_model.training
unwrapped_model.eval()
# Resolve device robustly (some modules don't expose `.device`)
device = getattr(unwrapped_model, "device", None)
if device is None:
try:
device = next(unwrapped_model.parameters()).device
except StopIteration:
device = torch.device("cpu")
generations = []
# Sample sequences from validation dataset
sampled_sequences = _sample_sequences_from_dataloader(
dataloader, num_generation_samples, max_length, device
)
LOG.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
# Generate samples using reverse diffusion process
with torch.no_grad():
for sample in sampled_sequences:
if isinstance(sample, dict):
original_sequence = sample.get("input_ids")
labels_seq = sample.get("labels")
attn_seq = sample.get("attention_mask")
else:
original_sequence = sample
labels_seq = None
attn_seq = None
generation_result = generate(
unwrapped_model,
tokenizer,
original_sequence,
num_diffusion_steps,
temperature,
mask_token_id,
mode=mode,
completion_tokens=completion_tokens,
target_mask_ratio=target_mask_ratio,
labels=labels_seq,
attention_mask=attn_seq,
)
generations.append(generation_result)
# Restore prior training state
if training:
unwrapped_model.train()
else:
unwrapped_model.eval()
return generations
def _sample_sequences_from_dataloader(
dataloader: Any, num_samples: int, max_length: int, device: torch.device
) -> List[Any]:
"""Sample sequences from validation dataloader."""
sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = []
sample_count = 0
# Skip a random number of batches (we could be more clever about this)
skip_batches = torch.randint(0, 10, (1,)).item()
batch_count = 0
for batch in dataloader:
# Skip some batches for variety
if batch_count < skip_batches:
batch_count += 1
continue
if sample_count >= num_samples:
break
batch_count += 1
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask")
labels = batch.get("labels")
# Randomly sample from sequences in this batch
batch_indices = torch.randperm(input_ids.size(0)).tolist()
for i in batch_indices:
if sample_count >= num_samples:
break
# Get actual sequence length (non-padded)
if attention_mask is not None:
seq_len = attention_mask[i].sum().item()
else:
seq_len = input_ids.size(1)
if seq_len < 10:
continue
# Determine truncation length
max_total = min(seq_len, max_length)
if labels is not None:
labels_i = labels[i][:seq_len]
answer_mask = labels_i != -100
if not answer_mask.any():
# No answer tokens; skip for SFT masking
continue
first_ans_idx = int(
torch.nonzero(answer_mask, as_tuple=False)[0].item()
)
prompt_len = first_ans_idx
if prompt_len >= max_total:
# Prompt alone reaches cap; cannot include any answer
continue
remaining_answer = int(answer_mask[prompt_len:].sum().item())
allowed_answer = max_total - prompt_len
take_answer = min(remaining_answer, allowed_answer)
if take_answer <= 0:
continue
actual_length = prompt_len + take_answer
else:
actual_length = max_total
# Extract the (possibly truncated) sequence
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
attn_seq = (
attention_mask[i][:actual_length].unsqueeze(0).to(device)
if attention_mask is not None
else None
)
if labels is not None:
labels_seq = labels[i][:actual_length].unsqueeze(0).to(device)
sampled_sequences.append(
{
"input_ids": sequence,
"labels": labels_seq,
"attention_mask": attn_seq,
}
)
else:
if attn_seq is not None:
sampled_sequences.append(
{"input_ids": sequence, "attention_mask": attn_seq}
)
else:
sampled_sequences.append(sequence)
sample_count += 1
return sampled_sequences
def generate(
model: torch.nn.Module,
tokenizer: Any,
original_sequence: torch.Tensor,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
*,
mode: Literal["random", "completion"] = "random",
completion_tokens: int = 0,
target_mask_ratio: Optional[float] = None,
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> dict:
"""Generate a single sample using reverse diffusion."""
# Get original text for comparison
original_text = tokenizer.decode(
original_sequence[0].cpu(), skip_special_tokens=True
)
# Build masked sequence
if (
labels is not None
and labels.numel() > 0
and (labels == -100).any()
and (labels != -100).any()
):
# SFT case: completely mask all answer tokens (labels != -100)
total_tokens = original_sequence.size(1)
masked_indices = (labels != -100).to(dtype=torch.bool)
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
masked_tokens = int(masked_indices.sum().item())
mask_ratio = masked_tokens / max(int(total_tokens), 1)
elif mode == "completion" and completion_tokens > 0:
# Append mask tokens to the right for completion
total_tokens = original_sequence.size(1) + int(completion_tokens)
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, -int(completion_tokens) :] = True
append = torch.full(
(1, int(completion_tokens)), mask_token_id, device=original_sequence.device
)
masked_sequence = torch.cat([original_sequence, append], dim=1)
masked_tokens = int(completion_tokens)
mask_ratio = masked_tokens / total_tokens
else:
# Apply random masking with optional fixed ratio
total_tokens = original_sequence.size(1)
if target_mask_ratio is None:
min_ratio, max_ratio = 0.1, 0.7
target_mask_ratio = (
torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
)
target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio)))
# Create random mask indices
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, mask_positions] = True
# Create masked sequence
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
# Calculate actual mask ratio
masked_tokens = masked_indices.sum().item()
mask_ratio = masked_tokens / total_tokens
# Get masked text for comparison
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
# Run reverse diffusion process
sequence = masked_sequence.clone()
attention_mask = create_bidirectional_attention_mask(
sequence, attention_mask, sample_packing=attention_mask is not None
)
for step in range(num_diffusion_steps):
sequence = _diffusion_step(
model,
sequence,
step,
num_diffusion_steps,
temperature,
mask_token_id,
attention_mask,
)
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
# Collect diagnostic info
final_ids = sequence[0].detach().cpu().tolist()
orig_ids_for_render = original_sequence[0].detach().cpu().tolist()
if masked_indices is not None:
masked_positions = (
torch.where(masked_indices[0])[0].detach().cpu().tolist()
if masked_indices.ndim == 2
else []
)
else:
masked_positions = []
result = {
"original": original_text,
"masked": masked_text,
"generated": generated_text,
"mask_ratio": mask_ratio,
"masked_tokens": masked_tokens,
"total_tokens": total_tokens,
"generated_ids": final_ids,
"masked_positions": masked_positions,
"orig_ids": orig_ids_for_render,
"formatted": (
f"Original: '{original_text}' → Masked: '{masked_text}' "
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
),
}
return result
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
"""Clean up masked text for display."""
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
# Remove literal special token strings
if hasattr(tokenizer, "special_tokens_map"):
for token_value in tokenizer.special_tokens_map.values():
if token_value and isinstance(token_value, str):
cleaned = cleaned.replace(token_value, "")
# Normalize whitespace but preserve newlines
cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n")
cleaned = re.sub(r"[ \t]+", " ", cleaned)
cleaned = "\n".join(line.rstrip() for line in cleaned.split("\n")).strip()
return cleaned
def _diffusion_step(
model: torch.nn.Module,
sequence: torch.Tensor,
step: int,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Perform a single diffusion step with remasking."""
# Only process if there are masked tokens remaining
current_mask = sequence == mask_token_id
if not current_mask.any():
return sequence
# Create or use provided attention mask
if attention_mask is None:
batch_size, seq_len = sequence.shape
attention_mask = torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
)
# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = outputs.logits
# Only sample at currently masked positions
if current_mask.any():
masked_logits = logits[current_mask]
# Apply temperature scaling
if temperature > 0:
scaled_logits = masked_logits / temperature
else:
scaled_logits = masked_logits
# Suppress mask token in outputs
scaled_logits[:, mask_token_id] = -float("inf")
if temperature > 0:
# Add Gumbel noise for sampling
gumbel_noise = -torch.log(
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
)
gumbel_logits = scaled_logits + gumbel_noise
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
else:
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
# Calculate probabilities for confidence scoring
probs = torch.softmax(scaled_logits, dim=-1)
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
# Determine how many tokens to unmask this step
remaining_masked = current_mask.sum().item()
if step == num_diffusion_steps - 1:
num_to_unmask = remaining_masked
else:
unmask_ratio = 1.0 / (num_diffusion_steps - step)
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
# Select highest confidence predictions to unmask
if num_to_unmask >= remaining_masked:
sequence[current_mask] = predicted_tokens
else:
_, top_indices = predicted_token_probs.topk(num_to_unmask)
mask_positions = torch.where(current_mask)[1]
positions_to_unmask = mask_positions[top_indices]
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
return sequence

View File

@@ -1,41 +0,0 @@
"""Diffusion LM training plugin for Axolotl."""
from peft import PeftModel
from transformers import PreTrainedModel
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .trainer import DiffusionTrainer
LOG = get_logger(__name__)
class DiffusionPlugin(BasePlugin):
"""
Plugin for diffusion language model training.
This plugin enables diffusion-based training using the LLaDA approach, which uses
random masking and bidirectional attention to train language models.
"""
def __init__(self):
super().__init__()
self.cfg = None
def get_input_args(self) -> str:
"""Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs"
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Perform actions after model is loaded."""
self.cfg = cfg
def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None:
"""Return custom trainer class for diffusion training."""
return DiffusionTrainer
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
"""Configure trainer after creation."""
trainer.set_config(cfg)

View File

@@ -1,301 +0,0 @@
"""Custom trainer for diffusion LM training."""
from typing import Any, Literal
import torch
import torch.nn.functional as F
from torch import nn
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
from .utils import create_bidirectional_attention_mask
LOG = get_logger(__name__)
class DiffusionTrainer(AxolotlTrainer):
"""Custom trainer for diffusion LM training that overrides loss computation."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cfg = None
self._special_token_ids = None
def set_config(self, config: DictDefault):
"""Set config for diffusion training."""
self.cfg = config
self._cache_special_token_ids()
self._resolve_mask_token_id()
token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0))
LOG.info(f"Diffusion: using mask_token_id={token_id}")
if getattr(config.diffusion, "generate_samples", True):
generation_callback = DiffusionGenerationCallback(self)
self.add_callback(generation_callback)
def _resolve_mask_token_id(self) -> None:
"""Ensure mask_token_id is valid for the current tokenizer."""
from .utils import resolve_mask_token_id
tokenizer = getattr(self, "processing_class", None)
if tokenizer is None:
return
mid = resolve_mask_token_id(
tokenizer,
self.cfg,
allow_add=True,
model=getattr(self, "model", None),
)
try:
self.cfg.diffusion.mask_token_id = int(mid)
except Exception:
pass
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
labels = inputs.get("labels")
if input_ids is None:
raise ValueError("input_ids is required for diffusion training")
loss, outputs = self._compute_diffusion_loss(
model, input_ids, attention_mask, labels
)
if return_outputs:
return loss, outputs
return loss
def _cache_special_token_ids(self):
"""Cache special token IDs to avoid repeated tokenizer access."""
if self.processing_class is None:
self._special_token_ids = set()
return
tokenizer = self.processing_class
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = int(self.cfg.diffusion.mask_token_id)
mask_value = torch.full_like(input_ids, mask_token_id)
noisy_batch = torch.where(masked_indices, mask_value, input_ids)
return noisy_batch, masked_indices, p_mask
def _compute_diffusion_loss(
self,
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | Any]:
"""
Compute diffusion loss.
Args:
model: The model to compute loss for.
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
metrics: Dictionary of metrics.
"""
# Short-circuit empty sequences
if input_ids is None or input_ids.numel() == 0 or input_ids.shape[1] == 0:
zero = torch.tensor(
0.0,
device=(input_ids.device if input_ids is not None else None),
requires_grad=True,
)
return zero, {}
# If an attention_mask is provided and all positions are padding for every
# sample in this batch, skip the step.
if attention_mask is not None:
if attention_mask.dim() == 2 and (attention_mask.sum(dim=1) == 0).all():
zero = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
return zero, {}
# Apply forward process
noisy_batch, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.cfg.diffusion.eps
)
# Create bidirectional attention mask
bidirectional_mask = create_bidirectional_attention_mask(
input_ids, attention_mask, sample_packing=self.cfg.sample_packing
)
# Forward pass
outputs = model(
input_ids=noisy_batch.long(),
attention_mask=bidirectional_mask,
)
logits = outputs.logits
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.cfg.diffusion.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
if labels is not None:
# For SFT data: normalize by answer token count per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
batch_size = input_ids.shape[0]
loss_per_sample = torch.zeros(batch_size, device=input_ids.device)
for i in range(batch_size):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
denom = answer_lengths[i].clamp(min=1.0)
loss_per_sample[i] = sample_loss / denom
loss = loss_per_sample.mean()
else:
# Non-SFT: when importance weighting is enabled, use unbiased estimator
# (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
# for stable scaling across varying mask ratios.
if self.cfg.diffusion.importance_weighting:
loss = weighted_loss.sum() / (
input_ids.shape[0] * input_ids.shape[1]
)
else:
loss = weighted_loss.mean()
ce_loss = token_loss.mean()
# Compute accuracy on masked tokens
with torch.no_grad():
pred_tokens = masked_logits.argmax(dim=-1)
accuracy = (pred_tokens == masked_targets).float().mean()
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
accuracy = torch.tensor(0.0, device=input_ids.device)
ce_loss = torch.tensor(0.0, device=input_ids.device)
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
avg_p_mask = (
p_mask[masked_indices].mean().item() if masked_indices.any() else 0.0
)
metrics = {
"loss": loss.item(),
"accuracy": accuracy.item(),
"mask_ratio": masked_indices.float().mean().item(),
"num_masked_tokens": (masked_indices.sum().item(), "sum"),
"avg_p_mask": avg_p_mask,
"ce_loss": ce_loss.item(),
}
# If doing SFT training, log answer-specific metrics
if self.cfg.datasets is not None:
with torch.no_grad():
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
total_answer_tokens = answer_mask.sum().item() # type: ignore
total_tokens = labels.numel() # type: ignore
metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
metrics["avg_answer_length"] = answer_lengths.mean().item()
if self.cfg.diffusion.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
self.store_metrics(metrics, train_eval=train_eval)
return loss, outputs

View File

@@ -1,159 +0,0 @@
"""Shared utilities for diffusion integration."""
from __future__ import annotations
from typing import Any, Optional
import torch
from axolotl.utils.dict import DictDefault
def resolve_mask_token_id(
tokenizer: Any,
cfg: DictDefault,
*,
allow_add: bool,
model: Any | None = None,
default_token: str = "<|diffusion_mask|>",
) -> int:
"""Resolve mask token id. Training may add a new special token; inference won't."""
# Determine vocab size if available
vocab_size = None
if tokenizer is not None:
if hasattr(tokenizer, "vocab_size") and tokenizer.vocab_size is not None:
try:
vocab_size = int(tokenizer.vocab_size) # type: ignore[arg-type]
except Exception:
vocab_size = None
elif hasattr(tokenizer, "__len__"):
try:
vocab_size = int(len(tokenizer))
except Exception:
vocab_size = None
# Use explicit id from config if provided
diffusion_cfg = getattr(cfg, "diffusion", None)
# Fallback to top-level attr names only if nested missing (shouldn't happen)
cfg_id = (
getattr(diffusion_cfg, "mask_token_id", None)
if diffusion_cfg is not None
else getattr(cfg, "diffusion_mask_token_id", None)
)
if isinstance(cfg_id, int) and cfg_id >= 0:
if vocab_size is None or cfg_id < vocab_size:
return int(cfg_id)
def _existing_special_token_id(token_str: str | None) -> int | None:
"""Attempt to resolve an existing special token string to a real ID."""
if not token_str or not hasattr(tokenizer, "convert_tokens_to_ids"):
return None
try:
token_id = tokenizer.convert_tokens_to_ids(token_str)
except Exception:
return None
if not isinstance(token_id, int) or token_id < 0:
return None
# Ensure it's registered as special and not UNK, and within vocab
unk_id = getattr(tokenizer, "unk_token_id", None)
specials = set(getattr(tokenizer, "all_special_tokens", []) or [])
addl = set(getattr(tokenizer, "additional_special_tokens", []) or [])
is_special = token_str in specials or token_str in addl
in_vocab = vocab_size is None or token_id < vocab_size
if (
(unk_id is not None and token_id == unk_id)
or not is_special
or not in_vocab
):
return None
return token_id
# Try mask token string if provided
token_str = (
getattr(diffusion_cfg, "mask_token_str", None)
if diffusion_cfg is not None
else getattr(cfg, "diffusion_mask_token_str", None)
)
for candidate in (token_str, default_token):
token_id = _existing_special_token_id(candidate)
if isinstance(token_id, int):
try:
if diffusion_cfg is None:
cfg.diffusion_mask_token_id = int(token_id) # legacy fallback
else:
diffusion_cfg.mask_token_id = int(token_id)
except Exception:
pass
return int(token_id)
# Optionally add and return a dedicated special token during training
if allow_add and hasattr(tokenizer, "add_special_tokens"):
token_to_add = token_str or default_token
try:
tokenizer.add_special_tokens({"additional_special_tokens": [token_to_add]})
# Resize embeddings if possible
if (
model is not None
and hasattr(tokenizer, "__len__")
and hasattr(model, "resize_token_embeddings")
):
try:
model.resize_token_embeddings(len(tokenizer))
except Exception:
pass
new_id = tokenizer.convert_tokens_to_ids(token_to_add)
if isinstance(new_id, int) and new_id >= 0:
try:
if diffusion_cfg is None:
cfg.diffusion_mask_token_id = int(new_id) # legacy fallback
else:
diffusion_cfg.mask_token_id = int(new_id)
except Exception:
pass
return int(new_id)
except Exception:
pass
# Fallback to unk or 0 (do not update cfg)
fallback = getattr(tokenizer, "unk_token_id", 0) or 0
return int(fallback)
def create_bidirectional_attention_mask(
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
sample_packing: bool = False,
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking.
Handles sample-packed sequences where different samples are identified
by different attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
sample_packing: Whether sample packing is enabled
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Handle sample packing: tokens can only attend within their sample
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
return bidirectional_mask.unsqueeze(1)

View File

@@ -14,7 +14,6 @@ from peft import (
PeftConfig,
PeftMixedModel,
PeftModel,
TaskType,
get_peft_model,
)
from transformers import PreTrainedModel
@@ -102,15 +101,6 @@ def load_lora(
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
# Determine the correct PEFT task type
model_cls = type(model).__name__
if "SequenceClassification" in model_cls:
task_type = TaskType.SEQ_CLS
elif "TokenClassification" in model_cls:
task_type = TaskType.TOKEN_CLS
else:
task_type = TaskType.CAUSAL_LM
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
@@ -122,7 +112,7 @@ def load_lora(
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none",
task_type=task_type,
task_type="CAUSAL_LM",
**lora_config_kwargs,
)

View File

@@ -224,21 +224,27 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
self._apply_activation_checkpointing()
use_reentrant = None
if (
self.cfg.gradient_checkpointing_kwargs
and self.cfg.gradient_checkpointing_kwargs.get("use_reentrant", True)
):
use_reentrant = True
self._apply_activation_checkpointing(use_reentrant=use_reentrant)
self._resize_token_embeddings()
self._adjust_model_config()
self._configure_embedding_dtypes()
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _apply_activation_checkpointing(self):
def _apply_activation_checkpointing(self, use_reentrant: bool | None = None):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
ac_wrap_hf_model,
)
# ^^ importing this at the module level breaks plugins
ac_wrap_hf_model(self.model)
ac_wrap_hf_model(self.model, use_reentrant=use_reentrant)
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""
@@ -673,33 +679,6 @@ class ModelLoader:
return hf_ds_cfg
def _load_model_from_config(self, model_loader_class=None) -> PreTrainedModel:
"""
Load model with random initialization using from_config.
Uses the selected loader when provided; otherwise falls back to the auto loader.
"""
loader = model_loader_class or self.auto_model_loader
if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
model = loader.from_config(
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
)
else:
model = loader(config=self.model_config)
return model
def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
"""Load model from pretrained weights."""
loader = model_loader_class or self.auto_model_loader
kwargs = {
"config": self.model_config,
"trust_remote_code": self.cfg.trust_remote_code or False,
**self.model_kwargs,
}
return loader.from_pretrained(self.base_model, **kwargs)
def _build_model(self) -> bool:
"""Load model, with load strategy depending on config."""
skip_move_to_device = False
@@ -714,8 +693,7 @@ class ModelLoader:
if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True
# Don't delete device_map for QLoRA + FSDP - it was set correctly in
# _set_device_map
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
if (
"device_map" in self.model_kwargs
and not self.is_qlora_and_fsdp_enabled
@@ -744,11 +722,6 @@ class ModelLoader:
or self.cfg.qlora_sharded_model_loading
)
):
if self.cfg.reinit_weights:
LOG.warning(
"reinit_weights is not supported with sharded quantized loading. "
"Loading from pretrained weights instead."
)
quant_storage = self.cfg.torch_dtype
quantization_config = getattr(
self.model_config, "quantization_config", None
@@ -764,12 +737,33 @@ class ModelLoader:
quantization_config=quantization_config,
)
skip_move_to_device = True
elif self.model_type == "MambaLMHeadModel":
if self.cfg.reinit_weights:
LOG.warning(
"reinit_weights is not supported with MambaLMHeadModel. "
"Loading from pretrained weights instead."
elif (
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
if self.auto_model_loader in [
AutoModelForCausalLM,
AutoModelForVision2Seq,
]:
self.model = self.auto_model_loader.from_config(
config=self.model_config,
)
else:
self.model = self.auto_model_loader(config=self.model_config)
else:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
)
elif self.model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss()
@@ -782,27 +776,41 @@ class ModelLoader:
self.base_model,
**self.model_kwargs,
)
elif (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
self.model = getattr(transformers, self.model_type).from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
elif self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
# Please don't remove underscore binding without reading the fn docstring
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
if (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# Use model type from transformers
model_loader_class = getattr(transformers, self.model_type)
else:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
if self.cfg.reinit_weights:
self.model = self._load_model_from_config(model_loader_class)
else:
self.model = self._load_model_from_pretrained(model_loader_class)
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True

View File

@@ -3,8 +3,8 @@
Applies pre- and post-model load patches for various fixes and optimizations.
"""
import importlib.util
import os
import importlib.util
from functools import cached_property
import addict
@@ -468,9 +468,8 @@ class PatchManager:
def _apply_patch_deepspeed_zero3(self):
try:
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
if self.cfg.activation_offloading is True and (
is_deepspeed_zero3_enabled()

View File

@@ -296,7 +296,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
)
tokenizer.chat_template = chat_template_string
elif getattr(tokenizer, "chat_template", None) is None:
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)

View File

@@ -160,11 +160,9 @@ def get_state_dict(self, model, unwrap=True):
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import (
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
)
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
full_state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
@@ -180,38 +178,6 @@ def get_state_dict(self, model, unwrap=True):
return state_dict
def cast_lora_module(module):
base_layer_dtype = module.base_layer.weight.dtype
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
for active_adapter in module.active_adapters:
if module.lora_A:
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_B:
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_A:
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_B:
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_magnitude_vector:
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
@@ -227,37 +193,18 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
fully_shard(module, **fsdp2_kwargs)
module.set_reshard_after_forward(False)
module.set_reshard_after_backward(False)
# for active_adapter in module.active_adapters:
# for adapter_name in [
# "lora_A",
# "lora_B",
# "lora_embedding_A",
# "lora_embedding_B",
# "lora_magnitude_vector",
# ]:
# adapter_module = getattr(module, adapter_name, None)
# # print(adapter_module, adapter_name)
# # torch.distributed.breakpoint()
# if not adapter_module:
# continue
# fsdp_adapter_module = fully_shard(adapter_module[active_adapter], **fsdp2_kwargs)
# # fsdp_adapter_module.unshard()
# fsdp_adapter_module.set_reshard_after_backward(False)
# fsdp_adapter_module.set_reshard_after_forward(False)
# torch.distributed.breakpoint()
# if module.lora_A:
# fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
# if module.lora_B:
# fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_A:
# fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_B:
# fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
# if module.lora_magnitude_vector:
# fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
for active_adapter in module.active_adapters:
if module.lora_A:
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
if module.lora_B:
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_A:
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_B:
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
if module.lora_magnitude_vector:
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
return log_bias_dtype_mismatch
@@ -371,26 +318,16 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
model.tie_weights()
is_peft_model = isinstance(model, PeftModel)
# TODO - this doesn't actually do anything
for name, module in model.named_children():
if name == "experts":
# torch.distributed.breakpoint()
for expert in module.children():
# torch.distributed.breakpoint()
print(f"expert: {expert}")
for lora_module in expert.children():
print(f"lora {lora_module}")
# torch.distributed.breakpoint()
cast_lora_module(lora_module)
_process_lora_module_for_fsdp(lora_module, fsdp2_kwargs)
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule):
# torch.distributed.breakpoint()
cast_lora_module(module)
# torch.distributed.breakpoint()
if is_peft_model and isinstance(module, LoraLayer):
module_log_bias_mismatch = _process_lora_module_for_fsdp(
module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)
@@ -407,9 +344,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
)
# for module in model.named_modules():
# if "Lora" in
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# We re-register the buffers, as they may not be in the state_dict
for fqn, buffer_tensor in original_non_persistent_buffers.items():

View File

@@ -1,12 +1,11 @@
"""Flex attention monkey patch"""
import sys
from packaging import version
import torch
import transformers
from packaging import version
from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)

View File

@@ -1,6 +1,5 @@
import importlib
import importlib.util
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)

View File

@@ -36,13 +36,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"glm",
"glm4",
"smollm3",
"granite",
"granitemoe",
"hunyuan_v1_dense",
"hunyuan_v1_moe",
"gpt_oss",
"arcee",
"seed_oss",
]

View File

@@ -30,7 +30,11 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
from axolotl.loaders import (
ModelLoader,
load_processor,
load_tokenizer,
)
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -230,15 +234,16 @@ def save_trained_model(
# handle QAT
if cfg.qat:
from axolotl.utils.quantization import convert_qat_model
from axolotl.utils.quantization import convert_qat_model_for_ptq
convert_qat_model(
LOG.info("Processing QAT model for saving...")
convert_qat_model_for_ptq(
model,
quantize_embedding=cfg.qat.quantize_embedding,
)
LOG.info(
"QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`"
" with the same config which you used for training."
"QAT modules have been converted for PTQ. Please ensure you quantize "
"your model weights with `axolotl quantize`."
)
# Handle ReLoRA early return case
if cfg.relora:
@@ -332,7 +337,9 @@ def save_trained_model(
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import save_compressed_model
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,

View File

@@ -17,8 +17,8 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
LOG = get_logger(__name__)

View File

@@ -1,14 +1,14 @@
"""Init for `axolotl.utils.data` module."""
from axolotl.utils.data.streaming import (
encode_streaming,
wrap_streaming_dataset,
)
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import (
get_dataset_wrapper,
prepare_datasets,
)
from axolotl.utils.data.streaming import (
encode_streaming,
wrap_streaming_dataset,
)
from axolotl.utils.data.utils import md5
__all__ = [

View File

@@ -16,6 +16,7 @@ from transformers import PreTrainedTokenizer, ProcessorMixin
from axolotl.prompters import Prompter
from axolotl.utils.data.lock import FileLockLoader
from axolotl.utils.data.streaming import wrap_streaming_dataset
from axolotl.utils.data.shared import (
create_train_validation_split,
datasets_with_name_generator,
@@ -26,7 +27,6 @@ from axolotl.utils.data.shared import (
save_preprocessed_dataset,
try_load_from_hub,
)
from axolotl.utils.data.streaming import wrap_streaming_dataset
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
handle_long_seq_in_dataset,

View File

@@ -6,6 +6,8 @@ from importlib.metadata import version
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
)
from accelerate.utils.environment import (
get_gpu_info,
)
from packaging.version import Version, parse

View File

@@ -3,47 +3,30 @@ Utilities for quantization including QAT and PTQ using torchao.
"""
import torch
from packaging import version
from torch import nn
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
_is_linear,
)
from axolotl.utils.schemas.enums import TorchAOQuantDType
quantization_config_to_str = {
Int8DynamicActivationInt4WeightConfig: "int8int4",
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
Float8DynamicActivationInt4WeightConfig: "fp8int4",
}
if version.parse(torch.__version__) >= version.parse("2.8.0"):
try:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
except:
pass
# int4 weight config imports will fail on machines with fbgemm-gpu installed
# without a CUDA runtime available so we do this safely
try:
from torchao.quantization.quant_api import Int4WeightOnlyConfig
quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
except:
pass
from axolotl.utils.schemas.enums import TorchIntDType
def get_quantization_config(
weight_dtype: TorchAOQuantDType,
activation_dtype: TorchAOQuantDType | None = None,
def get_ptq_config(
weight_dtype: TorchIntDType,
activation_dtype: TorchIntDType | None = None,
group_size: int | None = None,
) -> AOBaseConfig:
"""
@@ -62,101 +45,44 @@ def get_quantization_config(
or if the group size is not specified for int8 or int4 weight only quantization.
"""
if activation_dtype is None:
if weight_dtype == TorchAOQuantDType.int8:
raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.")
if weight_dtype == TorchAOQuantDType.int4:
from torchao.quantization.quant_api import Int4WeightOnlyConfig
if group_size is not None:
return Int4WeightOnlyConfig(group_size=group_size, version=2)
else:
return Int4WeightOnlyConfig(version=2)
if (
activation_dtype == TorchAOQuantDType.int4
and weight_dtype == TorchAOQuantDType.int4
):
raise ValueError(
"Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT."
)
if (
activation_dtype == TorchAOQuantDType.int8
and weight_dtype == TorchAOQuantDType.int8
):
raise ValueError(
"Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT."
)
if (
activation_dtype == TorchAOQuantDType.int8
and weight_dtype == TorchAOQuantDType.int4
):
if group_size is not None:
return Int8DynamicActivationInt4WeightConfig(group_size=group_size)
else:
return Int8DynamicActivationInt4WeightConfig()
if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
):
return Float8DynamicActivationFloat8WeightConfig()
if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.int4
):
return Float8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.nvfp4:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig()
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr]
return UIntXWeightOnlyConfig(
dtype=weight_dtype.value,
group_size=group_size,
set_inductor_config=False,
)
if weight_dtype == TorchIntDType.int8:
if group_size is None:
raise ValueError(
"group_size must be specified for int8 weight only quantization"
)
return Int8WeightOnlyConfig(
group_size=group_size,
)
if weight_dtype == TorchIntDType.int4:
if group_size is None:
raise ValueError(
"group_size must be specified for int4 weight only quantization"
)
return Int4WeightOnlyConfig(
group_size=group_size,
)
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4:
return Int4DynamicActivationInt4WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8:
return Int8DynamicActivationInt8WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4:
return Int8DynamicActivationInt4WeightConfig()
raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
)
def quantize_model(
model,
weight_dtype: TorchAOQuantDType,
group_size: int | None = None,
activation_dtype: TorchAOQuantDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=activation_dtype,
group_size=group_size,
)
quantize_(model, linear_ptq_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=None,
group_size=group_size,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def prepare_model_for_qat(
model,
weight_dtype: TorchAOQuantDType,
group_size: int | None = None,
activation_dtype: TorchAOQuantDType | None = None,
weight_dtype: TorchIntDType,
group_size: int,
activation_dtype: TorchIntDType | None = None,
quantize_embedding: bool = False,
):
"""
@@ -174,40 +100,86 @@ def prepare_model_for_qat(
Raises:
ValueError: If the activation/weight dtype combination is invalid.
"""
base_config = get_quantization_config(
if activation_dtype:
activation_config = FakeQuantizeConfig(
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
)
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None if activation_dtype is None else activation_config,
weight_config=weight_config,
)
quantize_(model, linear_quantize_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None,
weight_config=weight_config,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def quantize_model_for_ptq(
model,
weight_dtype: TorchIntDType,
group_size: int | None = None,
activation_dtype: TorchIntDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model for post-training quantization.
It swaps the model's linear layers with fake quantized linear layers.
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_ptq_config(
weight_dtype=weight_dtype,
activation_dtype=activation_dtype,
group_size=group_size,
)
qat_config = QATConfig(base_config)
quantize_(model, qat_config)
quantize_(model, linear_ptq_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_base_config = get_quantization_config(
embedding_quantize_config = get_ptq_config(
weight_dtype=weight_dtype,
activation_dtype=None,
group_size=group_size,
)
embedding_qat_config = QATConfig(embedding_base_config)
quantize_(
model,
embedding_qat_config,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def convert_qat_model(
def convert_qat_model_for_ptq(
model,
quantize_embedding: bool = False,
*,
quantize_embedding: bool | None = None,
):
"""
This function converts a QAT model which has fake quantized layers back to the original model.
This function is used to convert a swap fake-quantized modules in a model
which has been trained with QAT back to the original modules, ready for PTQ.
Args:
model: The model to convert.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
config = QATConfig(step="convert")
quantize_(model, config)
if quantize_embedding:
quantize_(
model,
config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def filter_fn(m, _):
return isinstance(m, nn.Embedding) or _is_linear(m)
else:
filter_fn = _is_linear
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)

View File

@@ -106,12 +106,6 @@ class AxolotlInputConfig(
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
},
)
reinit_weights: bool | None = Field(
default=None,
json_schema_extra={
"description": "Reinitialize model weights randomly instead of loading pretrained weights"
},
)
trainer_cls: str | None = Field(
default=None,

View File

@@ -5,21 +5,18 @@ from enum import Enum
import torch
class TorchAOQuantDType(Enum):
int4 = torch.int4
int8 = torch.int8
float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4"
class TorchIntDType(Enum):
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
def from_string(str):
if str == "int4":
return TorchAOQuantDType.int4
if str == "int8":
return TorchAOQuantDType.int8
if str in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if str == "nvfp4":
return TorchAOQuantDType.nvfp4
uint1 = getattr(torch, "uint1", None)
uint2 = getattr(torch, "uint2", None)
uint3 = getattr(torch, "uint3", None)
uint4 = getattr(torch, "uint4", None)
uint5 = getattr(torch, "uint5", None)
uint6 = getattr(torch, "uint6", None)
uint7 = getattr(torch, "uint7", None)
int4 = getattr(torch, "int4", None)
int8 = getattr(torch, "int8", None)
class RLType(str, Enum):

View File

@@ -6,23 +6,7 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from axolotl.utils.schemas.enums import TorchAOQuantDType
def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
if v is None:
return None
if v == "int4":
return TorchAOQuantDType.int4
if v == "int8":
return TorchAOQuantDType.int8
if v in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4":
return TorchAOQuantDType.nvfp4
raise ValueError(
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
)
from axolotl.utils.schemas.enums import TorchIntDType
class QATConfig(BaseModel):
@@ -30,13 +14,13 @@ class QATConfig(BaseModel):
QAT Config Schema
"""
activation_dtype: TorchAOQuantDType | None = Field(
activation_dtype: TorchIntDType | None = Field(
default=None,
description="Fake quantization layout to use for activation quantization.",
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
)
weight_dtype: TorchAOQuantDType = Field(
default=TorchAOQuantDType.int8,
description="Fake quantization layout to use for weight quantization.",
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8,
description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"',
)
quantize_embedding: bool | None = Field(
default=False, description="Quantize embedding"
@@ -51,8 +35,12 @@ class QATConfig(BaseModel):
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
return validate_ao_dtype(v)
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
class PTQConfig(BaseModel):
@@ -60,13 +48,13 @@ class PTQConfig(BaseModel):
PTQ Config Schema
"""
weight_dtype: TorchAOQuantDType = Field(
default=TorchAOQuantDType.int8,
description="Fake quantization layout to use for weight quantization.",
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8,
description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8",
)
activation_dtype: TorchAOQuantDType | None = Field(
activation_dtype: TorchIntDType | None = Field(
default=None,
description="Fake quantization layout to use for activation quantization.",
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
)
quantize_embedding: bool | None = Field(
default=None, description="Whether to quantize the embedding layer."
@@ -78,5 +66,9 @@ class PTQConfig(BaseModel):
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
return validate_ao_dtype(v)
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")

View File

@@ -14,6 +14,7 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
LOG = get_logger(__name__)
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}

View File

@@ -1,139 +0,0 @@
"""E2E smoke test for diffusion training plugin."""
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
class TestDiffusion:
"""Test case for diffusion training plugin."""
def test_diffusion_smoke_test(self, temp_dir):
"""
Smoke test for diffusion training to ensure the plugin loads and trains without
error.
"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"trust_remote_code": True,
"sequence_len": 256,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
"logging_steps": 1,
"eval_steps": 3,
# Diffusion-specific config
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
"diffusion": {
# sample generation
"generate_samples": True,
"generation_interval": 1,
"num_generation_samples": 1,
"generation_steps": 2,
"generation_max_length": 32,
"generation_temperature": 0.0,
# training-specific
"mask_token_id": 16,
"eps": 1e-3,
"importance_weighting": False,
},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_diffusion_sft_labels(self, temp_dir):
"""Test that diffusion training properly handles SFT data with labels."""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"trust_remote_code": True,
"sequence_len": 256,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
"logging_steps": 1,
"eval_steps": 2,
# Diffusion-specific config
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
"diffusion": {
# sample generation
"generate_samples": True,
"generation_interval": 1,
"num_generation_samples": 1,
"generation_steps": 2,
"generation_max_length": 32,
"generation_temperature": 0.0,
# training-specific
"mask_token_id": 16,
"eps": 1e-3,
"importance_weighting": True,
},
# Ensure we have proper SFT labels
"train_on_inputs": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
# Verify that the dataset has labels
sample = dataset_meta.train_dataset[0]
assert "labels" in sample, "SFT dataset should have labels"
# Check that some labels are -100 (prompt tokens)
labels = sample["labels"]
if hasattr(labels, "tolist"):
labels = labels.tolist()
assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens"
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -43,7 +43,7 @@ class TestQATLlama:
"qat": {
"quantize_embedding": True,
"activation_dtype": "int8",
"weight_dtype": "int4",
"weight_dtype": "int8",
"group_size": 8,
},
"num_epochs": 1,
@@ -111,7 +111,7 @@ class TestQATLlama:
"qat": {
"quantize_embedding": True,
"activation_dtype": "int8",
"weight_dtype": "int4",
"weight_dtype": "int8",
"group_size": 8,
},
"save_first_step": False,

View File

@@ -5,40 +5,41 @@ Tests for axolotl.utils.quantization
import pytest
import torch
from torch import nn
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt4WeightConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
)
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.quantization import (
convert_qat_model,
get_quantization_config,
convert_qat_model_for_ptq,
get_ptq_config,
prepare_model_for_qat,
quantize_model,
quantize_model_for_ptq,
)
from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.enums import TorchIntDType
from axolotl.utils.schemas.quantization import QATConfig
from tests.e2e.utils import (
require_torch_2_8_0,
requires_cuda_ge_8_9,
requires_sm_ge_100,
)
from tests.e2e.utils import require_torch_2_6_0
@pytest.fixture()
def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B",
device_map="auto",
"HuggingFaceTB/SmolLM2-135M",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
with torch.device(dummy_model.device):
@@ -47,56 +48,45 @@ def model():
dummy_model.model.embed_tokens.weight.shape[1],
dtype=dummy_model.model.embed_tokens.weight.dtype,
)
yield dummy_model
del dummy_model
return dummy_model
ptq_config_test_cases = [
# weight_dtype, activation_dtype, group_size, expected_type
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
(
TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
TorchIntDType.uint4,
None,
Int8DynamicActivationInt4WeightConfig,
None,
UIntXWeightOnlyConfig,
{"dtype": torch.uint4, "group_size": None},
),
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
(
TorchIntDType.int4,
TorchIntDType.int4,
None,
Int4DynamicActivationInt4WeightConfig,
{},
),
(
TorchAOQuantDType.float8_e4m3fn,
TorchAOQuantDType.float8_e4m3fn,
TorchIntDType.int8,
TorchIntDType.int8,
None,
Float8DynamicActivationFloat8WeightConfig,
),
(
TorchAOQuantDType.int4,
TorchAOQuantDType.float8_e4m3fn,
None,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
{},
),
]
ptq_test_cases = [
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class
(TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor),
(
TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
8,
False,
None,
LinearActivationQuantizedTensor,
),
# (
# TorchAOQuantDType.int4,
# TorchAOQuantDType.float8_e4m3fn,
# None,
# False,
# None,
# Int4Tensor,
# ),
(TorchAOQuantDType.int4, None, None, False, None, Int4Tensor),
# Deprecated configs
(TorchAOQuantDType.int8, None, 8, False, ValueError, None),
(TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None),
(TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None),
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
(TorchIntDType.int8, None, 8, False, None),
(TorchIntDType.int4, None, 4, True, None),
(TorchIntDType.uint4, None, 8, False, None),
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
(TorchIntDType.int8, None, None, False, ValueError),
(TorchIntDType.int4, None, None, False, ValueError),
]
@@ -106,132 +96,44 @@ class TestQuantization:
"""
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,expected_type",
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
ptq_config_test_cases,
)
@requires_cuda_ge_8_9
@require_torch_2_8_0
@require_torch_2_6_0
def test_get_ptq_config(
self, weight_dtype, activation_dtype, group_size, expected_type
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
):
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type)
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config_int4_weight_only(self):
from torchao.quantization.quant_api import Int4WeightOnlyConfig
config = get_quantization_config(TorchAOQuantDType.int4, None, 4)
assert isinstance(config, Int4WeightOnlyConfig)
for param_name, param_value in expected_params.items():
if isinstance(param_value, (PerAxis, PerGroup)):
if isinstance(param_value, PerAxis):
assert isinstance(getattr(config, param_name), PerAxis)
assert getattr(config, param_name).axis == param_value.axis
else:
assert isinstance(getattr(config, param_name), PerGroup)
assert (
getattr(config, param_name).group_size == param_value.group_size
)
else:
assert getattr(config, param_name) == param_value
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class",
ptq_test_cases,
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
)
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
expected_tensor_class,
):
if expected_exception:
with pytest.raises(expected_exception):
quantize_model(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, expected_tensor_class
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, expected_tensor_class)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_fp8(
self,
model,
):
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
)
quantize_model(
model,
TorchAOQuantDType.float8_e4m3fn,
None,
TorchAOQuantDType.float8_e4m3fn,
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, Float8Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs
)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_nvfp4(
self,
model,
):
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
QuantizeTensorToNVFP4Kwargs,
)
quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, NVFP4Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs
)
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.int4, None, 8, False),
(TorchAOQuantDType.int4, None, 16, True),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True),
(
TorchAOQuantDType.float8_e4m3fn,
TorchAOQuantDType.float8_e4m3fn,
None,
False,
),
(TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True),
],
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
)
@require_torch_2_8_0
@requires_cuda_ge_8_9
@pytest.mark.parametrize("group_size", [4, 8])
@pytest.mark.parametrize("quantize_embedding", [False, True])
@require_torch_2_6_0
def test_prepare_model_for_qat(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
):
prepare_model_for_qat(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
@@ -240,19 +142,17 @@ class TestQuantization:
model.model.embed_tokens.weight_fake_quantizer.config.dtype
== weight_dtype.value
)
if group_size:
assert (
model.model.embed_tokens.weight_fake_quantizer.config.group_size
== group_size
)
assert (
model.model.embed_tokens.weight_fake_quantizer.config.group_size
== group_size
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
if group_size:
assert child.weight_fake_quantizer.config.group_size == group_size
assert child.weight_fake_quantizer.config.group_size == group_size
if activation_dtype:
assert hasattr(child, "activation_fake_quantizer")
assert (
@@ -262,40 +162,49 @@ class TestQuantization:
else:
assert child.activation_fake_quantizer is None
@require_torch_2_8_0
@requires_cuda_ge_8_9
def test_convert_qat_model(self, model):
config = QATConfig(
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model(
model,
config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
ptq_test_cases,
)
@require_torch_2_6_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
):
if expected_exception:
with pytest.raises(expected_exception):
quantize_model_for_ptq(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, AffineQuantizedTensor
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
if activation_dtype:
assert isinstance(
child.weight, LinearActivationQuantizedTensor
), (
"Linear weight should be quantized with activation quantization"
)
else:
assert isinstance(child.weight, AffineQuantizedTensor), (
"Linear weight should be quantized without activation quantization"
)
class TestQuantizationCallback:
@@ -309,10 +218,10 @@ class TestQuantizationCallback:
global_step=0,
)
@require_torch_2_8_0
@require_torch_2_6_0
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int4",
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
@@ -359,10 +268,10 @@ class TestQuantizationCallback:
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
@require_torch_2_8_0
@require_torch_2_6_0
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int4",
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
@@ -395,3 +304,43 @@ class TestQuantizationCallback:
# quantization should be enabled from the get-go
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
class TestConvertQATModelForPTQ:
"""
Test convert_qat_model_for_ptq
"""
@require_torch_2_6_0
def test_convert_qat_model_for_ptq(self, model):
config = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model_for_ptq(
model,
quantize_embedding=config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)

View File

@@ -90,18 +90,6 @@ def require_torch_2_7_0(test_case):
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
def require_torch_2_8_0(test_case):
"""
Decorator marking a test that requires torch >= 2.7.0
"""
def is_min_2_8_0():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.8.0")
return unittest.skipUnless(is_min_2_8_0(), "test requires torch>=2.8.0")(test_case)
def require_torch_lt_2_6_0(test_case):
"""
Decorator marking a test that requires torch < 2.6.0
@@ -140,24 +128,6 @@ def require_llmcompressor(test_case):
)(test_case)
def requires_sm_ge_100(test_case):
is_sm_ge_100 = (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (10, 0)
)
return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case)
def requires_cuda_ge_8_9(test_case):
is_cuda_ge_8_9 = (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (8, 9)
)
return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)

View File

@@ -1,274 +0,0 @@
"""Tests for diffusion trainer integration."""
# pylint: disable=redefined-outer-name,protected-access
from unittest.mock import Mock
import pytest
import torch
from axolotl.integrations.diffusion import DiffusionTrainer
from axolotl.integrations.diffusion.utils import create_bidirectional_attention_mask
from axolotl.utils.dict import DictDefault
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
tokenizer.pad_token_id = 0
return tokenizer
@pytest.fixture
def diffusion_config():
"""Create a diffusion config."""
return DictDefault(
{
"diffusion": {
"mask_token_id": 32000,
"eps": 1e-3,
"importance_weighting": False,
},
"sample_packing": False,
}
)
@pytest.fixture
def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion trainer instance for testing methods directly."""
# Create a minimal trainer instance just for testing methods
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.cfg = diffusion_config
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
trainer.processing_class = mock_tokenizer
trainer.store_metrics = Mock() # Mock metrics storage
return trainer
class TestDiffusionTrainer:
"""Test the DiffusionTrainer class."""
def test_forward_process_basic(self, diffusion_trainer_instance):
"""Test basic forward process without labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that special tokens are not masked
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
assert not masked_indices[special_token_positions].any()
# Check that mask token is applied
mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id
masked_positions = masked_indices
if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all()
def test_forward_process_with_labels(self, diffusion_trainer_instance):
"""Test forward process with SFT labels."""
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
noisy_batch, masked_indices, p_mask = (
diffusion_trainer_instance._forward_process(
input_ids, labels=labels, eps=0.1
)
)
# Check shapes
assert noisy_batch.shape == input_ids.shape
assert masked_indices.shape == input_ids.shape
assert p_mask.shape == input_ids.shape
# Check that only answer tokens can be masked (where labels != -100)
non_answer_mask = labels == -100
# No masking should occur on non-answer tokens
assert not masked_indices[non_answer_mask].any()
# p_mask should be the same for all positions (sampled timestep),
# but masking is only applied to answer tokens
assert p_mask.shape == input_ids.shape
# Verify that masked_indices respects the answer mask
assert not masked_indices[non_answer_mask].any()
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
"""Test forward process with attention mask."""
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
_, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
input_ids, attention_mask=attention_mask, eps=0.1
)
# Check that padding tokens are not masked
padding_positions = attention_mask == 0
assert not masked_indices[padding_positions].any()
assert (p_mask[padding_positions] == 0).all()
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
"""Test bidirectional attention mask without sample packing."""
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
mask = create_bidirectional_attention_mask(input_ids)
# Should be all-to-all attention
expected_shape = (1, 1, 4, 4)
assert mask.shape == expected_shape
assert mask.all()
def test_bidirectional_attention_mask_with_packing(
self, diffusion_trainer_instance
):
"""Test bidirectional attention mask with sample packing."""
diffusion_trainer_instance.cfg.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
mask = create_bidirectional_attention_mask(
input_ids, attention_mask, sample_packing=True
)
# Check that tokens within same sample can attend to each other
# but not across samples
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
assert mask[0, 0, 1, 2].item()
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
assert not mask[0, 0, 2, 4].item()
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
def test_compute_loss_basic(self, diffusion_trainer_instance):
"""Test basic loss computation."""
# Mock model that returns logits
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
assert outputs == mock_outputs
# Check that metrics were stored
diffusion_trainer_instance.store_metrics.assert_called_once()
def test_compute_loss_sft(self, diffusion_trainer_instance):
"""Test loss computation with SFT labels."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 5
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
diffusion_trainer_instance.cfg.datasets = Mock()
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids, labels=labels
)
# Check that loss is computed
assert isinstance(loss, torch.Tensor)
assert loss.requires_grad
# Check that SFT metrics were added
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
assert "answer_ratio" in call_args
assert "avg_answer_length" in call_args
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
"""Test loss computation when no tokens are masked."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
vocab_size = 1000
seq_len = 3
mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
mock_model.return_value = mock_outputs
mock_model.training = True
# Only special tokens (which won't be masked)
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
mock_model, input_ids
)
# Loss should be zero when no tokens are masked
assert loss.item() == 0.0
assert loss.requires_grad
def test_cache_special_token_ids(self, mock_tokenizer):
"""Test caching of special token IDs."""
trainer = object.__new__(DiffusionTrainer)
trainer.processing_class = mock_tokenizer
trainer._cache_special_token_ids()
assert trainer._special_token_ids == {0, 1, 2}
def test_cache_special_token_ids_no_tokenizer(self):
"""Test caching when no tokenizer is available."""
trainer = object.__new__(DiffusionTrainer)
trainer.processing_class = None
trainer._cache_special_token_ids()
assert trainer._special_token_ids == set()
def test_main_compute_loss_interface(self, diffusion_trainer_instance):
"""Test the main compute_loss interface."""
# Mock model
mock_model = Mock()
mock_outputs = Mock()
mock_outputs.logits = torch.randn(1, 5, 1000)
mock_model.return_value = mock_outputs
mock_model.training = True
inputs = {
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
}
# Test without return_outputs
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
assert isinstance(loss, torch.Tensor)
# Test with return_outputs
loss, outputs = diffusion_trainer_instance.compute_loss(
mock_model, inputs, return_outputs=True
)
assert isinstance(loss, torch.Tensor)
assert outputs == mock_outputs
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
"""Test that missing input_ids raises ValueError."""
mock_model = Mock()
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
with pytest.raises(ValueError, match="input_ids is required"):
diffusion_trainer_instance.compute_loss(mock_model, inputs)

View File

@@ -1,92 +0,0 @@
"""Tests for diffusion generation callback dataloader selection and triggering."""
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from axolotl.integrations.diffusion import DiffusionGenerationCallback
class DummyTrainer:
"""Minimal trainer double with required attributes/methods for the callback."""
def __init__(self, use_eval: bool):
# Config used by callback
self.cfg = SimpleNamespace(
diffusion=SimpleNamespace(
generation_interval=1,
num_generation_samples=1,
generation_max_length=32,
generation_steps=4,
generation_temperature=0.0,
mask_token_id=16,
),
use_wandb=False,
)
# Model/tokenizer are passed through to generate_samples; not used here
self.model = Mock()
self.processing_class = Mock()
# Datasets and loaders
self.eval_dataset = object() if use_eval else None
self._train_loader = object()
self._eval_loader = object()
# State for world process check
self.state = SimpleNamespace(is_world_process_zero=True)
# Track which loader was requested
self.requested: list[str] = []
def get_train_dataloader(self):
self.requested.append("train")
return self._train_loader
def get_eval_dataloader(self):
self.requested.append("eval")
return self._eval_loader
@pytest.mark.parametrize("use_eval", [False, True])
def test_callback_uses_correct_dataloader(monkeypatch, use_eval):
trainer = DummyTrainer(use_eval=use_eval)
callback = DiffusionGenerationCallback(trainer)
captured = {}
# Patch generate_samples in the callback module's namespace
def fake_generate_samples(**kwargs):
captured["dataloader"] = kwargs.get("dataloader")
# Return one dummy sample to exercise logging path
return [
{
"original": "o",
"masked": "m",
"generated": "g",
"mask_ratio": 0.5,
"masked_tokens": 1,
"total_tokens": 2,
}
]
monkeypatch.setattr(
"axolotl.integrations.diffusion.callbacks.generate_samples",
fake_generate_samples,
)
# Trigger at step 1 (interval=1)
args = SimpleNamespace()
state = SimpleNamespace(global_step=1)
control = SimpleNamespace()
callback.on_step_end(args=args, state=state, control=control)
# Assert the expected dataloader path was used
if use_eval:
assert trainer.requested[0] == "eval"
assert captured["dataloader"] is trainer._eval_loader
else:
assert trainer.requested[0] == "train"
assert captured["dataloader"] is trainer._train_loader

View File

@@ -5,12 +5,12 @@ from unittest.mock import Mock, patch
from datasets import IterableDataset
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.data.sft import (
_prepare_streaming_dataset,
prepare_datasets,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.config import validate_config
class TestStreamingConfig(unittest.TestCase):