Compare commits
25 Commits
reentrant-
...
fsdp2_fp32
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d0562dedd | ||
|
|
7fa8ac40cd | ||
|
|
f9748c4dc5 | ||
|
|
33975ce4bc | ||
|
|
e8b962d47f | ||
|
|
856ff12171 | ||
|
|
6bc959342b | ||
|
|
b3b92687c4 | ||
|
|
55d1be2ae6 | ||
|
|
08d831c3d5 | ||
|
|
7be8740c5c | ||
|
|
c51d6b06c3 | ||
|
|
09959fac70 | ||
|
|
4065bc14c6 | ||
|
|
e5c427f6de | ||
|
|
86d6ee7c05 | ||
|
|
d4cff1b7bb | ||
|
|
1ef6c196f7 | ||
|
|
58d67bf98d | ||
|
|
0401a15888 | ||
|
|
fcfc13d710 | ||
|
|
9406c0c488 | ||
|
|
1b53c49e1a | ||
|
|
b71482cec5 | ||
|
|
79103b01ca |
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
|
||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -304,7 +304,7 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
gpu_type: "B200"
|
||||
axolotl_extras:
|
||||
axolotl_extras: fbgemm-gpu
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -190,3 +190,6 @@ out/
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
||||
# scm auto-versioning
|
||||
src/axolotl/_version.py
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
cff-version: 1.2.0
|
||||
type: software
|
||||
title: "Axolotl: Post-Training for AI Models"
|
||||
title: "Axolotl: Open Source LLM Post-Training"
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- name: "Axolotl maintainers and contributors"
|
||||
|
||||
16
README.md
16
README.md
@@ -5,6 +5,9 @@
|
||||
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
|
||||
</picture>
|
||||
</p>
|
||||
<p align="center">
|
||||
<strong>A Free and Open Source LLM Fine-tuning Framework</strong><br>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
|
||||
@@ -50,20 +53,21 @@
|
||||
|
||||
## ✨ Overview
|
||||
|
||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||
Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
|
||||
|
||||
Features:
|
||||
|
||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
|
||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
|
||||
|
||||
|
||||
## 🚀 Quick Start
|
||||
## 🚀 Quick Start - LLM Fine-tuning in Minutes
|
||||
|
||||
**Requirements**:
|
||||
|
||||
@@ -160,7 +164,7 @@ If you use Axolotl in your research or projects, please cite it as follows:
|
||||
|
||||
```bibtex
|
||||
@software{axolotl,
|
||||
title = {Axolotl: Post-Training for AI Models},
|
||||
title = {Axolotl: Open Source LLM Post-Training},
|
||||
author = {{Axolotl maintainers and contributors}},
|
||||
url = {https://github.com/axolotl-ai-cloud/axolotl},
|
||||
license = {Apache-2.0},
|
||||
|
||||
@@ -267,6 +267,7 @@ website:
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
- docs/optimizations.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
|
||||
@@ -212,6 +212,14 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
|
||||
|
||||
```
|
||||
"arguments": "{\"...\": \"...\"}"
|
||||
```
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
|
||||
@@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet
|
||||
|
||||
### Pre-training without streaming
|
||||
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
|
||||
@@ -140,3 +140,7 @@ description: Frequently asked questions
|
||||
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
|
||||
|
||||
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
|
||||
|
||||
**Q: `Error parsing tool_calls arguments as JSON.`
|
||||
|
||||
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.
|
||||
|
||||
@@ -13,6 +13,7 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
@@ -41,7 +42,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
# (optional) if doing lora, only finetune the Language model,
|
||||
# leave the vision model and vision tower frozen
|
||||
@@ -94,10 +94,22 @@ chat_template: llava
|
||||
|
||||
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
```
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Magistral-Small-2509
|
||||
```
|
||||
|
||||
### Voxtral {#sec-voxtral}
|
||||
|
||||
133
docs/optimizations.qmd
Normal file
133
docs/optimizations.qmd
Normal file
@@ -0,0 +1,133 @@
|
||||
---
|
||||
title: Optimizations Guide
|
||||
description: A guide to the performance and memory optimizations available in Axolotl.
|
||||
---
|
||||
|
||||
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
|
||||
|
||||
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
|
||||
|
||||
## Speed Optimizations
|
||||
|
||||
These optimizations focus on increasing training throughput and reducing total training time.
|
||||
|
||||
### Sample Packing
|
||||
|
||||
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
|
||||
|
||||
- **Config:** `sample_packing: true`
|
||||
- **Learn more:** [Sample Packing](multipack.qmd)
|
||||
|
||||
### Attention Implementations
|
||||
|
||||
Using an optimized attention implementation is critical for training speed.
|
||||
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
||||
|
||||
*Note: You should only enable one attention backend.*
|
||||
|
||||
### LoRA Optimizations
|
||||
|
||||
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
|
||||
|
||||
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
|
||||
|
||||
## Memory Optimizations
|
||||
|
||||
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
|
||||
|
||||
### Parameter Efficient Finetuning (LoRA & QLoRA)
|
||||
|
||||
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
|
||||
|
||||
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
|
||||
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
|
||||
|
||||
### Gradient Checkpointing & Activation Offloading
|
||||
|
||||
These techniques save VRAM by changing how activations are handled.
|
||||
|
||||
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
|
||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||
|
||||
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
|
||||
|
||||
### Liger Kernels
|
||||
|
||||
Provides efficient Triton kernels to improve training speed and reduce memory usage.
|
||||
|
||||
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
||||
|
||||
## Long Context Models
|
||||
|
||||
Techniques to train models on sequences longer than their original context window.
|
||||
|
||||
### RoPE Scaling
|
||||
|
||||
Extends a model's context window by interpolating its Rotary Position Embeddings.
|
||||
|
||||
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
|
||||
|
||||
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
|
||||
|
||||
### Artic Long Sequence Training (ALST)
|
||||
|
||||
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
|
||||
|
||||
- TiledMLP to reduce memory usage in MLP layers.
|
||||
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
|
||||
- Activation Offloading to CPU.
|
||||
|
||||
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
|
||||
|
||||
## Large Models (Distributed Training)
|
||||
|
||||
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
|
||||
|
||||
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
|
||||
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
|
||||
|
||||
### N-D Parallelism (Beta)
|
||||
|
||||
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
|
||||
|
||||
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Techniques to reduce the precision of model weights for memory savings.
|
||||
|
||||
### 4-bit Training (QLoRA)
|
||||
|
||||
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
|
||||
|
||||
### FP8 Training
|
||||
|
||||
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
|
||||
|
||||
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
|
||||
|
||||
### Quantization Aware Training (QAT)
|
||||
|
||||
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
|
||||
|
||||
- **Learn more:** [QAT Documentation](qat.qmd)
|
||||
|
||||
### GPTQ
|
||||
|
||||
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
||||
|
||||
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
||||
12
docs/qat.qmd
12
docs/qat.qmd
@@ -23,10 +23,18 @@ To enable QAT in axolotl, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
qat:
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
|
||||
```
|
||||
|
||||
We support the following quantization schemas:
|
||||
|
||||
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
|
||||
- `Int8DynamicActivationInt4Weight`
|
||||
- `Float8DynamicActivationFloat8Weight`
|
||||
- `Float8DynamicActivationInt4Weight`
|
||||
- `NVFP4`
|
||||
|
||||
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.
|
||||
|
||||
@@ -22,8 +22,8 @@ Quantization is configured using the `quantization` key in your configuration fi
|
||||
```yaml
|
||||
base_model: # The path to the model to quantize.
|
||||
quantization:
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
|
||||
|
||||
@@ -39,9 +39,8 @@ you used to train the model:
|
||||
# qat.yml
|
||||
qat:
|
||||
activation_dtype: int8
|
||||
weight_dtype: int8
|
||||
weight_dtype: int4
|
||||
group_size: 256
|
||||
quantize_embedding: true
|
||||
|
||||
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
|
||||
```
|
||||
@@ -51,3 +50,11 @@ axolotl quantize qat.yml
|
||||
```
|
||||
|
||||
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
|
||||
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,
|
||||
e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`
|
||||
|
||||
:::
|
||||
|
||||
@@ -7,3 +7,24 @@ techniques. It is a combination of:
|
||||
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
|
||||
|
||||
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
tiled_mlp: true
|
||||
|
||||
# See Sequence Parallelism docs
|
||||
# https://docs.axolotl.ai/docs/sequence_parallelism.html
|
||||
context_parallel_size: int
|
||||
|
||||
plugins:
|
||||
# See Cut Cross Entropy docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# or Liger Kernel docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
# ...
|
||||
|
||||
```
|
||||
|
||||
110
examples/apertus/README.md
Normal file
110
examples/apertus/README.md
Normal file
@@ -0,0 +1,110 @@
|
||||
# Finetune Swiss-AI's Apertus with Axolotl
|
||||
|
||||
[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. (Optional, highly recommended) Install XIELU CUDA
|
||||
|
||||
```bash
|
||||
## Recommended for reduced VRAM and faster speeds
|
||||
|
||||
# Point to CUDA toolkit directory
|
||||
# For those using our Docker image, use the below path.
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/apertus/apertus-8b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 8.7 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`.
|
||||
- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
### XIELU Installation Issues
|
||||
|
||||
#### `ModuleNotFoundError: No module named 'torch'`
|
||||
|
||||
Please check these one by one:
|
||||
- Running in correct environment
|
||||
- Env has PyTorch installed
|
||||
- CUDA toolkit is at `CUDA_HOME`
|
||||
|
||||
If those didn't help, please try the below solutions:
|
||||
|
||||
1. Pass env for CMAKE and try install again:
|
||||
|
||||
```bash
|
||||
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
2. Git clone the repo and manually hardcode python path:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/nickjbrowning/XIELU
|
||||
cd xielu
|
||||
git checkout 59d6031
|
||||
|
||||
cd xielu
|
||||
nano CMakeLists.txt # or vi depending on your preference
|
||||
```
|
||||
|
||||
```diff
|
||||
execute_process(
|
||||
- COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||
+ COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
|
||||
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
|
||||
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
|
||||
)
|
||||
```
|
||||
|
||||
```bash
|
||||
pip3 install . --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: swiss-ai/Apertus-8B-Instruct-2509
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -19,6 +19,9 @@ cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -176,8 +176,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from axolotl.utils.dict import DictDefault\n",
|
||||
"from axolotl.cli.config import load_cfg\n",
|
||||
"from axolotl.utils.dict import DictDefault\n",
|
||||
"\n",
|
||||
"# Axolotl provides full control and transparency over model and training configuration\n",
|
||||
"config = DictDefault(\n",
|
||||
@@ -251,10 +251,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from axolotl.utils import patch_optimized_env\n",
|
||||
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
|
||||
"\n",
|
||||
"# speedup downloads from HF 🤗 and set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||
"patch_optimized_env()"
|
||||
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||
"set_pytorch_cuda_alloc_conf()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -20,7 +20,13 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
```bash
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/devstral/devstral-small-qlora.yml
|
||||
|
||||
@@ -18,7 +18,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -23,7 +23,15 @@ pip3 install timm==1.0.17
|
||||
pip3 install librosa==0.11.0
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
3. Download sample dataset files
|
||||
|
||||
```bash
|
||||
# for text + vision + audio only
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# text only
|
||||
|
||||
@@ -106,6 +106,16 @@ See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-to
|
||||
|
||||
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
|
||||
|
||||
### Thinking and chat_template masking conflict
|
||||
|
||||
OpenAI’s Harmony template hides `thinking` in all non-final turns, which conflicts with Axolotl’s `chat_template` masking.
|
||||
|
||||
If your dataset has `thinking` content mid-turn, there are two paths we recommend:
|
||||
|
||||
- Train only on the last turn. This can be accomplished via chat_template's [train on last doc](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#training-on-last-message).
|
||||
|
||||
- Adjust your dataset to only have `thinking` content in the last turn.
|
||||
|
||||
### TIPS
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
85
examples/hunyuan/README.md
Normal file
85
examples/hunyuan/README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# Finetune HunYuan with Axolotl
|
||||
|
||||
Tencent released a family of opensource models called HunYuan with varying parameter scales of 0.5B, 1.8B, 4B, and 7B scale for both Pre-trained and Instruct variants. The models can be found at [HuggingFace](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as HunYuan is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/hunyuan/hunyuan-v1-dense-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 4.7 GB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Dataset
|
||||
|
||||
HunYuan Instruct models can choose to enter a slow think or fast think pattern. For best performance on fine-tuning their Instruct models, your dataset should be adjusted to match their pattern.
|
||||
|
||||
```python
|
||||
# fast think pattern
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "/no_think What color is the sun?" },
|
||||
{"role": "assistant", "content": "<think>\n\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
|
||||
]
|
||||
|
||||
# slow think pattern
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "/no_think What color is the sun?" },
|
||||
{"role": "assistant", "content": "<think>\nThe user is asking about the color of the sun. I need to ...\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
|
||||
]
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Tencent team recommends
|
||||
|
||||
```json
|
||||
|
||||
{
|
||||
"do_sample": true,
|
||||
"top_k": 20,
|
||||
"top_p": 0.8,
|
||||
"repetition_penalty": 1.05,
|
||||
"temperature": 0.7
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Tencent HunYuan Blog](https://hunyuan.tencent.com/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
64
examples/hunyuan/hunyuan-v1-dense-qlora.yaml
Normal file
64
examples/hunyuan/hunyuan-v1-dense-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: tencent/Hunyuan-0.5B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -15,20 +15,18 @@ liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
|
||||
datasets:
|
||||
- path: yahma/alpaca-cleaned
|
||||
type: alpaca
|
||||
split: train[:95%]
|
||||
|
||||
output_dir: ./outputs/qat_out/
|
||||
dataset_prepared_path: ./outputs/qat_out/dataset_prepared
|
||||
|
||||
sample_packing: true
|
||||
|
||||
sequence_len: 512
|
||||
|
||||
flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
sample_packing: false
|
||||
sequence_len: 8192
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: int8
|
||||
@@ -67,7 +65,7 @@ fsdp:
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
@@ -76,6 +74,6 @@ fsdp_config:
|
||||
fsdp_activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
64
examples/llama-3/3b-qat-nvfp4.yaml
Normal file
64
examples/llama-3/3b-qat-nvfp4.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: meta-llama/Llama-3.2-3B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: yahma/alpaca-cleaned
|
||||
type: alpaca
|
||||
split: train[:95%]
|
||||
|
||||
output_dir: ./outputs/qat_out/
|
||||
dataset_prepared_path: ./outputs/dataset_prepared
|
||||
|
||||
sequence_len: 8192
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 64
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
|
||||
cosine_constant_lr_ratio: 0
|
||||
cosine_min_lr_ratio: 1.0
|
||||
learning_rate: 2e-5
|
||||
save_only_model: true
|
||||
bf16: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
56
examples/llama-3/diffusion/pretrain-1b.yaml
Normal file
56
examples/llama-3/diffusion/pretrain-1b.yaml
Normal file
@@ -0,0 +1,56 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
pretraining_dataset:
|
||||
- path: wikitext
|
||||
name: wikitext-103-raw-v1
|
||||
type: completion
|
||||
field: text
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.diffusion.DiffusionPlugin
|
||||
|
||||
diffusion:
|
||||
noise_schedule: cosine
|
||||
min_mask_ratio: 0.15
|
||||
max_mask_ratio: 0.85
|
||||
num_diffusion_steps: 128
|
||||
eps: 5e-4
|
||||
importance_weighting: true
|
||||
mask_token_id: 128002
|
||||
generate_samples: true
|
||||
generation_interval: 250
|
||||
|
||||
output_dir: ./outputs/model-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: true
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 4
|
||||
max_steps: 10000
|
||||
warmup_ratio: 0.1
|
||||
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 3e-4
|
||||
sdp_attention: true
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
save_strategy: steps
|
||||
save_steps: 1000
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
59
examples/llama-3/diffusion/sft-1b.yaml
Normal file
59
examples/llama-3/diffusion/sft-1b.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
val_set_size: 0.05
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.diffusion.DiffusionPlugin
|
||||
|
||||
diffusion:
|
||||
noise_schedule: cosine
|
||||
min_mask_ratio: 0.1
|
||||
max_mask_ratio: 0.9
|
||||
num_diffusion_steps: 128
|
||||
eps: 1e-3
|
||||
importance_weighting: true
|
||||
mask_token_id: 128002
|
||||
generate_samples: true
|
||||
generation_interval: 250
|
||||
|
||||
output_dir: ./outputs/model-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
warmup_steps: 0.1
|
||||
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-5
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
sdp_attention: true
|
||||
|
||||
logging_steps: 1
|
||||
save_strategy: best
|
||||
eval_strategy: epoch
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -12,15 +12,6 @@ chat_template: llama3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -46,7 +46,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -45,7 +45,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Finetune Magistral Small with Axolotl
|
||||
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506) and [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
|
||||
|
||||
## Getting started
|
||||
|
||||
@@ -18,7 +18,13 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
```bash
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/magistral/magistral-small-qlora.yaml
|
||||
@@ -30,29 +36,17 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Thinking
|
||||
|
||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities. The model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
|
||||
|
||||
Example format:
|
||||
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [{ "type": "text", "text": "..."}]},
|
||||
{"role": "assistant", "content": [{ "type": "thinking", "thinking": "..."}, { "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
### Vision
|
||||
|
||||
Example config: `./magistral-small-think-qlora.yaml`.
|
||||
MistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities.
|
||||
|
||||
The `thinking` section also supports an optional arg `closed: bool` (`True` default) which controls adding the closing `[/THINK]` tag.
|
||||
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
|
||||
|
||||
Limitations:
|
||||
- You cannot mix `content: str` with `content: list[dict]` as the `dataset.load_dataset` may complain about different types for `content` key.
|
||||
- This mode does not work with custom `train_detail` and `training` at the moment.
|
||||
|
||||
### TIPS
|
||||
### Tips
|
||||
|
||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
@@ -83,5 +77,5 @@ In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||
- Add parity to Preference Tuning, RL, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
|
||||
73
examples/magistral/think/README.md
Normal file
73
examples/magistral/think/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Magistral Small Thinking Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
Run the thinking model fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train magistral-small-think-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 19.1 GiB VRAM.
|
||||
|
||||
### Tips
|
||||
|
||||
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
|
||||
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
|
||||
Example format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Options
|
||||
|
||||
The `thinking` section supports an optional `closed` parameter:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Internal reasoning here...",
|
||||
"closed": true // Default: true, controls adding the closing [/THINK] tag
|
||||
}
|
||||
```
|
||||
60
examples/magistral/vision/README.md
Normal file
60
examples/magistral/vision/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Magistral Small Vision Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train magistral-small-vision-24B-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 17GiB VRAM.
|
||||
|
||||
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
|
||||
|
||||
### Tips
|
||||
|
||||
Key differences from text-only model:
|
||||
- `max_tokens: 131072` for inference
|
||||
- Multi-modal dataset format required
|
||||
- Sample packing not supported
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
@@ -0,0 +1,64 @@
|
||||
base_model: mistralai/Magistral-Small-2509
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
@@ -8,12 +11,12 @@ skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
@@ -48,8 +51,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
# flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
||||
sdp_attention: true
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
@@ -12,15 +12,6 @@ chat_template: phi_3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -45,8 +45,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
# flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||
sdp_attention: true
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
64
examples/qwen3-next/README.md
Normal file
64
examples/qwen3-next/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Finetune Qwen3-Next with Axolotl
|
||||
|
||||
[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Install Qwen3-Next transformers commit
|
||||
```bash
|
||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||
```
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 45.62 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
68
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
68
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen3-Next-80B-A3B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- linear_attn.in_proj_ba
|
||||
- linear_attn.in_proj_qkvz
|
||||
- linear_attn.out_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert_gate
|
||||
- mlp.gate
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
54
examples/seed-oss/README.md
Normal file
54
examples/seed-oss/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Finetune ByteDance's Seed-OSS with Axolotl
|
||||
|
||||
[Seed-OSS](https://huggingface.co/collections/ByteDance-Seed/seed-oss-68a609f4201e788db05b5dcd) are a series of 36B parameter open source models trained by ByteDance's Seed Team.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install Cut Cross Entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/seed-oss/seed-oss-36b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 27.7 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Seed Team recommends `top_p=0.95` and `temperature=1.1`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [ByteDance Seed Website](https://seed.bytedance.com/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
56
examples/seed-oss/seed-oss-36b-qlora.yaml
Normal file
56
examples/seed-oss/seed-oss-36b-qlora.yaml
Normal file
@@ -0,0 +1,56 @@
|
||||
base_model: ByteDance-Seed/Seed-OSS-36B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -22,9 +22,19 @@ pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
# audio
|
||||
pip3 install librosa==0.11.0
|
||||
pip3 install 'mistral_common[audio]==1.8.3'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
3. Download sample dataset files
|
||||
|
||||
```bash
|
||||
# for text + audio only
|
||||
wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# text only
|
||||
|
||||
@@ -32,7 +32,7 @@ line-length = 88
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "C90", "B"]
|
||||
select = ["E", "F", "W", "C90", "B", "I"]
|
||||
ignore = [
|
||||
"E203", # Whitespace before ':'
|
||||
"E501", # Line too long
|
||||
|
||||
@@ -15,10 +15,10 @@ huggingface_hub>=0.33.0
|
||||
peft>=0.17.0
|
||||
transformers==4.56.1
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.10.0
|
||||
accelerate==1.10.1
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.21.0
|
||||
trl==0.23.0
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
trackio
|
||||
@@ -64,10 +64,10 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.12.0
|
||||
torchao==0.13.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.3
|
||||
mistral-common==1.8.5
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"'
|
||||
)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -124,7 +124,6 @@ extras_require = {
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.8.3",
|
||||
"ring-flash-attn>=0.1.7",
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.5",
|
||||
@@ -162,6 +161,7 @@ extras_require = {
|
||||
"llmcompressor": [
|
||||
"llmcompressor==0.5.1",
|
||||
],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
|
||||
}
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
|
||||
@@ -4,5 +4,7 @@ import os
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
|
||||
configure_logging()
|
||||
|
||||
@@ -115,6 +115,7 @@ class QuantizeCliArgs:
|
||||
quantize_embedding: Optional[bool] = field(default=None)
|
||||
group_size: Optional[int] = field(default=None)
|
||||
output_dir: Optional[str] = field(default=None)
|
||||
hub_model_id: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -23,7 +23,8 @@ from axolotl.utils.config import (
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.tee import prepare_debug_log
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -227,8 +228,11 @@ def load_cfg(
|
||||
},
|
||||
)
|
||||
|
||||
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
|
||||
# have to wait for cfg.output to be resolved. We could call this earlier if we write
|
||||
# to a temporary file, and then move it later.
|
||||
prepare_debug_log(cfg)
|
||||
prepare_optim_env(cfg)
|
||||
prepare_opinionated_env(cfg)
|
||||
normalize_config(cfg)
|
||||
normalize_cfg_datasets(cfg)
|
||||
setup_wandb_env_vars(cfg)
|
||||
@@ -241,7 +245,6 @@ def load_cfg(
|
||||
for k, v in cfg.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
"config:\n%s",
|
||||
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||
|
||||
@@ -14,6 +14,11 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
from axolotl.cli.args import InferenceCliArgs
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.cli.utils.diffusion import (
|
||||
diffusion_inference,
|
||||
launch_diffusion_gradio_ui,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -29,6 +34,7 @@ def get_multi_line_input() -> str:
|
||||
Possibly multi-line, possibly empty stdin input as a string.
|
||||
"""
|
||||
print("Give me an instruction (Ctrl + D to submit): ")
|
||||
print("=" * 80)
|
||||
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
@@ -43,9 +49,9 @@ def do_inference(
|
||||
cli_args: InferenceCliArgs,
|
||||
):
|
||||
"""
|
||||
Runs inference on the command line in a loop. User input is accepted, a chat template
|
||||
is (optionally) applied, and the model specified in the `axolotl` config is used to
|
||||
generate completions according to a default generation config.
|
||||
Runs inference on the command line in a loop. User input is accepted, a chat
|
||||
template is (optionally) applied, and the model specified in the `axolotl` config is
|
||||
used to generate completions according to a default generation config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
@@ -64,16 +70,28 @@ def do_inference(
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg, ds_cfg=None, tokenizer=tokenizer
|
||||
)
|
||||
elif cfg.datasets[0].type == "chat_template":
|
||||
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
# Detect diffusion mode
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
is_diffusion = any(
|
||||
plugin.__class__.__name__ == "DiffusionPlugin"
|
||||
for plugin in plugin_manager.plugins.values()
|
||||
)
|
||||
|
||||
if is_diffusion:
|
||||
print("=" * 80)
|
||||
print("Commands:")
|
||||
print(":complete N -> completion mode with N tokens (default 64)")
|
||||
print(":mask R -> random masking with ratio R (0.0–1.0)")
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
@@ -103,9 +121,19 @@ def do_inference(
|
||||
else:
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
print("=" * 80)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
if is_diffusion:
|
||||
diffusion_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
prompt=prompt,
|
||||
chat_template_str=chat_template_str,
|
||||
)
|
||||
continue
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
@@ -128,7 +156,7 @@ def do_inference(
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print("=" * 80)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
@@ -161,13 +189,30 @@ def do_inference_gradio(
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg, ds_cfg=None, tokenizer=tokenizer
|
||||
)
|
||||
elif cfg.datasets[0].type == "chat_template":
|
||||
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
# Detect diffusion mode
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
is_diffusion = any(
|
||||
plugin.__class__.__name__ == "DiffusionPlugin"
|
||||
for plugin in plugin_manager.plugins.values()
|
||||
)
|
||||
|
||||
if is_diffusion:
|
||||
launch_diffusion_gradio_ui(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
prompter_module=prompter_module,
|
||||
chat_template_str=chat_template_str,
|
||||
)
|
||||
return
|
||||
|
||||
def generate(instruction):
|
||||
if not instruction:
|
||||
return
|
||||
|
||||
@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
@@ -44,7 +44,7 @@ def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
patch_optimized_env()
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
||||
from axolotl.utils.quantization import (
|
||||
TorchAOQuantDType,
|
||||
get_quantization_config,
|
||||
quantization_config_to_str,
|
||||
quantize_model,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -43,13 +48,13 @@ def do_quantize(
|
||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||
)
|
||||
|
||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
||||
model_path = cli_args.get("base_model") or cfg.output_dir
|
||||
if weight_dtype := cli_args.get("weight_dtype"):
|
||||
weight_dtype = TorchIntDType[weight_dtype]
|
||||
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
|
||||
else:
|
||||
weight_dtype = quantize_cfg.weight_dtype
|
||||
if activation_dtype := cli_args.get("activation_dtype"):
|
||||
activation_dtype = TorchIntDType[activation_dtype]
|
||||
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
|
||||
else:
|
||||
activation_dtype = quantize_cfg.activation_dtype
|
||||
group_size = cli_args.get("group_size") or quantize_cfg.group_size
|
||||
@@ -57,10 +62,15 @@ def do_quantize(
|
||||
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
|
||||
)
|
||||
output_dir = cli_args.get("output_dir") or cfg.output_dir
|
||||
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
|
||||
|
||||
LOG.info(f"Loading model from {model_path}...")
|
||||
LOG.info(f"Loading model from {model_path}.")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", torch_dtype=torch_dtype
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Quantizing model with configuration: \n"
|
||||
@@ -70,11 +80,21 @@ def do_quantize(
|
||||
f"\tquantize_embedding: {quantize_embedding}"
|
||||
)
|
||||
|
||||
quantize_model_for_ptq(
|
||||
quantize_model(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
|
||||
quantization_config = get_quantization_config(
|
||||
weight_dtype, activation_dtype, group_size
|
||||
)
|
||||
|
||||
ao_config = TorchAoConfig(
|
||||
quant_type=quantization_config,
|
||||
include_input_output_embeddings=quantize_embedding,
|
||||
)
|
||||
model.config.quantization_config = ao_config
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
@@ -86,4 +106,14 @@ def do_quantize(
|
||||
progressbar=True,
|
||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||
)
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
||||
|
||||
if hub_model_id:
|
||||
hub_model_id = (
|
||||
hub_model_id.rstrip("-")
|
||||
+ f"-{quantization_config_to_str[type(quantization_config)]}"
|
||||
)
|
||||
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||
tokenizer.push_to_hub(hub_model_id)
|
||||
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
||||
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
||||
|
||||
@@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
|
||||
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
@@ -59,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
@@ -92,6 +92,7 @@ def ray_train_func(kwargs: dict):
|
||||
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
|
||||
# also renormalize the config now that TorchTrainer has spawned distributed workers
|
||||
cfg = DictDefault(kwargs["cfg"])
|
||||
prepare_optim_env(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
|
||||
|
||||
374
src/axolotl/cli/utils/diffusion.py
Normal file
374
src/axolotl/cli/utils/diffusion.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Helpers for diffusion-mode inference in CLI and Gradio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gradio as gr
|
||||
from colorama import Fore, Style
|
||||
|
||||
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def diffusion_inference(
|
||||
model,
|
||||
tokenizer,
|
||||
cfg,
|
||||
prompt: str,
|
||||
chat_template_str: str | None = None,
|
||||
):
|
||||
"""Diffusion inference helper method."""
|
||||
mode = "random"
|
||||
completion_tokens = 0
|
||||
target_mask_ratio = None
|
||||
mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt)
|
||||
|
||||
if cleaned:
|
||||
prompt = cleaned
|
||||
|
||||
info = run_diffusion(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
prompt=prompt,
|
||||
chat_template_str=chat_template_str,
|
||||
mode=mode,
|
||||
target_mask_ratio=target_mask_ratio,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
masked_text = info["masked_text"]
|
||||
mask_ratio = info["mask_ratio"]
|
||||
generated_ids = info["generated_ids"]
|
||||
masked_positions = info["masked_positions"]
|
||||
orig_ids = info["orig_ids"]
|
||||
|
||||
# Display with masked preview and colored diff
|
||||
if masked_text is not None and mask_ratio is not None:
|
||||
print(f"Masked ({mask_ratio:.1%}):\n{masked_text}\n")
|
||||
if generated_ids is not None:
|
||||
# Compute per-token style
|
||||
styles: list[str] = []
|
||||
for i, tid in enumerate(generated_ids):
|
||||
if i in masked_positions:
|
||||
if i < len(orig_ids) and tid == orig_ids[i]:
|
||||
styles.append("green") # correct fill
|
||||
elif i < len(orig_ids):
|
||||
styles.append("red") # incorrect fill
|
||||
else:
|
||||
styles.append("normal") # appended
|
||||
else:
|
||||
same = i < len(orig_ids) and tid == orig_ids[i]
|
||||
styles.append("dim" if same else "normal")
|
||||
|
||||
# Group contiguous spans by style
|
||||
styled_spans: list[tuple[str, int, int]] = []
|
||||
if generated_ids:
|
||||
current_style = styles[0]
|
||||
start = 0
|
||||
for i in range(1, len(generated_ids)):
|
||||
s = styles[i]
|
||||
if s != current_style:
|
||||
styled_spans.append((current_style, start, i))
|
||||
current_style, start = s, i
|
||||
styled_spans.append((current_style, start, len(generated_ids)))
|
||||
|
||||
out_parts = []
|
||||
for style_name, a, b in styled_spans:
|
||||
chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
|
||||
if style_name == "green":
|
||||
out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
|
||||
elif style_name == "red":
|
||||
out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
|
||||
else:
|
||||
if style_name == "dim":
|
||||
out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
|
||||
else:
|
||||
out_parts.append(chunk_text)
|
||||
print("Generated:\n" + "".join(out_parts))
|
||||
else:
|
||||
print("Generated:\n(no output)")
|
||||
|
||||
|
||||
def _parse_commands(text: str):
|
||||
"""
|
||||
Parse leading diffusion commands.
|
||||
|
||||
Supported at start of input (can be chained):
|
||||
:complete N -> completion mode with N tokens (default 64)
|
||||
:mask R -> random masking with ratio R in [0, 1]
|
||||
"""
|
||||
tokens = text.strip().split()
|
||||
i = 0
|
||||
mode = "random"
|
||||
completion_tokens = 0
|
||||
target_mask_ratio = None
|
||||
consumed = 0
|
||||
while i < len(tokens) and tokens[i].startswith(":"):
|
||||
cmd = tokens[i]
|
||||
i += 1
|
||||
consumed = i
|
||||
if cmd == ":complete":
|
||||
mode = "completion"
|
||||
if i < len(tokens):
|
||||
try:
|
||||
completion_tokens = int(tokens[i])
|
||||
i += 1
|
||||
consumed = i
|
||||
except Exception:
|
||||
completion_tokens = 64
|
||||
else:
|
||||
completion_tokens = 64
|
||||
elif cmd == ":mask":
|
||||
mode = "random"
|
||||
if i < len(tokens):
|
||||
try:
|
||||
target_mask_ratio = float(tokens[i])
|
||||
i += 1
|
||||
consumed = i
|
||||
except Exception:
|
||||
target_mask_ratio = None
|
||||
else:
|
||||
i -= 1
|
||||
consumed = i
|
||||
break
|
||||
|
||||
cleaned = " ".join(tokens[consumed:])
|
||||
|
||||
return mode, completion_tokens, target_mask_ratio, cleaned
|
||||
|
||||
|
||||
def run_diffusion(
|
||||
*,
|
||||
model,
|
||||
tokenizer,
|
||||
cfg: DictDefault,
|
||||
prompt: str,
|
||||
chat_template_str: str | None,
|
||||
mode: str = "random",
|
||||
target_mask_ratio: float | None = None,
|
||||
completion_tokens: int = 0,
|
||||
):
|
||||
"""Run a single diffusion generation and return a structured result dict."""
|
||||
if chat_template_str:
|
||||
batch = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False)
|
||||
|
||||
seq = batch["input_ids"].to(cfg.device)
|
||||
gen_mode = "completion" if mode == "completion" else "random"
|
||||
comp_tokens = int(completion_tokens) if gen_mode == "completion" else 0
|
||||
|
||||
result = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
original_sequence=seq[:1],
|
||||
num_diffusion_steps=cfg.diffusion.num_diffusion_steps,
|
||||
temperature=cfg.diffusion.generation_temperature,
|
||||
mask_token_id=int(mask_token_id),
|
||||
mode=gen_mode, # type: ignore[arg-type]
|
||||
completion_tokens=comp_tokens,
|
||||
target_mask_ratio=target_mask_ratio,
|
||||
)
|
||||
|
||||
masked_text = result.get("masked") if isinstance(result, dict) else None
|
||||
mask_ratio = result.get("mask_ratio") if isinstance(result, dict) else None
|
||||
generated_ids = result.get("generated_ids") if isinstance(result, dict) else None
|
||||
masked_positions = (
|
||||
set(result.get("masked_positions") or []) if isinstance(result, dict) else set()
|
||||
)
|
||||
orig_ids = seq[0].detach().cpu().tolist()
|
||||
|
||||
return {
|
||||
"masked_text": masked_text,
|
||||
"mask_ratio": mask_ratio,
|
||||
"generated_ids": generated_ids,
|
||||
"masked_positions": masked_positions,
|
||||
"orig_ids": orig_ids,
|
||||
}
|
||||
|
||||
|
||||
def render_html(
|
||||
*,
|
||||
generated_ids: list[int] | None,
|
||||
orig_ids: list[int],
|
||||
masked_positions: set[int],
|
||||
tokenizer,
|
||||
) -> str:
|
||||
"""Render HTML visualizing diffusion outputs."""
|
||||
if not generated_ids:
|
||||
return "<pre>Generated:\n(no output)</pre>"
|
||||
|
||||
def _style_for(i: int, tid: int) -> str:
|
||||
if i in masked_positions:
|
||||
if i < len(orig_ids) and tid == orig_ids[i]:
|
||||
return "green"
|
||||
if i < len(orig_ids):
|
||||
return "red"
|
||||
return "normal"
|
||||
same = i < len(orig_ids) and tid == orig_ids[i]
|
||||
return "dim" if same else "normal"
|
||||
|
||||
# Group contiguous spans by style to reduce HTML size
|
||||
spans: list[tuple[str, int, int]] = []
|
||||
if generated_ids:
|
||||
cur = _style_for(0, generated_ids[0])
|
||||
start = 0
|
||||
for i in range(1, len(generated_ids)):
|
||||
s = _style_for(i, generated_ids[i])
|
||||
if s != cur:
|
||||
spans.append((cur, start, i))
|
||||
cur, start = s, i
|
||||
spans.append((cur, start, len(generated_ids)))
|
||||
|
||||
html_parts = []
|
||||
for style_name, a, b in spans:
|
||||
txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
|
||||
if style_name == "green":
|
||||
html_parts.append(f'<span style="color:#2e7d32">{txt}</span>')
|
||||
elif style_name == "red":
|
||||
html_parts.append(f'<span style="color:#c62828">{txt}</span>')
|
||||
elif style_name == "dim":
|
||||
html_parts.append(f'<span style="opacity:0.6">{txt}</span>')
|
||||
else:
|
||||
html_parts.append(txt)
|
||||
|
||||
legend = (
|
||||
'<div style="font-size:0.9em;margin-bottom:4px">'
|
||||
'<span style="color:#2e7d32">correct</span>, '
|
||||
'<span style="color:#c62828">incorrect</span>, '
|
||||
'<span style="opacity:0.6">unchanged</span>'
|
||||
"</div>"
|
||||
)
|
||||
|
||||
return (
|
||||
legend
|
||||
+ '<pre style="white-space:pre-wrap">Generated:\n'
|
||||
+ "".join(html_parts)
|
||||
+ "</pre>"
|
||||
)
|
||||
|
||||
|
||||
def launch_diffusion_gradio_ui(
|
||||
*,
|
||||
model,
|
||||
tokenizer,
|
||||
cfg: DictDefault,
|
||||
prompter_module=None,
|
||||
chat_template_str: str | None = None,
|
||||
):
|
||||
"""Build and launch a simple Gradio UI for diffusion inference."""
|
||||
with gr.Blocks(
|
||||
title=cfg.get("gradio_title", "Axolotl Diffusion Interface")
|
||||
) as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
## Axolotl Diffusion Inference
|
||||
- Mode "Random" masks tokens at a target ratio and fills them.
|
||||
- Mode "Completion" appends N masked tokens at the end and fills them.
|
||||
"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
mode = gr.Radio(
|
||||
choices=["random", "completion"],
|
||||
value="random",
|
||||
label="Mode",
|
||||
)
|
||||
mask_ratio = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05,
|
||||
value=0.4,
|
||||
label="Mask ratio (random mode)",
|
||||
interactive=True,
|
||||
)
|
||||
completion_tokens = gr.Number(
|
||||
value=64,
|
||||
precision=0,
|
||||
label="Completion tokens (completion mode)",
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
instruction = gr.Textbox(label="Instruction", lines=6)
|
||||
run_btn = gr.Button("Generate")
|
||||
|
||||
masked_preview = gr.Textbox(label="Masked preview", lines=6)
|
||||
html_out = gr.HTML(label="Generated")
|
||||
|
||||
def _toggle_controls(selected_mode: str):
|
||||
return (
|
||||
gr.update(visible=(selected_mode == "random")),
|
||||
gr.update(visible=(selected_mode == "completion")),
|
||||
)
|
||||
|
||||
mode.change(
|
||||
_toggle_controls,
|
||||
inputs=[mode],
|
||||
outputs=[mask_ratio, completion_tokens],
|
||||
)
|
||||
|
||||
def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int):
|
||||
if not instruction_text:
|
||||
return "", "<pre>Generated:\n(no output)</pre>"
|
||||
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(
|
||||
instruction=instruction_text.strip("\n")
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt = instruction_text.strip()
|
||||
|
||||
info = run_diffusion(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
prompt=prompt,
|
||||
chat_template_str=chat_template_str,
|
||||
mode=selected_mode,
|
||||
target_mask_ratio=mratio if selected_mode == "random" else None,
|
||||
completion_tokens=int(ctoks) if selected_mode == "completion" else 0,
|
||||
)
|
||||
|
||||
masked_text = info.get("masked_text")
|
||||
mask_ratio_val = info.get("mask_ratio")
|
||||
generated_ids = info.get("generated_ids")
|
||||
masked_positions = info.get("masked_positions") or set()
|
||||
orig_ids = info.get("orig_ids") or []
|
||||
|
||||
preview = (
|
||||
f"Masked ({mask_ratio_val:.1%}):\n{masked_text}"
|
||||
if masked_text is not None and mask_ratio_val is not None
|
||||
else ""
|
||||
)
|
||||
html = render_html(
|
||||
generated_ids=generated_ids,
|
||||
orig_ids=orig_ids,
|
||||
masked_positions=masked_positions,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
return preview, html
|
||||
|
||||
run_btn.click(
|
||||
_gen,
|
||||
inputs=[instruction, mode, mask_ratio, completion_tokens],
|
||||
outputs=[masked_preview, html_out],
|
||||
)
|
||||
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
)
|
||||
@@ -435,7 +435,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
# don't use the HF gradient checkpointing, manually wrap
|
||||
training_args_kwargs["gradient_checkpointing"] = False
|
||||
training_args_kwargs["activation_offloading"] = True
|
||||
elif self.cfg.gradient_checkpointing:
|
||||
elif self.cfg.gradient_checkpointing is not None:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
|
||||
@@ -7,7 +7,11 @@ from pathlib import Path
|
||||
from typing import Type, Union
|
||||
|
||||
import transformers
|
||||
from transformers import DataCollatorWithFlattening, EarlyStoppingCallback
|
||||
from transformers import (
|
||||
DataCollatorWithFlattening,
|
||||
EarlyStoppingCallback,
|
||||
Trainer,
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
@@ -23,15 +27,16 @@ from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
log_prediction_callback_factory,
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
@@ -39,7 +44,6 @@ from axolotl.utils.collators import (
|
||||
MambaDataCollator,
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -391,10 +395,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
**data_collator_kwargs,
|
||||
)
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters:
|
||||
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
elif "tokenizer" in sig.parameters:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
if (
|
||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||
and self.cfg.datasets is not None
|
||||
|
||||
@@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
else:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
@@ -129,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
# Handle when max_prompt_length == max_length from defaults
|
||||
# CPOTrainer requires strictly less than
|
||||
if (
|
||||
training_args_kwargs["max_prompt_length"]
|
||||
== training_args_kwargs["max_length"]
|
||||
):
|
||||
training_args_kwargs["max_prompt_length"] -= 1
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
@@ -144,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.GRPO:
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Mapping
|
||||
|
||||
def chat_message_transform_builder(
|
||||
train_on_inputs=False,
|
||||
conversations_field: str = "conversations",
|
||||
conversations_field: str = "messages",
|
||||
message_field_role: str | list[str] | None = None, # commonly "role"
|
||||
message_field_content: str | list[str] | None = None, # commonly "content"
|
||||
message_field_training: str | list[str] | None = None, # commonly "weight"
|
||||
@@ -20,13 +20,13 @@ def chat_message_transform_builder(
|
||||
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
||||
Defaults to False.
|
||||
conversations_field (str, optional):
|
||||
The field name of the conversations. Defaults to "conversations".
|
||||
The field name of the conversations. Defaults to "messages".
|
||||
message_field_role (str | list[str], optional):
|
||||
The field name of the role. Defaults to "role".
|
||||
The field name of the role.
|
||||
message_field_content (str | list[str], optional):
|
||||
The field name of the message content. Defaults to "content".
|
||||
The field name of the message content.
|
||||
message_field_training (str | list[str], optional):
|
||||
The field name of the train/weight. Defaults to "weight".
|
||||
The field name of the train/weight.
|
||||
|
||||
Returns:
|
||||
Callable:
|
||||
|
||||
@@ -49,6 +49,13 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
REDUCTION_FNS = {
|
||||
"mean": torch.mean,
|
||||
"min": torch.min,
|
||||
"max": torch.max,
|
||||
"sum": torch.sum,
|
||||
}
|
||||
|
||||
|
||||
class AxolotlTrainer(
|
||||
PackingMixin,
|
||||
@@ -89,7 +96,9 @@ class AxolotlTrainer(
|
||||
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
self._stored_metrics = defaultdict(
|
||||
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
||||
)
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
@@ -362,6 +371,11 @@ class AxolotlTrainer(
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
@override
|
||||
def evaluate(self, *args, **kwargs):
|
||||
LOG.info("Running evaluation step...")
|
||||
return super().evaluate(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
concatenated_batch = {}
|
||||
@@ -585,9 +599,17 @@ class AxolotlTrainer(
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
|
||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
||||
reduction_type = metric_data["reduction"]
|
||||
|
||||
fn = REDUCTION_FNS.get(reduction_type)
|
||||
if fn is None:
|
||||
raise NotImplementedError(
|
||||
"Metric reduction must be one of [mean, min, max, sum]"
|
||||
)
|
||||
logs[key] = round(fn(values).item(), 4)
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
@@ -611,10 +633,27 @@ class AxolotlTrainer(
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
self,
|
||||
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
reduction: Literal["mean", "min", "max", "sum"] = "mean",
|
||||
) -> None:
|
||||
"""
|
||||
Store metrics with specified reduction type.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metric names to values, or metric names to (value,
|
||||
reduction_type) tuples.
|
||||
train_eval: Whether this is for training or evaluation.
|
||||
"""
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
if isinstance(value, tuple):
|
||||
value, _reduction = value # type: ignore[assignment]
|
||||
else:
|
||||
value, _reduction = value, reduction
|
||||
|
||||
self._stored_metrics[train_eval][key]["values"].append(value)
|
||||
self._stored_metrics[train_eval][key]["reduction"] = _reduction
|
||||
|
||||
def _save_checkpoint(self, model, trial, **kwargs):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
|
||||
@@ -27,7 +27,6 @@ class DPOStrategy:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
|
||||
@@ -142,7 +142,7 @@ class BasePlugin:
|
||||
model: The loaded model.
|
||||
"""
|
||||
|
||||
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
|
||||
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
|
||||
"""Returns a custom class for the trainer.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -20,8 +20,8 @@ from typing import Any, Dict, List, Type
|
||||
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||
|
||||
|
||||
def merge_input_args():
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -31,9 +31,11 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- apertus
|
||||
- arcee
|
||||
- cohere
|
||||
- cohere2
|
||||
- deepseek_v3
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3
|
||||
@@ -42,9 +44,14 @@ plugins:
|
||||
- gemma3n_text
|
||||
- glm
|
||||
- glm4
|
||||
- glm4_moe
|
||||
- glm4v
|
||||
- glm4v_moe
|
||||
- gpt_oss
|
||||
- granite
|
||||
- granitemoe
|
||||
- granitemoeshared
|
||||
- granitemoehybrid
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- llama
|
||||
@@ -63,7 +70,11 @@ plugins:
|
||||
- qwen2_5_vl
|
||||
- qwen3
|
||||
- qwen3_moe
|
||||
- qwen3_vl
|
||||
- qwen3_vl_moe
|
||||
- qwen3_next
|
||||
- smollm3
|
||||
- seed_oss
|
||||
- voxtral
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
154
src/axolotl/integrations/diffusion/README.md
Normal file
154
src/axolotl/integrations/diffusion/README.md
Normal file
@@ -0,0 +1,154 @@
|
||||
# Diffusion LM Training Plugin for Axolotl
|
||||
|
||||
This plugin enables diffusion language model training using an approach inspired by
|
||||
LLaDA (Large Language Diffusion Models) within Axolotl.
|
||||
|
||||
## Overview
|
||||
|
||||
LLaDA is a diffusion-based approach to language model training that uses:
|
||||
- **Random token masking** during training instead of next-token prediction
|
||||
- **Bidirectional attention** to allow the model to attend to the full context
|
||||
- **Importance weighting** based on masking probabilities for stable training
|
||||
|
||||
This approach can lead to more robust language models with better understanding of
|
||||
bidirectional context.
|
||||
|
||||
## Installation
|
||||
|
||||
The plugin is included with Axolotl. See our
|
||||
[installation docs](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
## Quickstart
|
||||
|
||||
Train with an example config (Llama‑3.2 1B):
|
||||
- Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml`
|
||||
- SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml`
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
You can also modify your existing configs to enable / customize diffusion training.
|
||||
|
||||
Add the following to your Axolotl config:
|
||||
|
||||
```yaml
|
||||
# Enable diffusion LM training plugin
|
||||
plugins:
|
||||
- axolotl.integrations.diffusion.DiffusionPlugin
|
||||
```
|
||||
|
||||
And, configure the nested `diffusion` block (defaults shown):
|
||||
|
||||
```yaml
|
||||
diffusion:
|
||||
noise_schedule: linear # or "cosine"
|
||||
min_mask_ratio: 0.1
|
||||
max_mask_ratio: 0.9
|
||||
num_diffusion_steps: 128
|
||||
eps: 1e-3
|
||||
importance_weighting: true
|
||||
|
||||
# Mask token (training auto-adds if missing, avoid pad/eos)
|
||||
mask_token_str: "<|diffusion_mask|>"
|
||||
# Or use an existing special token id (e.g., 128002 for Llama-3.x)
|
||||
# mask_token_id: 128002
|
||||
|
||||
# Sample generation during training (optional)
|
||||
generate_samples: true
|
||||
generation_interval: 100
|
||||
num_generation_samples: 3
|
||||
generation_steps: 128
|
||||
generation_temperature: 0.0
|
||||
generation_max_length: 100
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
Any models that support 4D attention masks should work out of the box. If not, please
|
||||
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a
|
||||
[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)!
|
||||
|
||||
## How It Works
|
||||
|
||||
### Random Masking
|
||||
During training, tokens are randomly masked:
|
||||
- Sample timestep `t` uniformly from [0, 1]
|
||||
- Calculate masking probability: `p = (1 - eps) * t + eps`
|
||||
- Randomly mask tokens with probability `p`
|
||||
|
||||
### Diffusion Loss
|
||||
|
||||
Loss is computed only on masked tokens with (optional) importance weighting:
|
||||
|
||||
```python
|
||||
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
|
||||
```
|
||||
|
||||
## Sample Generation
|
||||
|
||||
When `diffusion.generate_samples: true`, the plugin generates samples during training:
|
||||
|
||||
```
|
||||
Sample 1:
|
||||
Original (45 tokens): The quick brown fox jumps over the lazy dog...
|
||||
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
|
||||
Generated: The quick brown fox jumps over the lazy dog...
|
||||
```
|
||||
|
||||
Samples are logged to console and wandb (if enabled).
|
||||
|
||||
## Inference
|
||||
|
||||
Diffusion inference is integrated into the standard Axolotl CLI. Use the same config
|
||||
you trained with and run:
|
||||
|
||||
```
|
||||
axolotl inference path/to/your-config.yaml
|
||||
```
|
||||
|
||||
Optionally, pass `--gradio` to use a simple web interface.
|
||||
|
||||
Interactive controls (prefix the prompt with commands):
|
||||
- `:complete N` → completion mode with N new masked tokens appended (default 64)
|
||||
- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0]
|
||||
|
||||
Example session:
|
||||
|
||||
```
|
||||
================================================================================
|
||||
Commands:
|
||||
:complete N -> completion mode with N tokens (default 64)
|
||||
:mask R -> random masking with ratio R (0.0–1.0)
|
||||
================================================================================
|
||||
Give me an instruction (Ctrl + D to submit):
|
||||
|
||||
:mask 0.4 The quick brown fox jumps over the lazy dog
|
||||
|
||||
Masked (40.0%):
|
||||
The [MASK] brown [MASK] jumps over the [MASK] dog
|
||||
|
||||
Generated:
|
||||
The quick brown fox jumps over the loud dog
|
||||
```
|
||||
|
||||
## Metrics and Monitoring
|
||||
|
||||
The plugin adds (or modifies) several metrics to track diffusion training:
|
||||
|
||||
- `train/loss`: Weighted diffusion loss
|
||||
- `train/accuracy`: Accuracy on masked tokens
|
||||
- `train/mask_ratio`: Average fraction of tokens masked
|
||||
- `train/num_masked_tokens`: Number of tokens masked
|
||||
- `train/avg_p_mask`: Average masking probability
|
||||
- `train/ce_loss`: Unweighted cross-entropy loss
|
||||
- `train/importance_weight_avg`: Average importance weight
|
||||
|
||||
## Limitations
|
||||
|
||||
- No flash attention support
|
||||
- No RL training support
|
||||
|
||||
## References
|
||||
|
||||
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
|
||||
- [Axolotl Documentation](https://docs.axolotl.ai/)
|
||||
- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args)
|
||||
19
src/axolotl/integrations/diffusion/__init__.py
Normal file
19
src/axolotl/integrations/diffusion/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Diffusion LM training plugin init."""
|
||||
|
||||
from .args import DiffusionArgs, DiffusionConfig
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
from .generation import generate
|
||||
from .plugin import DiffusionPlugin
|
||||
from .trainer import DiffusionTrainer
|
||||
from .utils import create_bidirectional_attention_mask, resolve_mask_token_id
|
||||
|
||||
__all__ = [
|
||||
"DiffusionArgs",
|
||||
"DiffusionPlugin",
|
||||
"DiffusionTrainer",
|
||||
"generate",
|
||||
"resolve_mask_token_id",
|
||||
"create_bidirectional_attention_mask",
|
||||
"DiffusionGenerationCallback",
|
||||
"DiffusionConfig",
|
||||
]
|
||||
95
src/axolotl/integrations/diffusion/args.py
Normal file
95
src/axolotl/integrations/diffusion/args.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Config args for diffusion LM training (nested under `diffusion:`)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class DiffusionConfig(BaseModel):
|
||||
"""Nested diffusion configuration available under the `diffusion` key."""
|
||||
|
||||
# Noise schedule config
|
||||
noise_schedule: Literal["linear", "cosine"] = Field(
|
||||
default="linear", description="Type of noise schedule for diffusion training"
|
||||
)
|
||||
min_mask_ratio: float = Field(
|
||||
default=0.1,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum masking ratio for diffusion noise schedule",
|
||||
)
|
||||
max_mask_ratio: float = Field(
|
||||
default=0.9,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Maximum masking ratio for diffusion noise schedule",
|
||||
)
|
||||
num_diffusion_steps: int = Field(
|
||||
default=128, ge=1, description="Number of diffusion timesteps"
|
||||
)
|
||||
eps: float = Field(
|
||||
default=1e-3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Epsilon value for minimum masking probability in forward process",
|
||||
)
|
||||
|
||||
# Training config
|
||||
importance_weighting: bool = Field(
|
||||
default=True,
|
||||
description="Apply importance weighting to loss based on masking probability",
|
||||
)
|
||||
mask_token_id: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Token ID to use for masking. Unset by default; can use one of the "
|
||||
"tokenizer's special tokens here."
|
||||
),
|
||||
)
|
||||
mask_token_str: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Token string to use as a mask. If `mask_token_id` is invalid or unset, "
|
||||
"this token will be ensured to exist as an additional special token and "
|
||||
"used. If absent, a default '<|diffusion_mask|>' will be added."
|
||||
),
|
||||
)
|
||||
|
||||
# Sample generation config
|
||||
generate_samples: bool = Field(
|
||||
default=True, description="Enable sample generation during training"
|
||||
)
|
||||
generation_interval: int = Field(
|
||||
default=100, ge=1, description="Generate samples every N steps"
|
||||
)
|
||||
num_generation_samples: int = Field(
|
||||
default=3, ge=1, description="Number of samples to generate each time"
|
||||
)
|
||||
generation_steps: int = Field(
|
||||
default=128, ge=1, description="Number of diffusion steps for generation"
|
||||
)
|
||||
generation_temperature: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
description="Temperature for generation sampling (0.0 = deterministic)",
|
||||
)
|
||||
generation_max_length: int = Field(
|
||||
default=100, ge=1, description="Maximum sequence length for generation"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_mask_ratios(self) -> "DiffusionConfig":
|
||||
if self.min_mask_ratio > self.max_mask_ratio:
|
||||
raise ValueError("min_mask_ratio must be ≤ max_mask_ratio")
|
||||
return self
|
||||
|
||||
|
||||
class DiffusionArgs(BaseModel):
|
||||
"""Plugin entry that exposes the nested `diffusion` block to the core config."""
|
||||
|
||||
diffusion: DiffusionConfig = Field(
|
||||
default_factory=DiffusionConfig,
|
||||
description="Diffusion training configuration. Only nested block is supported.",
|
||||
)
|
||||
174
src/axolotl/integrations/diffusion/callbacks.py
Normal file
174
src/axolotl/integrations/diffusion/callbacks.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Callbacks for diffusion training."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import wandb
|
||||
from colorama import Fore, Style
|
||||
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from .generation import generate_samples
|
||||
|
||||
# Simpler logger for more readable sample generation
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class DiffusionGenerationCallback(TrainerCallback):
|
||||
"""Callback for generating samples during diffusion training."""
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate samples at specified intervals."""
|
||||
if (
|
||||
state.global_step > 0
|
||||
and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0
|
||||
):
|
||||
if not self.trainer.state.is_world_process_zero:
|
||||
return
|
||||
|
||||
# Use eval dataloader if available, otherwise use train dataloader
|
||||
dataloader = None
|
||||
try:
|
||||
if getattr(self.trainer, "eval_dataset", None) is not None:
|
||||
dataloader = self.trainer.get_eval_dataloader()
|
||||
except Exception:
|
||||
dataloader = None
|
||||
if dataloader is None:
|
||||
dataloader = self.trainer.get_train_dataloader()
|
||||
|
||||
# Generate samples
|
||||
diffusion_cfg = self.trainer.cfg.diffusion
|
||||
samples = generate_samples(
|
||||
model=self.trainer.model,
|
||||
tokenizer=self.trainer.processing_class,
|
||||
dataloader=dataloader,
|
||||
num_generation_samples=diffusion_cfg.num_generation_samples,
|
||||
max_length=diffusion_cfg.generation_max_length,
|
||||
num_diffusion_steps=diffusion_cfg.generation_steps,
|
||||
temperature=diffusion_cfg.generation_temperature,
|
||||
mask_token_id=diffusion_cfg.mask_token_id,
|
||||
)
|
||||
|
||||
# Log samples
|
||||
self._log_samples(samples, state.global_step)
|
||||
|
||||
def _log_samples(self, samples: list, step: int):
|
||||
"""Log generated samples."""
|
||||
if not samples:
|
||||
return
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("GENERATED SAMPLES")
|
||||
logger.info("=" * 60)
|
||||
|
||||
for i, sample_data in enumerate(samples, 1):
|
||||
original = sample_data["original"]
|
||||
masked = sample_data["masked"]
|
||||
generated = sample_data["generated"]
|
||||
mask_ratio = sample_data["mask_ratio"]
|
||||
masked_tokens = sample_data["masked_tokens"]
|
||||
total_tokens = sample_data["total_tokens"]
|
||||
|
||||
logger.info(f"\nSample {i}:")
|
||||
logger.info(f"\tOriginal ({total_tokens} tokens): {original}")
|
||||
logger.info(
|
||||
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
|
||||
f"{mask_ratio:.1%}): {masked}"
|
||||
)
|
||||
|
||||
try:
|
||||
gen_ids = sample_data.get("generated_ids")
|
||||
orig_ids = sample_data.get("orig_ids")
|
||||
masked_positions = set(sample_data.get("masked_positions") or [])
|
||||
if isinstance(gen_ids, list) and isinstance(orig_ids, list):
|
||||
styles: list[str] = []
|
||||
for i, tid in enumerate(gen_ids):
|
||||
if i in masked_positions:
|
||||
if i < len(orig_ids) and tid == orig_ids[i]:
|
||||
styles.append("green")
|
||||
elif i < len(orig_ids):
|
||||
styles.append("red")
|
||||
else:
|
||||
styles.append("normal")
|
||||
else:
|
||||
same = i < len(orig_ids) and tid == orig_ids[i]
|
||||
styles.append("dim" if same else "normal")
|
||||
|
||||
spans: list[tuple[str, int, int]] = []
|
||||
if gen_ids:
|
||||
cur = styles[0]
|
||||
start = 0
|
||||
for i in range(1, len(gen_ids)):
|
||||
s = styles[i]
|
||||
if s != cur:
|
||||
spans.append((cur, start, i))
|
||||
cur, start = s, i
|
||||
spans.append((cur, start, len(gen_ids)))
|
||||
|
||||
parts = []
|
||||
for style_name, a, b in spans:
|
||||
chunk_text = self.trainer.processing_class.decode(
|
||||
gen_ids[a:b], skip_special_tokens=False
|
||||
)
|
||||
if style_name == "green":
|
||||
parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
|
||||
elif style_name == "red":
|
||||
parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
|
||||
else:
|
||||
if style_name == "dim":
|
||||
parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
|
||||
else:
|
||||
parts.append(chunk_text)
|
||||
logger.info("\tGenerated:\n%s", "".join(parts))
|
||||
else:
|
||||
logger.info(f"\tGenerated: {generated}")
|
||||
except Exception:
|
||||
logger.info(f"\tGenerated: {generated}")
|
||||
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.trainer.cfg.use_wandb:
|
||||
if wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
"generated_samples": wandb.Table(
|
||||
columns=[
|
||||
"step",
|
||||
"original",
|
||||
"masked",
|
||||
"generated",
|
||||
"mask_ratio",
|
||||
"masked_tokens",
|
||||
"total_tokens",
|
||||
],
|
||||
data=[
|
||||
[
|
||||
step,
|
||||
sample["original"],
|
||||
sample["masked"],
|
||||
sample["generated"],
|
||||
f"{sample['mask_ratio']:.1%}",
|
||||
sample["masked_tokens"],
|
||||
sample["total_tokens"],
|
||||
]
|
||||
for sample in samples
|
||||
],
|
||||
)
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
409
src/axolotl/integrations/diffusion/generation.py
Normal file
409
src/axolotl/integrations/diffusion/generation.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""Sample generation utilities for diffusion training."""
|
||||
|
||||
import re
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def generate_samples(
|
||||
model: torch.nn.Module,
|
||||
tokenizer: Any,
|
||||
dataloader: Optional[Any] = None,
|
||||
num_generation_samples: int = 3,
|
||||
max_length: int = 100,
|
||||
num_diffusion_steps: int = 128,
|
||||
temperature: float = 0.0,
|
||||
mask_token_id: int = 32000,
|
||||
mode: Literal["random", "completion"] = "random",
|
||||
completion_tokens: int = 0,
|
||||
target_mask_ratio: Optional[float] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Generate text samples using the diffusion model by randomly masking sequences from
|
||||
the given dataset and running the reverse diffusion process.
|
||||
|
||||
Args:
|
||||
model: The wrapped or unwrapped model
|
||||
tokenizer: Tokenizer for encoding/decoding
|
||||
dataloader: Validation dataloader (for sampling sequences)
|
||||
num_generation_samples: Number of samples to generate
|
||||
max_length: Maximum length of sequences to use
|
||||
num_diffusion_steps: Number of diffusion steps for generation
|
||||
temperature: Temperature for sampling (0.0 = deterministic)
|
||||
mask_token_id: Token ID used for masking
|
||||
|
||||
Returns:
|
||||
List of dictionaries with original text, masked text, and generated text
|
||||
"""
|
||||
if dataloader is None:
|
||||
LOG.warning("No validation dataloader provided, cannot generate samples")
|
||||
return []
|
||||
|
||||
unwrapped_model = model.module if hasattr(model, "module") else model
|
||||
training = unwrapped_model.training
|
||||
unwrapped_model.eval()
|
||||
|
||||
# Resolve device robustly (some modules don't expose `.device`)
|
||||
device = getattr(unwrapped_model, "device", None)
|
||||
if device is None:
|
||||
try:
|
||||
device = next(unwrapped_model.parameters()).device
|
||||
except StopIteration:
|
||||
device = torch.device("cpu")
|
||||
generations = []
|
||||
|
||||
# Sample sequences from validation dataset
|
||||
sampled_sequences = _sample_sequences_from_dataloader(
|
||||
dataloader, num_generation_samples, max_length, device
|
||||
)
|
||||
LOG.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
|
||||
|
||||
# Generate samples using reverse diffusion process
|
||||
with torch.no_grad():
|
||||
for sample in sampled_sequences:
|
||||
if isinstance(sample, dict):
|
||||
original_sequence = sample.get("input_ids")
|
||||
labels_seq = sample.get("labels")
|
||||
attn_seq = sample.get("attention_mask")
|
||||
else:
|
||||
original_sequence = sample
|
||||
labels_seq = None
|
||||
attn_seq = None
|
||||
generation_result = generate(
|
||||
unwrapped_model,
|
||||
tokenizer,
|
||||
original_sequence,
|
||||
num_diffusion_steps,
|
||||
temperature,
|
||||
mask_token_id,
|
||||
mode=mode,
|
||||
completion_tokens=completion_tokens,
|
||||
target_mask_ratio=target_mask_ratio,
|
||||
labels=labels_seq,
|
||||
attention_mask=attn_seq,
|
||||
)
|
||||
generations.append(generation_result)
|
||||
|
||||
# Restore prior training state
|
||||
if training:
|
||||
unwrapped_model.train()
|
||||
else:
|
||||
unwrapped_model.eval()
|
||||
|
||||
return generations
|
||||
|
||||
|
||||
def _sample_sequences_from_dataloader(
|
||||
dataloader: Any, num_samples: int, max_length: int, device: torch.device
|
||||
) -> List[Any]:
|
||||
"""Sample sequences from validation dataloader."""
|
||||
sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = []
|
||||
sample_count = 0
|
||||
|
||||
# Skip a random number of batches (we could be more clever about this)
|
||||
skip_batches = torch.randint(0, 10, (1,)).item()
|
||||
batch_count = 0
|
||||
|
||||
for batch in dataloader:
|
||||
# Skip some batches for variety
|
||||
if batch_count < skip_batches:
|
||||
batch_count += 1
|
||||
continue
|
||||
|
||||
if sample_count >= num_samples:
|
||||
break
|
||||
|
||||
batch_count += 1
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch.get("attention_mask")
|
||||
labels = batch.get("labels")
|
||||
|
||||
# Randomly sample from sequences in this batch
|
||||
batch_indices = torch.randperm(input_ids.size(0)).tolist()
|
||||
|
||||
for i in batch_indices:
|
||||
if sample_count >= num_samples:
|
||||
break
|
||||
|
||||
# Get actual sequence length (non-padded)
|
||||
if attention_mask is not None:
|
||||
seq_len = attention_mask[i].sum().item()
|
||||
else:
|
||||
seq_len = input_ids.size(1)
|
||||
|
||||
if seq_len < 10:
|
||||
continue
|
||||
|
||||
# Determine truncation length
|
||||
max_total = min(seq_len, max_length)
|
||||
if labels is not None:
|
||||
labels_i = labels[i][:seq_len]
|
||||
answer_mask = labels_i != -100
|
||||
if not answer_mask.any():
|
||||
# No answer tokens; skip for SFT masking
|
||||
continue
|
||||
first_ans_idx = int(
|
||||
torch.nonzero(answer_mask, as_tuple=False)[0].item()
|
||||
)
|
||||
prompt_len = first_ans_idx
|
||||
if prompt_len >= max_total:
|
||||
# Prompt alone reaches cap; cannot include any answer
|
||||
continue
|
||||
remaining_answer = int(answer_mask[prompt_len:].sum().item())
|
||||
allowed_answer = max_total - prompt_len
|
||||
take_answer = min(remaining_answer, allowed_answer)
|
||||
if take_answer <= 0:
|
||||
continue
|
||||
actual_length = prompt_len + take_answer
|
||||
else:
|
||||
actual_length = max_total
|
||||
|
||||
# Extract the (possibly truncated) sequence
|
||||
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
|
||||
attn_seq = (
|
||||
attention_mask[i][:actual_length].unsqueeze(0).to(device)
|
||||
if attention_mask is not None
|
||||
else None
|
||||
)
|
||||
if labels is not None:
|
||||
labels_seq = labels[i][:actual_length].unsqueeze(0).to(device)
|
||||
sampled_sequences.append(
|
||||
{
|
||||
"input_ids": sequence,
|
||||
"labels": labels_seq,
|
||||
"attention_mask": attn_seq,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if attn_seq is not None:
|
||||
sampled_sequences.append(
|
||||
{"input_ids": sequence, "attention_mask": attn_seq}
|
||||
)
|
||||
else:
|
||||
sampled_sequences.append(sequence)
|
||||
sample_count += 1
|
||||
|
||||
return sampled_sequences
|
||||
|
||||
|
||||
def generate(
|
||||
model: torch.nn.Module,
|
||||
tokenizer: Any,
|
||||
original_sequence: torch.Tensor,
|
||||
num_diffusion_steps: int,
|
||||
temperature: float,
|
||||
mask_token_id: int,
|
||||
*,
|
||||
mode: Literal["random", "completion"] = "random",
|
||||
completion_tokens: int = 0,
|
||||
target_mask_ratio: Optional[float] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> dict:
|
||||
"""Generate a single sample using reverse diffusion."""
|
||||
# Get original text for comparison
|
||||
original_text = tokenizer.decode(
|
||||
original_sequence[0].cpu(), skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Build masked sequence
|
||||
if (
|
||||
labels is not None
|
||||
and labels.numel() > 0
|
||||
and (labels == -100).any()
|
||||
and (labels != -100).any()
|
||||
):
|
||||
# SFT case: completely mask all answer tokens (labels != -100)
|
||||
total_tokens = original_sequence.size(1)
|
||||
masked_indices = (labels != -100).to(dtype=torch.bool)
|
||||
masked_sequence = original_sequence.clone()
|
||||
masked_sequence[masked_indices] = mask_token_id
|
||||
masked_tokens = int(masked_indices.sum().item())
|
||||
mask_ratio = masked_tokens / max(int(total_tokens), 1)
|
||||
elif mode == "completion" and completion_tokens > 0:
|
||||
# Append mask tokens to the right for completion
|
||||
total_tokens = original_sequence.size(1) + int(completion_tokens)
|
||||
masked_indices = torch.zeros(
|
||||
1, total_tokens, dtype=torch.bool, device=original_sequence.device
|
||||
)
|
||||
masked_indices[0, -int(completion_tokens) :] = True
|
||||
|
||||
append = torch.full(
|
||||
(1, int(completion_tokens)), mask_token_id, device=original_sequence.device
|
||||
)
|
||||
masked_sequence = torch.cat([original_sequence, append], dim=1)
|
||||
masked_tokens = int(completion_tokens)
|
||||
mask_ratio = masked_tokens / total_tokens
|
||||
else:
|
||||
# Apply random masking with optional fixed ratio
|
||||
total_tokens = original_sequence.size(1)
|
||||
if target_mask_ratio is None:
|
||||
min_ratio, max_ratio = 0.1, 0.7
|
||||
target_mask_ratio = (
|
||||
torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
|
||||
)
|
||||
target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio)))
|
||||
|
||||
# Create random mask indices
|
||||
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
|
||||
masked_indices = torch.zeros(
|
||||
1, total_tokens, dtype=torch.bool, device=original_sequence.device
|
||||
)
|
||||
masked_indices[0, mask_positions] = True
|
||||
|
||||
# Create masked sequence
|
||||
masked_sequence = original_sequence.clone()
|
||||
masked_sequence[masked_indices] = mask_token_id
|
||||
|
||||
# Calculate actual mask ratio
|
||||
masked_tokens = masked_indices.sum().item()
|
||||
mask_ratio = masked_tokens / total_tokens
|
||||
|
||||
# Get masked text for comparison
|
||||
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
|
||||
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
|
||||
|
||||
# Run reverse diffusion process
|
||||
sequence = masked_sequence.clone()
|
||||
attention_mask = create_bidirectional_attention_mask(
|
||||
sequence, attention_mask, sample_packing=attention_mask is not None
|
||||
)
|
||||
for step in range(num_diffusion_steps):
|
||||
sequence = _diffusion_step(
|
||||
model,
|
||||
sequence,
|
||||
step,
|
||||
num_diffusion_steps,
|
||||
temperature,
|
||||
mask_token_id,
|
||||
attention_mask,
|
||||
)
|
||||
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
|
||||
|
||||
# Collect diagnostic info
|
||||
final_ids = sequence[0].detach().cpu().tolist()
|
||||
orig_ids_for_render = original_sequence[0].detach().cpu().tolist()
|
||||
if masked_indices is not None:
|
||||
masked_positions = (
|
||||
torch.where(masked_indices[0])[0].detach().cpu().tolist()
|
||||
if masked_indices.ndim == 2
|
||||
else []
|
||||
)
|
||||
else:
|
||||
masked_positions = []
|
||||
|
||||
result = {
|
||||
"original": original_text,
|
||||
"masked": masked_text,
|
||||
"generated": generated_text,
|
||||
"mask_ratio": mask_ratio,
|
||||
"masked_tokens": masked_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"generated_ids": final_ids,
|
||||
"masked_positions": masked_positions,
|
||||
"orig_ids": orig_ids_for_render,
|
||||
"formatted": (
|
||||
f"Original: '{original_text}' → Masked: '{masked_text}' "
|
||||
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
|
||||
),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
|
||||
"""Clean up masked text for display."""
|
||||
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
|
||||
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
|
||||
|
||||
# Remove literal special token strings
|
||||
if hasattr(tokenizer, "special_tokens_map"):
|
||||
for token_value in tokenizer.special_tokens_map.values():
|
||||
if token_value and isinstance(token_value, str):
|
||||
cleaned = cleaned.replace(token_value, "")
|
||||
|
||||
# Normalize whitespace but preserve newlines
|
||||
cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n")
|
||||
cleaned = re.sub(r"[ \t]+", " ", cleaned)
|
||||
cleaned = "\n".join(line.rstrip() for line in cleaned.split("\n")).strip()
|
||||
return cleaned
|
||||
|
||||
|
||||
def _diffusion_step(
|
||||
model: torch.nn.Module,
|
||||
sequence: torch.Tensor,
|
||||
step: int,
|
||||
num_diffusion_steps: int,
|
||||
temperature: float,
|
||||
mask_token_id: int,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Perform a single diffusion step with remasking."""
|
||||
# Only process if there are masked tokens remaining
|
||||
current_mask = sequence == mask_token_id
|
||||
if not current_mask.any():
|
||||
return sequence
|
||||
|
||||
# Create or use provided attention mask
|
||||
if attention_mask is None:
|
||||
batch_size, seq_len = sequence.shape
|
||||
attention_mask = torch.ones(
|
||||
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
|
||||
# Only sample at currently masked positions
|
||||
if current_mask.any():
|
||||
masked_logits = logits[current_mask]
|
||||
|
||||
# Apply temperature scaling
|
||||
if temperature > 0:
|
||||
scaled_logits = masked_logits / temperature
|
||||
else:
|
||||
scaled_logits = masked_logits
|
||||
|
||||
# Suppress mask token in outputs
|
||||
scaled_logits[:, mask_token_id] = -float("inf")
|
||||
|
||||
if temperature > 0:
|
||||
# Add Gumbel noise for sampling
|
||||
gumbel_noise = -torch.log(
|
||||
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
|
||||
)
|
||||
gumbel_logits = scaled_logits + gumbel_noise
|
||||
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
|
||||
else:
|
||||
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
|
||||
|
||||
# Calculate probabilities for confidence scoring
|
||||
probs = torch.softmax(scaled_logits, dim=-1)
|
||||
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
|
||||
|
||||
# Determine how many tokens to unmask this step
|
||||
remaining_masked = current_mask.sum().item()
|
||||
if step == num_diffusion_steps - 1:
|
||||
num_to_unmask = remaining_masked
|
||||
else:
|
||||
unmask_ratio = 1.0 / (num_diffusion_steps - step)
|
||||
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
|
||||
|
||||
# Select highest confidence predictions to unmask
|
||||
if num_to_unmask >= remaining_masked:
|
||||
sequence[current_mask] = predicted_tokens
|
||||
else:
|
||||
_, top_indices = predicted_token_probs.topk(num_to_unmask)
|
||||
mask_positions = torch.where(current_mask)[1]
|
||||
positions_to_unmask = mask_positions[top_indices]
|
||||
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
|
||||
|
||||
return sequence
|
||||
41
src/axolotl/integrations/diffusion/plugin.py
Normal file
41
src/axolotl/integrations/diffusion/plugin.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Diffusion LM training plugin for Axolotl."""
|
||||
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .trainer import DiffusionTrainer
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DiffusionPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for diffusion language model training.
|
||||
|
||||
This plugin enables diffusion-based training using the LLaDA approach, which uses
|
||||
random masking and bidirectional attention to train language models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cfg = None
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
"""Returns the pydantic model for LLaDA plugin arguments."""
|
||||
return "axolotl.integrations.diffusion.DiffusionArgs"
|
||||
|
||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
"""Perform actions after model is loaded."""
|
||||
self.cfg = cfg
|
||||
|
||||
def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None:
|
||||
"""Return custom trainer class for diffusion training."""
|
||||
return DiffusionTrainer
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
|
||||
"""Configure trainer after creation."""
|
||||
trainer.set_config(cfg)
|
||||
301
src/axolotl/integrations/diffusion/trainer.py
Normal file
301
src/axolotl/integrations/diffusion/trainer.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Custom trainer for diffusion LM training."""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DiffusionTrainer(AxolotlTrainer):
|
||||
"""Custom trainer for diffusion LM training that overrides loss computation."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cfg = None
|
||||
self._special_token_ids = None
|
||||
|
||||
def set_config(self, config: DictDefault):
|
||||
"""Set config for diffusion training."""
|
||||
self.cfg = config
|
||||
self._cache_special_token_ids()
|
||||
self._resolve_mask_token_id()
|
||||
|
||||
token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0))
|
||||
LOG.info(f"Diffusion: using mask_token_id={token_id}")
|
||||
|
||||
if getattr(config.diffusion, "generate_samples", True):
|
||||
generation_callback = DiffusionGenerationCallback(self)
|
||||
self.add_callback(generation_callback)
|
||||
|
||||
def _resolve_mask_token_id(self) -> None:
|
||||
"""Ensure mask_token_id is valid for the current tokenizer."""
|
||||
from .utils import resolve_mask_token_id
|
||||
|
||||
tokenizer = getattr(self, "processing_class", None)
|
||||
if tokenizer is None:
|
||||
return
|
||||
|
||||
mid = resolve_mask_token_id(
|
||||
tokenizer,
|
||||
self.cfg,
|
||||
allow_add=True,
|
||||
model=getattr(self, "model", None),
|
||||
)
|
||||
try:
|
||||
self.cfg.diffusion.mask_token_id = int(mid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor],
|
||||
return_outputs: bool = False,
|
||||
num_items_in_batch: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""Override compute_loss to use diffusion loss."""
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask")
|
||||
labels = inputs.get("labels")
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("input_ids is required for diffusion training")
|
||||
|
||||
loss, outputs = self._compute_diffusion_loss(
|
||||
model, input_ids, attention_mask, labels
|
||||
)
|
||||
|
||||
if return_outputs:
|
||||
return loss, outputs
|
||||
return loss
|
||||
|
||||
def _cache_special_token_ids(self):
|
||||
"""Cache special token IDs to avoid repeated tokenizer access."""
|
||||
if self.processing_class is None:
|
||||
self._special_token_ids = set()
|
||||
return
|
||||
|
||||
tokenizer = self.processing_class
|
||||
special_tokens = set()
|
||||
|
||||
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
|
||||
special_tokens.add(tokenizer.bos_token_id)
|
||||
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
|
||||
special_tokens.add(tokenizer.eos_token_id)
|
||||
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
|
||||
special_tokens.add(tokenizer.pad_token_id)
|
||||
|
||||
self._special_token_ids = special_tokens
|
||||
|
||||
def _forward_process(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
labels: torch.Tensor | None = None,
|
||||
eps: float = 1e-3,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward noising process. A timestep is sampled along the process, and tokens are
|
||||
masked with probability determined by the configured noise schedule.
|
||||
|
||||
Args:
|
||||
input_ids: Input token ids [batch_size, seq_len].
|
||||
attention_mask: Attention mask [batch_size, seq_len].
|
||||
labels: Labels for SFT training [batch_size, seq_len].
|
||||
eps: Small epsilon value for minimum masking probability.
|
||||
|
||||
Returns:
|
||||
noisy_batch: Input with some tokens masked.
|
||||
masked_indices: Boolean mask indicating which tokens were masked.
|
||||
p_mask: Masking probabilities for each token [batch_size, seq_len].
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
# Sample random timesteps for each sample in batch
|
||||
t = torch.rand(batch_size, device=device)
|
||||
p_mask = (1 - eps) * t + eps # [batch_size]
|
||||
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
|
||||
|
||||
# Don't mask padding tokens if attention_mask is provided
|
||||
if attention_mask is not None:
|
||||
valid_mask = attention_mask.bool()
|
||||
p_mask = p_mask * valid_mask.float()
|
||||
|
||||
# Create mask to exclude special tokens
|
||||
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
||||
if self._special_token_ids:
|
||||
for token_id in self._special_token_ids:
|
||||
special_token_mask |= input_ids == token_id
|
||||
|
||||
# Create random mask based on p_mask
|
||||
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
|
||||
masked_indices = masked_indices & ~special_token_mask
|
||||
if attention_mask is not None:
|
||||
masked_indices = masked_indices & attention_mask.bool()
|
||||
|
||||
# For SFT data, only mask answer tokens
|
||||
if labels is not None:
|
||||
answer_mask = labels != -100
|
||||
masked_indices = masked_indices & answer_mask
|
||||
|
||||
# Create masked input
|
||||
mask_token_id = int(self.cfg.diffusion.mask_token_id)
|
||||
mask_value = torch.full_like(input_ids, mask_token_id)
|
||||
noisy_batch = torch.where(masked_indices, mask_value, input_ids)
|
||||
|
||||
return noisy_batch, masked_indices, p_mask
|
||||
|
||||
def _compute_diffusion_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
labels: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | Any]:
|
||||
"""
|
||||
Compute diffusion loss.
|
||||
|
||||
Args:
|
||||
model: The model to compute loss for.
|
||||
input_ids: Ground truth token ids [batch_size, seq_len].
|
||||
attention_mask: Attention mask [batch_size, seq_len].
|
||||
labels: Labels for SFT training [batch_size, seq_len].
|
||||
|
||||
Returns:
|
||||
loss: Cross-entropy loss.
|
||||
metrics: Dictionary of metrics.
|
||||
"""
|
||||
# Short-circuit empty sequences
|
||||
if input_ids is None or input_ids.numel() == 0 or input_ids.shape[1] == 0:
|
||||
zero = torch.tensor(
|
||||
0.0,
|
||||
device=(input_ids.device if input_ids is not None else None),
|
||||
requires_grad=True,
|
||||
)
|
||||
return zero, {}
|
||||
|
||||
# If an attention_mask is provided and all positions are padding for every
|
||||
# sample in this batch, skip the step.
|
||||
if attention_mask is not None:
|
||||
if attention_mask.dim() == 2 and (attention_mask.sum(dim=1) == 0).all():
|
||||
zero = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
|
||||
return zero, {}
|
||||
|
||||
# Apply forward process
|
||||
noisy_batch, masked_indices, p_mask = self._forward_process(
|
||||
input_ids, attention_mask, labels, self.cfg.diffusion.eps
|
||||
)
|
||||
|
||||
# Create bidirectional attention mask
|
||||
bidirectional_mask = create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask, sample_packing=self.cfg.sample_packing
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(
|
||||
input_ids=noisy_batch.long(),
|
||||
attention_mask=bidirectional_mask,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
if masked_indices.sum() > 0:
|
||||
valid_indices = torch.where(masked_indices)
|
||||
batch_indices, seq_indices = valid_indices
|
||||
|
||||
masked_logits = logits[batch_indices, seq_indices]
|
||||
masked_targets = input_ids[batch_indices, seq_indices]
|
||||
masked_p_mask = p_mask[batch_indices, seq_indices]
|
||||
|
||||
# Compute cross-entropy loss without reduction
|
||||
token_loss = F.cross_entropy(
|
||||
masked_logits.float(), masked_targets, reduction="none"
|
||||
)
|
||||
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
masked_p_mask = masked_p_mask.float()
|
||||
weighted_loss = token_loss / masked_p_mask
|
||||
else:
|
||||
weighted_loss = token_loss
|
||||
|
||||
if labels is not None:
|
||||
# For SFT data: normalize by answer token count per sample
|
||||
answer_mask = labels != -100
|
||||
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
|
||||
|
||||
# Get batch indices for masked tokens
|
||||
masked_batch_indices = batch_indices
|
||||
|
||||
# Sum losses per sample and divide by answer length
|
||||
batch_size = input_ids.shape[0]
|
||||
loss_per_sample = torch.zeros(batch_size, device=input_ids.device)
|
||||
for i in range(batch_size):
|
||||
sample_mask = masked_batch_indices == i
|
||||
if sample_mask.sum() > 0:
|
||||
sample_loss = weighted_loss[sample_mask].sum()
|
||||
denom = answer_lengths[i].clamp(min=1.0)
|
||||
loss_per_sample[i] = sample_loss / denom
|
||||
|
||||
loss = loss_per_sample.mean()
|
||||
else:
|
||||
# Non-SFT: when importance weighting is enabled, use unbiased estimator
|
||||
# (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
|
||||
# for stable scaling across varying mask ratios.
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
loss = weighted_loss.sum() / (
|
||||
input_ids.shape[0] * input_ids.shape[1]
|
||||
)
|
||||
else:
|
||||
loss = weighted_loss.mean()
|
||||
|
||||
ce_loss = token_loss.mean()
|
||||
|
||||
# Compute accuracy on masked tokens
|
||||
with torch.no_grad():
|
||||
pred_tokens = masked_logits.argmax(dim=-1)
|
||||
accuracy = (pred_tokens == masked_targets).float().mean()
|
||||
else:
|
||||
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
|
||||
accuracy = torch.tensor(0.0, device=input_ids.device)
|
||||
ce_loss = torch.tensor(0.0, device=input_ids.device)
|
||||
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
|
||||
|
||||
avg_p_mask = (
|
||||
p_mask[masked_indices].mean().item() if masked_indices.any() else 0.0
|
||||
)
|
||||
metrics = {
|
||||
"loss": loss.item(),
|
||||
"accuracy": accuracy.item(),
|
||||
"mask_ratio": masked_indices.float().mean().item(),
|
||||
"num_masked_tokens": (masked_indices.sum().item(), "sum"),
|
||||
"avg_p_mask": avg_p_mask,
|
||||
"ce_loss": ce_loss.item(),
|
||||
}
|
||||
|
||||
# If doing SFT training, log answer-specific metrics
|
||||
if self.cfg.datasets is not None:
|
||||
with torch.no_grad():
|
||||
answer_mask = labels != -100
|
||||
answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
|
||||
total_answer_tokens = answer_mask.sum().item() # type: ignore
|
||||
total_tokens = labels.numel() # type: ignore
|
||||
metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
|
||||
metrics["avg_answer_length"] = answer_lengths.mean().item()
|
||||
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
||||
|
||||
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
|
||||
self.store_metrics(metrics, train_eval=train_eval)
|
||||
|
||||
return loss, outputs
|
||||
159
src/axolotl/integrations/diffusion/utils.py
Normal file
159
src/axolotl/integrations/diffusion/utils.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Shared utilities for diffusion integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def resolve_mask_token_id(
|
||||
tokenizer: Any,
|
||||
cfg: DictDefault,
|
||||
*,
|
||||
allow_add: bool,
|
||||
model: Any | None = None,
|
||||
default_token: str = "<|diffusion_mask|>",
|
||||
) -> int:
|
||||
"""Resolve mask token id. Training may add a new special token; inference won't."""
|
||||
# Determine vocab size if available
|
||||
vocab_size = None
|
||||
if tokenizer is not None:
|
||||
if hasattr(tokenizer, "vocab_size") and tokenizer.vocab_size is not None:
|
||||
try:
|
||||
vocab_size = int(tokenizer.vocab_size) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
vocab_size = None
|
||||
elif hasattr(tokenizer, "__len__"):
|
||||
try:
|
||||
vocab_size = int(len(tokenizer))
|
||||
except Exception:
|
||||
vocab_size = None
|
||||
|
||||
# Use explicit id from config if provided
|
||||
diffusion_cfg = getattr(cfg, "diffusion", None)
|
||||
# Fallback to top-level attr names only if nested missing (shouldn't happen)
|
||||
cfg_id = (
|
||||
getattr(diffusion_cfg, "mask_token_id", None)
|
||||
if diffusion_cfg is not None
|
||||
else getattr(cfg, "diffusion_mask_token_id", None)
|
||||
)
|
||||
if isinstance(cfg_id, int) and cfg_id >= 0:
|
||||
if vocab_size is None or cfg_id < vocab_size:
|
||||
return int(cfg_id)
|
||||
|
||||
def _existing_special_token_id(token_str: str | None) -> int | None:
|
||||
"""Attempt to resolve an existing special token string to a real ID."""
|
||||
if not token_str or not hasattr(tokenizer, "convert_tokens_to_ids"):
|
||||
return None
|
||||
try:
|
||||
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not isinstance(token_id, int) or token_id < 0:
|
||||
return None
|
||||
|
||||
# Ensure it's registered as special and not UNK, and within vocab
|
||||
unk_id = getattr(tokenizer, "unk_token_id", None)
|
||||
specials = set(getattr(tokenizer, "all_special_tokens", []) or [])
|
||||
addl = set(getattr(tokenizer, "additional_special_tokens", []) or [])
|
||||
is_special = token_str in specials or token_str in addl
|
||||
in_vocab = vocab_size is None or token_id < vocab_size
|
||||
if (
|
||||
(unk_id is not None and token_id == unk_id)
|
||||
or not is_special
|
||||
or not in_vocab
|
||||
):
|
||||
return None
|
||||
return token_id
|
||||
|
||||
# Try mask token string if provided
|
||||
token_str = (
|
||||
getattr(diffusion_cfg, "mask_token_str", None)
|
||||
if diffusion_cfg is not None
|
||||
else getattr(cfg, "diffusion_mask_token_str", None)
|
||||
)
|
||||
for candidate in (token_str, default_token):
|
||||
token_id = _existing_special_token_id(candidate)
|
||||
if isinstance(token_id, int):
|
||||
try:
|
||||
if diffusion_cfg is None:
|
||||
cfg.diffusion_mask_token_id = int(token_id) # legacy fallback
|
||||
else:
|
||||
diffusion_cfg.mask_token_id = int(token_id)
|
||||
except Exception:
|
||||
pass
|
||||
return int(token_id)
|
||||
|
||||
# Optionally add and return a dedicated special token during training
|
||||
if allow_add and hasattr(tokenizer, "add_special_tokens"):
|
||||
token_to_add = token_str or default_token
|
||||
try:
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [token_to_add]})
|
||||
|
||||
# Resize embeddings if possible
|
||||
if (
|
||||
model is not None
|
||||
and hasattr(tokenizer, "__len__")
|
||||
and hasattr(model, "resize_token_embeddings")
|
||||
):
|
||||
try:
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
except Exception:
|
||||
pass
|
||||
new_id = tokenizer.convert_tokens_to_ids(token_to_add)
|
||||
if isinstance(new_id, int) and new_id >= 0:
|
||||
try:
|
||||
if diffusion_cfg is None:
|
||||
cfg.diffusion_mask_token_id = int(new_id) # legacy fallback
|
||||
else:
|
||||
diffusion_cfg.mask_token_id = int(new_id)
|
||||
except Exception:
|
||||
pass
|
||||
return int(new_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to unk or 0 (do not update cfg)
|
||||
fallback = getattr(tokenizer, "unk_token_id", 0) or 0
|
||||
return int(fallback)
|
||||
|
||||
|
||||
def create_bidirectional_attention_mask(
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sample_packing: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create bidirectional attention mask to override default causal masking.
|
||||
Handles sample-packed sequences where different samples are identified
|
||||
by different attention mask values.
|
||||
|
||||
Args:
|
||||
input_ids: Input token ids [batch_size, seq_len]
|
||||
attention_mask: Attention mask [batch_size, seq_len]
|
||||
sample_packing: Whether sample packing is enabled
|
||||
|
||||
Returns:
|
||||
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
if attention_mask is None or not sample_packing:
|
||||
return torch.ones(
|
||||
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Handle sample packing: tokens can only attend within their sample
|
||||
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
|
||||
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
|
||||
|
||||
# Tokens can attend to each other if they have the same non-zero sample ID
|
||||
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
|
||||
|
||||
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||
return bidirectional_mask.unsqueeze(1)
|
||||
@@ -14,6 +14,7 @@ from peft import (
|
||||
PeftConfig,
|
||||
PeftMixedModel,
|
||||
PeftModel,
|
||||
TaskType,
|
||||
get_peft_model,
|
||||
)
|
||||
from transformers import PreTrainedModel
|
||||
@@ -101,6 +102,15 @@ def load_lora(
|
||||
if cfg.peft_trainable_token_indices:
|
||||
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
||||
|
||||
# Determine the correct PEFT task type
|
||||
model_cls = type(model).__name__
|
||||
if "SequenceClassification" in model_cls:
|
||||
task_type = TaskType.SEQ_CLS
|
||||
elif "TokenClassification" in model_cls:
|
||||
task_type = TaskType.TOKEN_CLS
|
||||
else:
|
||||
task_type = TaskType.CAUSAL_LM
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
@@ -112,7 +122,7 @@ def load_lora(
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
task_type=task_type,
|
||||
**lora_config_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -673,6 +673,33 @@ class ModelLoader:
|
||||
|
||||
return hf_ds_cfg
|
||||
|
||||
def _load_model_from_config(self, model_loader_class=None) -> PreTrainedModel:
|
||||
"""
|
||||
Load model with random initialization using from_config.
|
||||
|
||||
Uses the selected loader when provided; otherwise falls back to the auto loader.
|
||||
"""
|
||||
loader = model_loader_class or self.auto_model_loader
|
||||
if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
|
||||
model = loader.from_config(
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
)
|
||||
else:
|
||||
model = loader(config=self.model_config)
|
||||
|
||||
return model
|
||||
|
||||
def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
|
||||
"""Load model from pretrained weights."""
|
||||
loader = model_loader_class or self.auto_model_loader
|
||||
kwargs = {
|
||||
"config": self.model_config,
|
||||
"trust_remote_code": self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
return loader.from_pretrained(self.base_model, **kwargs)
|
||||
|
||||
def _build_model(self) -> bool:
|
||||
"""Load model, with load strategy depending on config."""
|
||||
skip_move_to_device = False
|
||||
@@ -687,7 +714,8 @@ class ModelLoader:
|
||||
if self.is_fsdp_enabled:
|
||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
|
||||
# Don't delete device_map for QLoRA + FSDP - it was set correctly in
|
||||
# _set_device_map
|
||||
if (
|
||||
"device_map" in self.model_kwargs
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
@@ -716,6 +744,11 @@ class ModelLoader:
|
||||
or self.cfg.qlora_sharded_model_loading
|
||||
)
|
||||
):
|
||||
if self.cfg.reinit_weights:
|
||||
LOG.warning(
|
||||
"reinit_weights is not supported with sharded quantized loading. "
|
||||
"Loading from pretrained weights instead."
|
||||
)
|
||||
quant_storage = self.cfg.torch_dtype
|
||||
quantization_config = getattr(
|
||||
self.model_config, "quantization_config", None
|
||||
@@ -731,33 +764,12 @@ class ModelLoader:
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
self.model_config.model_type in ["llama", "llama4"]
|
||||
and not self.cfg.trust_remote_code
|
||||
and not self.cfg.gptq
|
||||
):
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
|
||||
# Load model with random initialization if specified
|
||||
if self.cfg.random_init_weights:
|
||||
# AutoModel classes support the from_config method
|
||||
if self.auto_model_loader in [
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
]:
|
||||
self.model = self.auto_model_loader.from_config(
|
||||
config=self.model_config,
|
||||
)
|
||||
else:
|
||||
self.model = self.auto_model_loader(config=self.model_config)
|
||||
else:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
elif self.model_type == "MambaLMHeadModel":
|
||||
if self.cfg.reinit_weights:
|
||||
LOG.warning(
|
||||
"reinit_weights is not supported with MambaLMHeadModel. "
|
||||
"Loading from pretrained weights instead."
|
||||
)
|
||||
# FIXME this is janky at best and hacked together to make it work
|
||||
MambaLMHeadModel = fix_mamba_attn_for_loss()
|
||||
|
||||
@@ -770,41 +782,27 @@ class ModelLoader:
|
||||
self.base_model,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
elif (
|
||||
self.model_type
|
||||
and self.model_type != "AutoModelForCausalLM"
|
||||
and not self.cfg.trust_remote_code
|
||||
):
|
||||
if self.cfg.gptq:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
self.model = getattr(transformers, self.model_type).from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
elif self.cfg.gptq:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
# Please don't remove underscore binding without reading the fn docstring
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_type
|
||||
and self.model_type != "AutoModelForCausalLM"
|
||||
and not self.cfg.trust_remote_code
|
||||
and not self.cfg.gptq
|
||||
):
|
||||
# Use model type from transformers
|
||||
model_loader_class = getattr(transformers, self.model_type)
|
||||
else:
|
||||
# Use auto model loader (handles gptq and default cases)
|
||||
model_loader_class = self.auto_model_loader
|
||||
|
||||
if self.cfg.reinit_weights:
|
||||
self.model = self._load_model_from_config(model_loader_class)
|
||||
else:
|
||||
self.model = self._load_model_from_pretrained(model_loader_class)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
Applies pre- and post-model load patches for various fixes and optimizations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import importlib.util
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
import addict
|
||||
@@ -68,11 +68,12 @@ class PatchManager:
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_fsdp2_bnb_patches()
|
||||
self._apply_patch_deepspeed_zero3()
|
||||
self._apply_voxtral_patches()
|
||||
self._apply_apertus_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
self._apply_voxtral_patches()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
@@ -83,6 +84,13 @@ class PatchManager:
|
||||
patch_evaluation_loop()
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
if self.cfg.context_parallel_size > 1:
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
@@ -168,6 +176,20 @@ class PatchManager:
|
||||
|
||||
patch_llama4_linearized_modeling()
|
||||
|
||||
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_modeling_packing,
|
||||
)
|
||||
|
||||
patch_qwen3_next_modeling_packing()
|
||||
|
||||
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
|
||||
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||
apply_mistral_tokenizer_image_patch,
|
||||
)
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
@@ -334,6 +356,13 @@ class PatchManager:
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
||||
|
||||
if self.model_config.model_type in ("mistral3", "llava"):
|
||||
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
|
||||
apply_patch_is_packed_sequence,
|
||||
)
|
||||
|
||||
apply_patch_is_packed_sequence()
|
||||
|
||||
def _patch_loss_llama(self):
|
||||
"""Patch loss functions and other optimizations for LLaMA models."""
|
||||
if not self.cfg.is_llama_derived_model:
|
||||
@@ -468,9 +497,10 @@ class PatchManager:
|
||||
|
||||
def _apply_patch_deepspeed_zero3(self):
|
||||
try:
|
||||
from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
|
||||
|
||||
if self.cfg.activation_offloading is True and (
|
||||
is_deepspeed_zero3_enabled()
|
||||
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
|
||||
@@ -478,3 +508,12 @@ class PatchManager:
|
||||
apply_deepspeed_patches()
|
||||
except ImportError as e:
|
||||
LOG.warning(f"DeepSpeed patches not applied: {e}")
|
||||
|
||||
def _apply_apertus_patches(self):
|
||||
"""Apply patches for Apertus model."""
|
||||
if self.cfg.model_config_type == "apertus":
|
||||
from axolotl.monkeypatch.models.apertus.activation import (
|
||||
patch_apertus_xielu_activation,
|
||||
)
|
||||
|
||||
patch_apertus_xielu_activation()
|
||||
|
||||
@@ -21,6 +21,13 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
if cfg.processor_type:
|
||||
processor_cls = getattr(transformers, cfg.processor_type)
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
from axolotl.utils.mistral import Mistral3Processor
|
||||
|
||||
return Mistral3Processor(
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
processor = processor_cls.from_pretrained(
|
||||
cfg.processor_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
|
||||
@@ -124,13 +124,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
|
||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||
"""Load mistral-common tokenizer"""
|
||||
from transformers import tokenization_mistral_common
|
||||
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
# patch
|
||||
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
|
||||
|
||||
# Load the HF-compatible wrapper around MistralTokenizer
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
||||
|
||||
@@ -296,7 +291,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
)
|
||||
|
||||
tokenizer.chat_template = chat_template_string
|
||||
else:
|
||||
elif getattr(tokenizer, "chat_template", None) is None:
|
||||
LOG.info(
|
||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||
)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""
|
||||
Common logging module for axolotl
|
||||
"""
|
||||
"""Common logging module for axolotl."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from logging import Formatter, Logger, LogRecord
|
||||
from logging.config import dictConfig
|
||||
from typing import Any, Dict
|
||||
@@ -17,9 +14,9 @@ DEFAULT_LOG_LEVEL = "WARNING"
|
||||
|
||||
class AxolotlOrWarnErrorFilter(logging.Filter):
|
||||
"""
|
||||
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL)
|
||||
Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL)
|
||||
Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default)
|
||||
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at
|
||||
INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records
|
||||
(i.e. non-axolotl.INFO, DEBUG, etc. by default).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -52,13 +49,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
|
||||
|
||||
|
||||
class AxolotlLogger(Logger):
|
||||
"""A Logger that automatically rejects non-axolotl INFOs."""
|
||||
"""Logger that applies filtering to non-axolotl loggers."""
|
||||
|
||||
def __init__(self, name: str, level: int = logging.NOTSET):
|
||||
super().__init__(name, level)
|
||||
|
||||
# set global filter on the logger itself
|
||||
self.addFilter(AxolotlOrWarnErrorFilter())
|
||||
if not name.startswith("axolotl"):
|
||||
self.addFilter(AxolotlOrWarnErrorFilter())
|
||||
|
||||
|
||||
class ColorfulFormatter(Formatter):
|
||||
@@ -74,6 +70,7 @@ class ColorfulFormatter(Formatter):
|
||||
|
||||
def format(self, record):
|
||||
record.rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
record.rank_fmt = f" [RANK:{record.rank}]" if record.rank != 0 else ""
|
||||
log_message = super().format(record)
|
||||
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
||||
|
||||
@@ -87,32 +84,54 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
},
|
||||
"colorful": {
|
||||
"()": ColorfulFormatter,
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s",
|
||||
},
|
||||
"concise": {
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
|
||||
},
|
||||
"concise_color": {
|
||||
"()": ColorfulFormatter,
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s",
|
||||
},
|
||||
},
|
||||
"filters": {
|
||||
"ax_or_warn": {
|
||||
"()": "axolotl.logging_config.AxolotlOrWarnErrorFilter",
|
||||
},
|
||||
},
|
||||
"filters": {},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "simple",
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
"formatter": "concise",
|
||||
"filters": ["ax_or_warn"],
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"color_console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "colorful",
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
"formatter": "concise_color",
|
||||
"filters": ["ax_or_warn"],
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"ax_file_only": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "simple",
|
||||
"stream": "ext://axolotl.utils.tee.file_only_stream",
|
||||
},
|
||||
"root_file_only": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "simple",
|
||||
"stream": "ext://axolotl.utils.tee.file_only_stream",
|
||||
},
|
||||
},
|
||||
# log level will be superseded by the AxolotlLogger
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL),
|
||||
"handlers": ["console", "root_file_only"],
|
||||
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper(),
|
||||
},
|
||||
"loggers": {
|
||||
"axolotl": {
|
||||
"handlers": ["color_console"],
|
||||
"handlers": ["color_console", "ax_file_only"],
|
||||
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
|
||||
"propagate": False,
|
||||
},
|
||||
@@ -123,9 +142,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
def configure_logging():
|
||||
"""Configure with default logging"""
|
||||
init() # Initialize colorama
|
||||
|
||||
dictConfig(DEFAULT_LOGGING_CONFIG)
|
||||
logging.setLoggerClass(AxolotlLogger)
|
||||
|
||||
# set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
|
||||
# Route Python warnings through logging so they reach file handlers
|
||||
logging.captureWarnings(True)
|
||||
|
||||
# Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
|
||||
if "ACCELERATE_LOG_LEVEL" not in os.environ:
|
||||
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
|
||||
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv(
|
||||
"LOG_LEVEL", DEFAULT_LOG_LEVEL
|
||||
).upper()
|
||||
|
||||
@@ -160,9 +160,11 @@ def get_state_dict(self, model, unwrap=True):
|
||||
state_dict[param_name] = param.cpu()
|
||||
torch.distributed.barrier()
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
from torch.distributed.fsdp import (
|
||||
FullStateDictConfig,
|
||||
FullyShardedDataParallel as FSDP,
|
||||
StateDictType,
|
||||
)
|
||||
|
||||
full_state_dict_config = FullStateDictConfig(
|
||||
offload_to_cpu=True, rank0_only=True
|
||||
@@ -366,6 +368,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
# removing the call above leads to extra memory usage as explained in the comment above
|
||||
if hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
model = model.to(torch.float32)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Flex attention monkey patch"""
|
||||
|
||||
import sys
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from packaging import version
|
||||
from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
0
src/axolotl/monkeypatch/models/apertus/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/apertus/__init__.py
Normal file
52
src/axolotl/monkeypatch/models/apertus/activation.py
Normal file
52
src/axolotl/monkeypatch/models/apertus/activation.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Monkeypatch for Apertus to dtype mismatch in XIELU act"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def patch_apertus_xielu_activation():
|
||||
try:
|
||||
from transformers.activations import XIELUActivation
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Cannot import XIELUActivation. "
|
||||
"Please make sure to update your transformers version >= 4.56.1."
|
||||
) from err
|
||||
|
||||
from transformers.activations import logger
|
||||
|
||||
# Store the original method
|
||||
old_fn = XIELUActivation._xielu_cuda
|
||||
|
||||
def _xielu_cuda_fixed(self, x: Tensor) -> Tensor:
|
||||
"""Firewall function to prevent torch.compile from seeing .item() calls"""
|
||||
original_shape = x.shape
|
||||
# CUDA kernel expects 3D tensors, reshape if needed
|
||||
while x.dim() < 3:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 3:
|
||||
x = x.view(-1, 1, x.size(-1))
|
||||
if original_shape != x.shape:
|
||||
logger.warning_once(
|
||||
"Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
|
||||
original_shape,
|
||||
x.shape,
|
||||
)
|
||||
result = self._xielu_cuda_obj.forward(
|
||||
x,
|
||||
self.alpha_p.to(x.dtype),
|
||||
self.alpha_n.to(x.dtype),
|
||||
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
|
||||
self._beta_scalar,
|
||||
self._eps_scalar,
|
||||
self.with_vector_loads,
|
||||
)
|
||||
return result.view(original_shape)
|
||||
|
||||
# Apply the patch
|
||||
XIELUActivation._xielu_cuda = _xielu_cuda_fixed
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original method"""
|
||||
XIELUActivation._xielu_cuda = old_fn
|
||||
|
||||
return unpatch
|
||||
0
src/axolotl/monkeypatch/models/mistral3/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/mistral3/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user