Compare commits

...

18 Commits

Author SHA1 Message Date
Dan Saunders
3299f182ba ungate lora with bias 2025-09-25 12:40:13 -04:00
Dan Saunders
2fc430d365 update lora optims doc 2025-09-25 12:24:25 -04:00
Dan Saunders
f9748c4dc5 Cp fix (#3182)
* patch transformers to allow CP + FA2

* nits

* only patch in CP > 1 case
2025-09-25 12:03:50 -04:00
miketung
33975ce4bc feat(qwen3-next): Adds targeting of shared expert and attention modules (#3183)
* Adds targetting of shared expert and attention modules in each layer

* Update VRAM usage

---------

Co-authored-by: Mike Tung <mike@diffbot.com>
2025-09-25 17:06:16 +07:00
陈华杰
e8b962d47f feat: support training with JSON string tool arguments (#3136)
* feat: support training with JSON string tool arguments; fix PyArrow data type inconsistent error

* feat: raise error for tool call arguments decode

* Add test_chat_templates_tool_call_string_arguments.py

Add test for string arguments

* fix: change to correct qwen3 tokenizer

* fix: update docs to clarify arguments json

* chore: lint

* fix: duplicate

* chore: revert

* feat: add error to faq

* fix: remove duplicate fixture

---------

Co-authored-by: caoqinping <caoqinping@lixiang.com>
Co-authored-by: gamersover-blog <1611885128@qq.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-09-25 12:06:21 +07:00
NanoCode012
856ff12171 feat(doc): add optimizations table of content to our improvements (#3175) [skip ci]
* chore: format

* feat: add usage for alst

* chore: wording

* feat: add optimizations doc

* Apply suggestion from @SalmanMohammadi

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/dataset-formats/index.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* feat: add alst, act offloading, nd parallelism, use relative links, and fix format

* chore: comments

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-09-24 16:13:49 -04:00
Dan Saunders
6bc959342b remove unused dep (#3180) 2025-09-24 13:18:44 -04:00
NanoCode012
b3b92687c4 chore: rename gemma3 270m config (#3174) 2025-09-24 13:48:38 +07:00
NanoCode012
55d1be2ae6 fix: unify default for conversations_field [skip-e2e] (#3070)
* fix: unify default for conversations_field

* fix: suggestion to remove defaults
2025-09-23 21:22:15 +07:00
NanoCode012
08d831c3d5 Feat: add qwen3-next (w packing+cce) (#3150)
* feat: upgrade cce for qwen3-next

* feat: add sample qwen3 config

* feat: add packing patch for chunk_gated_delta_rule

* feat: add qwen3 link

* fix: tuple name

* feat: add tested qwen3 config

* fix: improve log

* feat: add patch for fla without packing

* fix: remove fla patch for standard mode

* feat: enable packing

* feat: add qwen3-next tests

* chore: move tests
2025-09-23 11:31:15 +07:00
AlexHT Hung
7be8740c5c fix(rl): pass max_prompt_len to training args as max_prompt_length (#3113)
* pass max_prompt_len to training args as max_prompt_length

* Update rl.py

* refactor

* format

* fix: default for max_prompt_length

* fix: defaults for trainer

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-09-19 17:34:28 +07:00
NanoCode012
c51d6b06c3 feat: add apertus model and cce (#3144) [skip ci]
* feat: add apertus, glm4v, glm4v_moe cce

* fix: arcee docs

* feat: add apertus

* feat: added vram usage

* fix: add apertus note

* feat: update doc on apertus xielu

* fix: add monkeypatch for xielu activation issue

* fix: simplify env

* feat: pin commit

* feat: add packing

* chore: move patch calling

* Update examples/apertus/README.md

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update examples/apertus/README.md

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update examples/apertus/README.md

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-09-19 17:34:04 +07:00
NanoCode012
09959fac70 Feat: add Magistral Small 2509 and native mistral3 tokenizer support (#3165)
* feat: update mistral common

* feat: add mistral3processor

* fix: loading

* fix: cast pixel_values to fp32

* fix: image tensor conversion

* feat: add FA2 support for pixtral based models

* fix: update mistral small 3.1 to use native tokenizer

* fix: install tips

* fix: improve info on sample dataset files

* chore: move mistral configs into subfolders

* fix: remove unneeded patch

* fix: indent

* feat: add integration tests

* chore: move

* feat: add magistral 2509 docs and example

* fix: convert tensor to bool

* feat: expand tests

* chore: move tests
2025-09-18 15:42:20 +07:00
Dan Saunders
4065bc14c6 Debug log, logging improvements (#3159)
* simplify logging

* remove comment

* progress on debug.log

* add debug-level logger for file log

* simplify

* case insensitivity; 3rd party logging improvements

* simplify

* fix

* tests

* lint

* nits

* nit

* Update tests/test_utils_tee.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* cleanup / comments

* fix

* oops

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-09-17 13:27:03 -04:00
salman
e5c427f6de qat doc updates (#3162) [skip-ci] 2025-09-17 10:38:15 +01:00
Wing Lian
86d6ee7c05 upgrade trl and accelerate (#3161)
* upgrade trl==0.23.0

* upgrade accelerate patch fix

* add hints when using gradient_checkpointing with DPO

* set gradient-checpointing properly
2025-09-16 14:53:01 -04:00
Wing Lian
d4cff1b7bb improve setting of NCCL_P2P_DISABLE on runpod (#3132) [skip ci]
* improve setting of NCCL_P2P_DISABLE on runpod

* use recs from review
2025-09-16 14:52:45 -04:00
Wing Lian
1ef6c196f7 setup env vars for ray train for FSDP (#3130) [skip ci] 2025-09-16 14:52:29 -04:00
105 changed files with 2728 additions and 245 deletions

3
.gitignore vendored
View File

@@ -190,3 +190,6 @@ out/
# vim
*.swp
# scm auto-versioning
src/axolotl/_version.py

View File

@@ -14,7 +14,7 @@ repos:
rev: v0.12.12
hooks:
- id: ruff
args: [--fix, --select, I]
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1

View File

@@ -267,6 +267,7 @@ website:
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- docs/optimizations.qmd
- section: "Core Concepts"
contents:

View File

@@ -212,6 +212,14 @@ Instead of passing `tools` via the system prompt, an alternative method would be
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
::: {.callout-warning}
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
```
"arguments": "{\"...\": \"...\"}"
```
:::
Example config for Llama4:
```yaml
chat_template: llama4

View File

@@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet
### Pre-training without streaming
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.

View File

@@ -140,3 +140,7 @@ description: Frequently asked questions
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
**Q: `Error parsing tool_calls arguments as JSON.`
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.

View File

@@ -5,10 +5,11 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.
(including DDP, DeepSpeed, and FSDP2) training. These include (1) SwiGLU and GEGLU
activation function Triton kernels, and (2) LoRA MLP and attention custom autograd
functions. Our goal was to leverage operator fusion and tensor re-use in order to
improve speed and reduce memory usage during the forward and backward passes of these
calculations.
We currently support several common model architectures, including (but not limited to):
@@ -92,13 +93,12 @@ Currently, LoRA kernels are not supported for RLHF training, only SFT.
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
- Targeted LoRA adapters cannot use Dropout
- This may limit model expressivity / cause overfitting
- Targeted LoRA adapters cannot have bias terms
- Targeted LoRA adapters must disable dropout (`lora_dropout: 0`)
- This may limit model expressivity
- Adapters that already include bias terms are supported.
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
be re-finetuned without these features in order to be useful.
Models with pre-existing LoRA adapters that use Dropout may need to be re-finetuned
without it in order to be as performant.
## Implementation details
@@ -131,6 +131,5 @@ computation path.
## Future Work
- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Support for dropout
- Additional operator fusions

View File

@@ -13,6 +13,7 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Magistral-Small-2509](#sec-magistral-small-2509)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
@@ -41,7 +42,6 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
# (optional) if doing lora, only finetune the Language model,
# leave the vision model and vision tower frozen
@@ -94,10 +94,22 @@ chat_template: llava
### Mistral-Small-3.1 {#sec-mistral-small-31}
::: {.callout-tip}
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
:::
```yaml
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
```
chat_template: mistral_v7_tekken
### Magistral-Small-2509 {#sec-magistral-small-2509}
::: {.callout-tip}
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
:::
```yaml
base_model: mistralai/Magistral-Small-2509
```
### Voxtral {#sec-voxtral}

133
docs/optimizations.qmd Normal file
View File

@@ -0,0 +1,133 @@
---
title: Optimizations Guide
description: A guide to the performance and memory optimizations available in Axolotl.
---
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
## Speed Optimizations
These optimizations focus on increasing training throughput and reducing total training time.
### Sample Packing
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
- **Config:** `sample_packing: true`
- **Learn more:** [Sample Packing](multipack.qmd)
### Attention Implementations
Using an optimized attention implementation is critical for training speed.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
*Note: You should only enable one attention backend.*
### LoRA Optimizations
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
## Memory Optimizations
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
### Parameter Efficient Finetuning (LoRA & QLoRA)
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
### Gradient Checkpointing & Activation Offloading
These techniques save VRAM by changing how activations are handled.
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
### Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
### Liger Kernels
Provides efficient Triton kernels to improve training speed and reduce memory usage.
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
## Long Context Models
Techniques to train models on sequences longer than their original context window.
### RoPE Scaling
Extends a model's context window by interpolating its Rotary Position Embeddings.
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
### Sequence Parallelism
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
### Artic Long Sequence Training (ALST)
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
- TiledMLP to reduce memory usage in MLP layers.
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
- Activation Offloading to CPU.
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
## Large Models (Distributed Training)
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
### N-D Parallelism (Beta)
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
## Quantization
Techniques to reduce the precision of model weights for memory savings.
### 4-bit Training (QLoRA)
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
### FP8 Training
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
### Quantization Aware Training (QAT)
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
- **Learn more:** [QAT Documentation](qat.qmd)
### GPTQ
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)

View File

@@ -23,10 +23,18 @@ To enable QAT in axolotl, add the following to your configuration file:
```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
We support the following quantization schemas:
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
- `Int8DynamicActivationInt4Weight`
- `Float8DynamicActivationFloat8Weight`
- `Float8DynamicActivationInt4Weight`
- `NVFP4`
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.

View File

@@ -22,8 +22,8 @@ Quantization is configured using the `quantization` key in your configuration fi
```yaml
base_model: # The path to the model to quantize.
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
@@ -39,9 +39,8 @@ you used to train the model:
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int8
weight_dtype: int4
group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```

View File

@@ -7,3 +7,24 @@ techniques. It is a combination of:
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
## Usage
```yaml
tiled_mlp: true
# See Sequence Parallelism docs
# https://docs.axolotl.ai/docs/sequence_parallelism.html
context_parallel_size: int
plugins:
# See Cut Cross Entropy docs
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# or Liger Kernel docs
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
- axolotl.integrations.liger.LigerPlugin
# ...
```

110
examples/apertus/README.md Normal file
View File

@@ -0,0 +1,110 @@
# Finetune Swiss-AI's Apertus with Axolotl
[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. (Optional, highly recommended) Install XIELU CUDA
```bash
## Recommended for reduced VRAM and faster speeds
# Point to CUDA toolkit directory
# For those using our Docker image, use the below path.
export CUDA_HOME=/usr/local/cuda
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
3. Run the finetuning example:
```bash
axolotl train examples/apertus/apertus-8b-qlora.yaml
```
This config uses about 8.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### Tips
- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`.
- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
### XIELU Installation Issues
#### `ModuleNotFoundError: No module named 'torch'`
Please check these one by one:
- Running in correct environment
- Env has PyTorch installed
- CUDA toolkit is at `CUDA_HOME`
If those didn't help, please try the below solutions:
1. Pass env for CMAKE and try install again:
```bash
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
2. Git clone the repo and manually hardcode python path:
```bash
git clone https://github.com/nickjbrowning/XIELU
cd xielu
git checkout 59d6031
cd xielu
nano CMakeLists.txt # or vi depending on your preference
```
```diff
execute_process(
- COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
+ COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
)
```
```bash
pip3 install . --no-build-isolation --no-deps
```
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,64 @@
base_model: swiss-ai/Apertus-8B-Instruct-2509
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -19,6 +19,9 @@ cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:

View File

@@ -9,10 +9,6 @@ strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
]
},
{
@@ -251,10 +251,10 @@
},
"outputs": [],
"source": [
"from axolotl.utils import patch_optimized_env\n",
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
"\n",
"# speedup downloads from HF 🤗 and set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
"patch_optimized_env()"
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
"set_pytorch_cuda_alloc_conf()"
]
},
{

View File

@@ -9,10 +9,6 @@ strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -9,10 +9,6 @@ strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -18,7 +18,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out

View File

@@ -23,7 +23,15 @@ pip3 install timm==1.0.17
pip3 install librosa==0.11.0
```
3. Run the finetuning example:
3. Download sample dataset files
```bash
# for text + vision + audio only
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
```
4. Run the finetuning example:
```bash
# text only

View File

@@ -12,15 +12,6 @@ chat_template: llama3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
roles:
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -46,7 +46,6 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -45,7 +45,6 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -1,10 +1,10 @@
# Finetune Magistral Small with Axolotl
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506) and [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
Thanks to the team at MistralAI for giving us early access to prepare for this release.
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
## Getting started
@@ -36,29 +36,17 @@ Let us know how it goes. Happy finetuning! 🚀
### Thinking
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities. The model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
Example format:
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
```json
{
"messages": [
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
{"role": "user", "content": [{ "type": "text", "text": "..."}]},
{"role": "assistant", "content": [{ "type": "thinking", "thinking": "..."}, { "type": "text", "text": "..." }]},
],
}
```
### Vision
Example config: `./magistral-small-think-qlora.yaml`.
MistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities.
The `thinking` section also supports an optional arg `closed: bool` (`True` default) which controls adding the closing `[/THINK]` tag.
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
Limitations:
- You cannot mix `content: str` with `content: list[dict]` as the `dataset.load_dataset` may complain about different types for `content` key.
- This mode does not work with custom `train_detail` and `training` at the moment.
### TIPS
### Tips
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
@@ -89,5 +77,5 @@ In addition, we do not support overriding tokens yet.
## Future Work
- Add parity to Preference Tuning, RL, Multi-modal, etc.
- Add parity to Preference Tuning, RL, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -0,0 +1,73 @@
# Magistral Small Thinking Fine-tuning
This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [main README](../README.md))
## Getting Started
Run the thinking model fine-tuning:
```bash
axolotl train magistral-small-think-qlora.yaml
```
This config uses about 19.1 GiB VRAM.
### Tips
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
## Dataset Format
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
Example format:
```json
{
"messages": [
{
"role": "system",
"content": [
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
]
},
{
"role": "user",
"content": [
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
]
},
{
"role": "assistant",
"content": [
{
"type": "thinking",
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
},
{
"type": "text",
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
}
]
}
]
}
```
### Advanced Options
The `thinking` section supports an optional `closed` parameter:
```json
{
"type": "thinking",
"thinking": "Internal reasoning here...",
"closed": true // Default: true, controls adding the closing [/THINK] tag
}
```

View File

@@ -0,0 +1,60 @@
# Magistral Small Vision Fine-tuning
This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl from source (see [main README](../README.md#getting-started))
## Getting started
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.5'
```
2. Download the example dataset image:
```bash
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
```
3. Run the fine-tuning:
```bash
axolotl train magistral-small-vision-24B-qlora.yml
```
This config uses about 17GiB VRAM.
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
### Tips
Key differences from text-only model:
- `max_tokens: 131072` for inference
- Multi-modal dataset format required
- Sample packing not supported
## Dataset Format
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
Example:
```json
{
"messages": [
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
{"role": "user", "content": [
{ "type": "text", "text": "What's in this image?"},
{"type": "image", "path": "path/to/image.jpg" }
]},
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
],
}
```
## Limitations
- Sample Packing is not supported for multi-modality training currently.

View File

@@ -0,0 +1,64 @@
base_model: mistralai/Magistral-Small-2509
processor_type: AutoProcessor
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# sample dataset below requires downloading image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,6 +1,9 @@
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
processor_type: AutoProcessor
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
load_in_8bit: true
# these 3 lines are needed for now to handle vision chat templates w images
@@ -8,12 +11,12 @@ skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: mistral_v7_tekken
# sample dataset below requires downloading image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
- path: Nanobit/text-vision-2k-test
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
@@ -48,8 +51,7 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
# flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
sdp_attention: true
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -12,15 +12,6 @@ chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
roles:
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -45,8 +45,7 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
# flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet
sdp_attention: true
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -11,7 +11,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -11,7 +11,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -0,0 +1,64 @@
# Finetune Qwen3-Next with Axolotl
[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
```
4. Run the finetuning example:
```bash
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about 45.62 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,68 @@
base_model: Qwen/Qwen3-Next-80B-A3B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0.05
lora_target_modules:
- linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz
- linear_attn.out_proj
- shared_expert.up_proj
- shared_expert.down_proj
- shared_expert.gate_proj
- shared_expert_gate
- mlp.gate
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -27,7 +27,14 @@ pip3 install 'mistral_common[audio]==1.8.3'
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
3. Download sample dataset files
```bash
# for text + audio only
wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga
```
4. Run the finetuning example:
```bash
# text only

View File

@@ -32,7 +32,7 @@ line-length = 88
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "W", "C90", "B"]
select = ["E", "F", "W", "C90", "B", "I"]
ignore = [
"E203", # Whitespace before ':'
"E501", # Line too long

View File

@@ -15,10 +15,10 @@ huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.56.1
tokenizers>=0.21.1
accelerate==1.10.0
accelerate==1.10.1
datasets==4.0.0
deepspeed>=0.17.0
trl==0.21.0
trl==0.23.0
hf_xet==1.1.5
kernels==0.9.0
trackio
@@ -70,4 +70,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
mistral-common==1.8.3
mistral-common==1.8.5

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
)

View File

@@ -124,7 +124,6 @@ extras_require = {
"ring-flash-attn": [
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.5",

View File

@@ -4,5 +4,7 @@ import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
configure_logging()

View File

@@ -23,7 +23,8 @@ from axolotl.utils.config import (
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.tee import prepare_debug_log
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
@@ -227,8 +228,11 @@ def load_cfg(
},
)
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
# have to wait for cfg.output to be resolved. We could call this earlier if we write
# to a temporary file, and then move it later.
prepare_debug_log(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
@@ -241,7 +245,6 @@ def load_cfg(
for k, v in cfg.items()
if v is not None
}
LOG.info(
"config:\n%s",
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),

View File

@@ -17,8 +17,6 @@ from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.cli.utils.diffusion import (
diffusion_inference,
launch_diffusion_gradio_ui,
render_html,
run_diffusion,
)
from axolotl.integrations.base import PluginManager
from axolotl.utils.chat_templates import get_chat_template_from_config

View File

@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -44,7 +44,7 @@ def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art()
load_dotenv()
patch_optimized_env()
set_pytorch_cuda_alloc_conf()
@cli.command()

View File

@@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import prepare_optim_env
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
@@ -59,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
@@ -92,6 +92,7 @@ def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"])
prepare_optim_env(cfg)
normalize_config(cfg)
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import gradio as gr
import torch
from colorama import Fore, Style
from axolotl.integrations.diffusion import generate, resolve_mask_token_id

View File

@@ -435,7 +435,7 @@ class TrainerBuilderBase(abc.ABC):
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False
training_args_kwargs["activation_offloading"] = True
elif self.cfg.gradient_checkpointing:
elif self.cfg.gradient_checkpointing is not None:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)

View File

@@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
@@ -129,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
# Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
@@ -144,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))

View File

@@ -8,7 +8,7 @@ from typing import Any, Mapping
def chat_message_transform_builder(
train_on_inputs=False,
conversations_field: str = "conversations",
conversations_field: str = "messages",
message_field_role: str | list[str] | None = None, # commonly "role"
message_field_content: str | list[str] | None = None, # commonly "content"
message_field_training: str | list[str] | None = None, # commonly "weight"
@@ -20,13 +20,13 @@ def chat_message_transform_builder(
If True, the transform will train on the inputs. If False, the transform will train on the targets.
Defaults to False.
conversations_field (str, optional):
The field name of the conversations. Defaults to "conversations".
The field name of the conversations. Defaults to "messages".
message_field_role (str | list[str], optional):
The field name of the role. Defaults to "role".
The field name of the role.
message_field_content (str | list[str], optional):
The field name of the message content. Defaults to "content".
The field name of the message content.
message_field_training (str | list[str], optional):
The field name of the train/weight. Defaults to "weight".
The field name of the train/weight.
Returns:
Callable:

View File

@@ -27,7 +27,6 @@ class DPOStrategy:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
```
## Usage
@@ -65,6 +65,7 @@ plugins:
- qwen2_5_vl
- qwen3
- qwen3_moe
- qwen3_next
- smollm3
- seed_oss
- voxtral

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
)

View File

@@ -68,11 +68,12 @@ class PatchManager:
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
self._apply_patch_deepspeed_zero3()
self._apply_voxtral_patches()
self._apply_apertus_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_voxtral_patches()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
@@ -83,6 +84,13 @@ class PatchManager:
patch_evaluation_loop()
patch_maybe_log_save_evaluate()
if self.cfg.context_parallel_size > 1:
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs,
)
patch_prepare_context_parallel_inputs()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
@@ -168,6 +176,20 @@ class PatchManager:
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_modeling_packing,
)
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
apply_mistral_tokenizer_image_patch()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
@@ -334,6 +356,13 @@ class PatchManager:
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
if self.model_config.model_type in ("mistral3", "llava"):
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
apply_patch_is_packed_sequence,
)
apply_patch_is_packed_sequence()
def _patch_loss_llama(self):
"""Patch loss functions and other optimizations for LLaMA models."""
if not self.cfg.is_llama_derived_model:
@@ -479,3 +508,12 @@ class PatchManager:
apply_deepspeed_patches()
except ImportError as e:
LOG.warning(f"DeepSpeed patches not applied: {e}")
def _apply_apertus_patches(self):
"""Apply patches for Apertus model."""
if self.cfg.model_config_type == "apertus":
from axolotl.monkeypatch.models.apertus.activation import (
patch_apertus_xielu_activation,
)
patch_apertus_xielu_activation()

View File

@@ -21,6 +21,13 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
if cfg.tokenizer_use_mistral_common:
from axolotl.utils.mistral import Mistral3Processor
return Mistral3Processor(
tokenizer=tokenizer,
)
processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,

View File

@@ -124,13 +124,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
def _load_mistral_common_tokenizer(cfg: DictDefault):
"""Load mistral-common tokenizer"""
from transformers import tokenization_mistral_common
from axolotl.utils.mistral import HFMistralTokenizer
# patch
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)

View File

@@ -1,10 +1,7 @@
"""
Common logging module for axolotl
"""
"""Common logging module for axolotl."""
import logging
import os
import sys
from logging import Formatter, Logger, LogRecord
from logging.config import dictConfig
from typing import Any, Dict
@@ -17,9 +14,9 @@ DEFAULT_LOG_LEVEL = "WARNING"
class AxolotlOrWarnErrorFilter(logging.Filter):
"""
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL)
Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL)
Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default)
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at
INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records
(i.e. non-axolotl.INFO, DEBUG, etc. by default).
"""
def __init__(self, **kwargs):
@@ -52,13 +49,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
class AxolotlLogger(Logger):
"""A Logger that automatically rejects non-axolotl INFOs."""
"""Logger that applies filtering to non-axolotl loggers."""
def __init__(self, name: str, level: int = logging.NOTSET):
super().__init__(name, level)
# set global filter on the logger itself
self.addFilter(AxolotlOrWarnErrorFilter())
if not name.startswith("axolotl"):
self.addFilter(AxolotlOrWarnErrorFilter())
class ColorfulFormatter(Formatter):
@@ -74,6 +70,7 @@ class ColorfulFormatter(Formatter):
def format(self, record):
record.rank = int(os.getenv("LOCAL_RANK", "0"))
record.rank_fmt = f" [RANK:{record.rank}]" if record.rank != 0 else ""
log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
@@ -87,32 +84,54 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
},
"colorful": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s",
},
"concise": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
},
"concise_color": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s",
},
},
"filters": {
"ax_or_warn": {
"()": "axolotl.logging_config.AxolotlOrWarnErrorFilter",
},
},
"filters": {},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "simple",
"filters": [],
"stream": sys.stdout,
"formatter": "concise",
"filters": ["ax_or_warn"],
"stream": "ext://sys.stdout",
},
"color_console": {
"class": "logging.StreamHandler",
"formatter": "colorful",
"filters": [],
"stream": sys.stdout,
"formatter": "concise_color",
"filters": ["ax_or_warn"],
"stream": "ext://sys.stdout",
},
"ax_file_only": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "simple",
"stream": "ext://axolotl.utils.tee.file_only_stream",
},
"root_file_only": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "simple",
"stream": "ext://axolotl.utils.tee.file_only_stream",
},
},
# log level will be superseded by the AxolotlLogger
"root": {
"handlers": ["console"],
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL),
"handlers": ["console", "root_file_only"],
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper(),
},
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"handlers": ["color_console", "ax_file_only"],
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
"propagate": False,
},
@@ -123,9 +142,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
def configure_logging():
"""Configure with default logging"""
init() # Initialize colorama
dictConfig(DEFAULT_LOGGING_CONFIG)
logging.setLoggerClass(AxolotlLogger)
# set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
# Route Python warnings through logging so they reach file handlers
logging.captureWarnings(True)
# Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
if "ACCELERATE_LOG_LEVEL" not in os.environ:
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv(
"LOG_LEVEL", DEFAULT_LOG_LEVEL
).upper()

View File

@@ -323,8 +323,8 @@ def apply_lora_kernel_patches(
AssertionError: If multiple adapters are active (currently unsupported).
Note:
The optimizations require LoRA adapters with no dropout and no bias terms. The
function will skip patching if these conditions aren't met.
The optimizations require LoRA adapters with no dropout. The function will skip
patching if that condition isn't met.
"""
if not isinstance(model, PeftModelForCausalLM):
raise TypeError("Model must be a PeftModelForCausalLM")
@@ -340,10 +340,10 @@ def apply_lora_kernel_patches(
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
can_patch = lora_config.lora_dropout == 0
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Cannot patch layers - requires `lora_dropout: 0`")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model

View File

@@ -0,0 +1,52 @@
"""Monkeypatch for Apertus to dtype mismatch in XIELU act"""
from torch import Tensor
def patch_apertus_xielu_activation():
try:
from transformers.activations import XIELUActivation
except ImportError as err:
raise ImportError(
"Cannot import XIELUActivation. "
"Please make sure to update your transformers version >= 4.56.1."
) from err
from transformers.activations import logger
# Store the original method
old_fn = XIELUActivation._xielu_cuda
def _xielu_cuda_fixed(self, x: Tensor) -> Tensor:
"""Firewall function to prevent torch.compile from seeing .item() calls"""
original_shape = x.shape
# CUDA kernel expects 3D tensors, reshape if needed
while x.dim() < 3:
x = x.unsqueeze(0)
if x.dim() > 3:
x = x.view(-1, 1, x.size(-1))
if original_shape != x.shape:
logger.warning_once(
"Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
original_shape,
x.shape,
)
result = self._xielu_cuda_obj.forward(
x,
self.alpha_p.to(x.dtype),
self.alpha_n.to(x.dtype),
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
self._beta_scalar,
self._eps_scalar,
self.with_vector_loads,
)
return result.view(original_shape)
# Apply the patch
XIELUActivation._xielu_cuda = _xielu_cuda_fixed
def unpatch():
"""Restore the original method"""
XIELUActivation._xielu_cuda = old_fn
return unpatch

View File

@@ -0,0 +1,85 @@
"""
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
"""
import importlib
import inspect
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def apply_mistral_tokenizer_image_patch():
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonTokenizer
# Get original source
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
original_source, _ = detab_code(original_source)
# Define the replacement
original_tensor_conversion = (
" pixel_values = torch.tensor(images)"
)
patched_tensor_conversion = """ if isinstance(images, list) and len(images) > 0 and isinstance(images[0], np.ndarray):
pixel_values = torch.tensor(np.array(images))
else:
pixel_values = torch.tensor(images)"""
# Apply the replacement
if original_tensor_conversion in original_source:
patched_source = original_source.replace(
original_tensor_conversion, patched_tensor_conversion
)
patched_source = patched_source.replace(
"def apply_chat_template(",
"def patched_apply_chat_template(",
1,
)
# Load necessary imports from the module
module_name = MistralCommonTokenizer.__module__
module = importlib.import_module(module_name)
# Detect what needs to be imported
items_to_import = []
for item in dir(module):
if item in patched_source and not item.startswith("_"):
items_to_import.append(item)
# Execute imports in global scope
if items_to_import:
exec( # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
# Also need standard imports that might be used
exec("import numpy as np", globals()) # nosec B102
exec("import torch", globals()) # nosec B102
exec("from typing import Union, Optional, List, Dict, Any, Callable", globals()) # nosec B102
exec("from pathlib import Path", globals()) # nosec B102
# Import other dependencies that might be needed
try:
exec("from transformers.utils import is_torch_available", globals()) # nosec B102
exec(
"from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType",
globals(),
) # nosec B102
exec("from transformers.utils import logging", globals()) # nosec B102
exec("logger = logging.get_logger(__name__)", globals()) # nosec B102
except ImportError as e:
LOG.warning(f"Could not import some dependencies: {e}")
# Execute the patched source
exec(patched_source, globals()) # nosec B102
# Replace the method
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
else:
LOG.warning("Could not find target code for MistralCommonTokenizer patching")

View File

@@ -0,0 +1,42 @@
"""Monkeypatch for FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
import torch
def apply_patch_is_packed_sequence():
"""Apply patch to FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
from transformers import modeling_flash_attention_utils
def fixed_is_packed_sequence(position_ids, batch_size):
"""
Check the position ids whether packed sequences are indicated or not
1. Position ids exist
2. Flattened sequences only are supported
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
"""
if position_ids is None:
return False
if position_ids.ndim == 1:
position_ids = position_ids.unsqueeze(0) # [N] -> [1, N]
increasing_position_sequences = (
torch.arange(position_ids.shape[1], device=position_ids.device)
+ position_ids.min()
)
return (
batch_size == 1
and (increasing_position_sequences - position_ids).abs().sum().bool().item()
)
# Store original method
old_fn = modeling_flash_attention_utils._is_packed_sequence
# Apply the patch
modeling_flash_attention_utils._is_packed_sequence = fixed_is_packed_sequence
def unpatch():
"""Restore the original method"""
modeling_flash_attention_utils._is_packed_sequence = old_fn
return unpatch

View File

@@ -0,0 +1 @@
"""Qwen3_Next model monkeypatches."""

View File

@@ -0,0 +1,317 @@
"""Monkeypatch for Qwen3_Next model to pass position_ids to linear attention."""
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_cu_seqlens(position_ids):
"""
Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.
https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316
"""
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
position_ids = position_ids.view(-1)
indices_q = (position_ids == 0).nonzero().view(-1)
cu_seq_lens_q = torch.cat(
(
indices_q.to(**tensor_kwargs),
torch.tensor(position_ids.size(), **tensor_kwargs),
)
)
return cu_seq_lens_q
def patch_qwen3_next_decoder_layer():
"""Patch Qwen3NextDecoderLayer to pass position_ids to linear attention."""
try:
from transformers.models.qwen3_next.modeling_qwen3_next import (
Qwen3NextDecoderLayer,
)
except ImportError:
LOG.warning("Qwen3Next model not found, skipping patch")
return
# Store original forward method
original_decoder_forward = Qwen3NextDecoderLayer.forward
def patched_decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Token Mixer
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
position_ids=position_ids,
)
elif self.layer_type == "full_attention":
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# For the MoE layers, we need to unpack
if isinstance(hidden_states, Tuple):
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
# Apply the patches
Qwen3NextDecoderLayer.forward = patched_decoder_forward
def unpatch():
"""Restore the original forward method"""
Qwen3NextDecoderLayer.forward = original_decoder_forward
return unpatch
def patch_qwen3_next_gateddelta_layer():
"""Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule"""
try:
from transformers.models.qwen3_next.modeling_qwen3_next import (
Qwen3NextDynamicCache,
Qwen3NextGatedDeltaNet,
apply_mask_to_padding_states,
)
except ImportError:
LOG.warning("Qwen3Next model not found, skipping patch")
return
# Store original forward method
original_gated_delta_net_forward = Qwen3NextGatedDeltaNet.forward
def patched_gated_delta_net_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[Qwen3NextDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_position is not None
)
# getting projected states from cache if it exists
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
projected_states_ba = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = (
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
else:
if cache_params is not None:
conv_state = F.pad(
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx] = conv_state
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)
else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
# Update cache
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
z_shape_og = z.shape
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = core_attn_out.reshape(
core_attn_out.shape[0], core_attn_out.shape[1], -1
)
output = self.out_proj(core_attn_out)
return output
# Apply the patches
Qwen3NextGatedDeltaNet.forward = patched_gated_delta_net_forward
def unpatch():
"""Restore the original forward method"""
Qwen3NextGatedDeltaNet.forward = original_gated_delta_net_forward
return unpatch
def patch_qwen3_next_imports():
"""Patch Qwen3Next imports to use try/except instead of is_flash_linear_attention_available."""
try:
import transformers.models.qwen3_next.modeling_qwen3_next as qwen3_modeling
except ImportError:
LOG.warning("Qwen3Next model not found, skipping import patch")
return
# Save original values for unpatch
original_FusedRMSNormGated = getattr(qwen3_modeling, "FusedRMSNormGated", None)
original_chunk_gated_delta_rule = getattr(
qwen3_modeling, "chunk_gated_delta_rule", None
)
original_fused_recurrent_gated_delta_rule = getattr(
qwen3_modeling, "fused_recurrent_gated_delta_rule", None
)
original_is_fast_path_available = getattr(
qwen3_modeling, "is_fast_path_available", False
)
try:
from fla.modules import FusedRMSNormGated
from fla.ops.gated_delta_rule import (
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
qwen3_modeling.FusedRMSNormGated = FusedRMSNormGated
qwen3_modeling.chunk_gated_delta_rule = chunk_gated_delta_rule
qwen3_modeling.fused_recurrent_gated_delta_rule = (
fused_recurrent_gated_delta_rule
)
# Force is_fast_path_available to be True
# fla has triton kernels for causal_conv1d
qwen3_modeling.is_fast_path_available = True
except ImportError:
qwen3_modeling.chunk_gated_delta_rule = None
qwen3_modeling.fused_recurrent_gated_delta_rule = None
qwen3_modeling.FusedRMSNormGated = None
def unpatch():
"""Restore the original import values"""
qwen3_modeling.FusedRMSNormGated = original_FusedRMSNormGated
qwen3_modeling.chunk_gated_delta_rule = original_chunk_gated_delta_rule
qwen3_modeling.fused_recurrent_gated_delta_rule = (
original_fused_recurrent_gated_delta_rule
)
qwen3_modeling.is_fast_path_available = original_is_fast_path_available
return unpatch
def patch_qwen3_next_modeling_packing():
"""Apply all Qwen3Next model patches."""
patch_qwen3_next_imports()
patch_qwen3_next_decoder_layer()
patch_qwen3_next_gateddelta_layer()
LOG.info("Applied Qwen3Next patch for packing")

View File

@@ -11,6 +11,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"apertus",
"mllama_text_model",
"llama",
"llama4",
@@ -20,6 +21,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"qwen2_moe",
"qwen3",
"qwen3_moe",
"qwen3_next",
"falcon",
"phi",
"phi3",

View File

@@ -0,0 +1,68 @@
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
from __future__ import annotations
import importlib
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
PATCHED_GUARD = (
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
)
def patch_prepare_context_parallel_inputs() -> None:
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
return
try:
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
except OSError as exc: # pragma: no cover - occurs when source is unavailable
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
return
if GUARD_PATTERN not in original_source:
LOG.warning(
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
"skipping FlashAttention context parallelism patch"
)
return
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
patched_source, _ = detab_code(patched_source)
patched_source = patched_source.replace(
"def _prepare_context_parallel_inputs(",
"def axolotl_prepare_context_parallel_inputs(",
1,
)
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# import symbols referenced in the method so exec can succeed
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
exec(patched_source, globals())
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
LOG.debug(
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
)

View File

@@ -41,7 +41,7 @@ def patch_evaluation_loop():
"""Patch the evaluation_loop method."""
# Check if already patched
if hasattr(Trainer, "_original_evaluation_loop"):
LOG.info("Trainer.evaluation_loop already patched")
LOG.debug("Trainer.evaluation_loop already patched")
return
# Check if the patterns exist
@@ -84,7 +84,7 @@ def patch_evaluation_loop():
)
exec(evaluation_loop_source, globals())
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
LOG.debug("Patched Trainer.evaluation_loop with nanmean loss calculation")
Trainer.evaluation_loop = axolotl_evaluation_loop
@@ -135,5 +135,5 @@ def patch_maybe_log_save_evaluate():
)
exec(maybe_log_source, globals())
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
LOG.debug("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate

View File

@@ -11,6 +11,7 @@ from transformers.image_utils import load_image
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
LOG = get_logger(__name__)
@@ -421,6 +422,36 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy):
]
class Mistral3ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Mistral3"""
def __init__(
self,
processor: Mistral3Processor,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
special_ids = (
processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids
)
self.image_token = special_ids.img
self.image_break_token = special_ids.img_break
self.image_end_token = special_ids.img_end
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.processor.tokenizer.pad_token_id] = -100
labels[labels == self.image_token] = -100
labels[labels == self.image_break_token] = -100
labels[labels == self.image_end_token] = -100
return labels
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -463,6 +494,11 @@ def get_processing_strategy(
**processing_kwargs,
)
if isinstance(processor, Mistral3Processor):
return Mistral3ProcessingStrategy(
**processing_kwargs,
)
# llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy(

View File

@@ -2,6 +2,7 @@
HF Chat Templates prompt strategy
"""
import json
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
@@ -794,6 +795,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if val is not None:
transformed_message[key] = val
if "tool_calls" in transformed_message and transformed_message["tool_calls"]:
for tool_call in transformed_message["tool_calls"]:
if "function" in tool_call and "arguments" in tool_call["function"]:
args = tool_call["function"]["arguments"]
if isinstance(args, str):
try:
tool_call["function"]["arguments"] = json.loads(args)
except json.JSONDecodeError as e:
LOG.error(
f"Error parsing tool_calls arguments as JSON. "
f"Function: {tool_call.get('function', {}).get('name', 'unknown')}, "
f"Arguments string: {args!r}, "
f"Error: {e}"
)
raise
return transformed_message
def _get_images(self, prompt):

View File

@@ -196,10 +196,11 @@ def execute_training(
)
)
LOG.info("Starting trainer...")
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
# if cfg.bf16:
# torch.set_default_dtype(torch.bfloat16)
LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
plugin_manager = PluginManager.get_instance()

View File

@@ -44,15 +44,6 @@ def set_pytorch_cuda_alloc_conf():
)
def patch_optimized_env():
"""
Patch environment variables to improve VRAM usage and increase download speed
"""
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()
def get_not_null(value, default=None):
"""
return the value if it's not None, otherwise return the default value

View File

@@ -2,6 +2,8 @@
utils to get GPU info for the current environment
"""
import os
import subprocess # nosec B404
from importlib.metadata import version
from accelerate.utils.environment import (
@@ -14,6 +16,8 @@ from packaging.version import Version, parse
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
if not check_runpod_p2p_support():
return False
unsupported_devices = {"RTX 6000 Ada", "L40S"}
try:
device_names, device_count = get_gpu_info()
@@ -29,6 +33,39 @@ def check_cuda_p2p_ib_support():
return True
def check_runpod_p2p_support() -> bool:
if "RUNPOD_GPU_COUNT" not in os.environ:
return True
try:
gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1"))
except ValueError:
return True
if gpu_count >= 2:
# run `nvidia-smi topo -p2p n` and inspect the GPU0 row
try:
result = subprocess.run( # nosec B603 B607
["nvidia-smi", "topo", "-p2p", "n"],
check=True,
capture_output=True,
text=True,
timeout=5,
)
except (
subprocess.CalledProcessError,
FileNotFoundError,
subprocess.TimeoutExpired,
):
return True # fail-open if detection fails
output_lines = result.stdout.strip().split("\n")
# filter rows that start with "GPU0" (avoid header row)
gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")]
if not gpu0_rows:
return True
# consider P2P supported if any OK is present in the GPU0 row
return "OK" in gpu0_rows[-1]
return True
def get_package_version(package: str) -> Version:
version_str = version(package)
return parse(version_str)

View File

@@ -2,7 +2,6 @@
import functools
import logging
import os
from axolotl.utils.distributed import is_main_process
@@ -40,10 +39,6 @@ class MultiProcessAdapter(logging.LoggerAdapter):
def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:
if log_level is None:
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
logger = logging.getLogger(name)
if log_level is not None:
logger.setLevel(log_level.upper())
logger.root.setLevel(log_level.upper())
logger.setLevel(logging.DEBUG)
return MultiProcessAdapter(logger, extra={})

View File

@@ -1,5 +1,6 @@
"""Init for `axolotl.utils.mistral` module."""
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
__all__ = ["HFMistralTokenizer"]
__all__ = ["HFMistralTokenizer", "Mistral3Processor"]

View File

@@ -0,0 +1,169 @@
"""Processor for Mistral3 multimodal models with image support"""
from typing import Any, Dict, Optional, Union
import torch
from transformers import ProcessorMixin
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessingKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
class Mistral3ProcessorKwargs(ProcessingKwargs):
_defaults: Dict[str, Dict[str, Any]] = {
"text_kwargs": {
"padding": True,
},
"common_kwargs": {
"return_tensors": "pt",
"return_dict": True,
"tokenize": True,
},
}
class Mistral3Processor(ProcessorMixin):
"""
Processor for Mistral3 multimodal models that handles text and images.
Wraps HFMistralTokenizer and adds image processing capabilities.
"""
attributes = ["tokenizer"]
tokenizer_class = "HFMistralTokenizer"
def __init__(self, tokenizer: HFMistralTokenizer):
# Don't call super().__init__ to avoid the class validation issue
self.tokenizer = tokenizer
@property
def chat_template(self) -> None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return None
@property
def audio_tokenizer(self) -> None:
"""Audio tokenizer is not supported. Dummy method to satisfy HuggingFace API."""
return None
def _merge_kwargs(
self, processor_kwargs_class: Any, **kwargs: Any
) -> Dict[str, Dict[str, Any]]:
"""Merge kwargs with defaults similar to ProcessorMixin"""
defaults = processor_kwargs_class._defaults
output_kwargs: Dict[str, Dict[str, Any]] = {}
for kwarg_type, default_values in defaults.items():
output_kwargs[kwarg_type] = {**default_values}
# Update with provided kwargs
for key, value in kwargs.items():
# Try to match key to appropriate kwarg type
if key in ["padding", "truncation", "max_length"]:
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
elif key in ["return_tensors", "return_dict", "tokenize"]:
output_kwargs.setdefault("common_kwargs", {}).update({key: value})
else:
# Add to text_kwargs by default
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
return output_kwargs
def apply_chat_template(
self,
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
**kwargs: Any,
) -> Union[BatchFeature, str, list[str]]:
"""
Apply chat template with image support for Mistral3.
Similar to VoxtralProcessor, this method extracts images from the conversation,
calls the tokenizer's apply_chat_template, then adds pixel_values and image_sizes
to the result.
"""
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
text_kwargs = output_kwargs["text_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
return_tensors = common_kwargs.pop("return_tensors", "pt")
if return_tensors != "pt":
raise ValueError(
f"{self.__class__.__name__} only supports `return_tensors='pt'`."
)
return_dict = common_kwargs.pop("return_dict", False)
tokenize = common_kwargs.pop("tokenize", False)
# Determine if batched
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple))
or hasattr(conversation[0], "content")
):
is_batched = True
conversations = conversation
else:
is_batched = False
conversations = [conversation] # type: ignore
# Call tokenizer's apply_chat_template
tokenizer_kwargs = {**text_kwargs, **common_kwargs}
tokenizer_kwargs["return_tensors"] = return_tensors
tokenizer_kwargs["tokenize"] = tokenize
tokenizer_kwargs["return_dict"] = return_dict
encoded_instruct_inputs = self.tokenizer.apply_chat_template(
conversations,
**tokenizer_kwargs,
)
if tokenize:
if return_dict:
# The tokenizer already handles pixel_values, we just need to add image_sizes
if hasattr(encoded_instruct_inputs, "items"):
data: Dict[str, Any] = dict(encoded_instruct_inputs) # type: ignore
elif hasattr(encoded_instruct_inputs, "data"):
data = encoded_instruct_inputs.data # type: ignore
else:
raise ValueError("Unknown data type")
if "pixel_values" in data:
pixel_values = data["pixel_values"]
# MistralTokenizer returns a Double, so we convert to fp32
data["pixel_values"] = pixel_values.to(dtype=torch.float32)
# Always batched: [B, C, H, W] -> image_sizes: [B, 2]
# Since tensor is homogeneous, all images have same H, W
batch_size = pixel_values.shape[0]
image_sizes = torch.tensor([pixel_values.shape[-2:]] * batch_size)
data["image_sizes"] = image_sizes
return BatchFeature(data=data, tensor_type=return_tensors)
if not is_batched:
return encoded_instruct_inputs[0]
return encoded_instruct_inputs
def __call__(
self,
text: Optional[
Union[
TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
]
],
**kwargs: Any,
) -> BatchFeature:
"""
Forward text processing to the tokenizer.
This method does not support images - use apply_chat_template instead.
"""
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
text_kwargs = output_kwargs["text_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
out = self.tokenizer(text, **text_kwargs)
return BatchFeature(
data=out, tensor_type=common_kwargs.pop("return_tensors", None)
)

View File

@@ -436,8 +436,8 @@ class AxolotlInputConfig(
},
)
min_sample_len: int | None = None
max_prompt_len: int = Field(
default=512,
max_prompt_len: int | None = Field(
default=None,
json_schema_extra={"description": "maximum prompt length for RL training"},
)
sample_packing: bool | None = Field(

View File

@@ -1378,6 +1378,21 @@ class ComplexValidationMixin:
return self
def hint_gradient_checkpointing_dpo_lora_ddp(self):
if (
(self.gradient_checkpointing is True or self.gradient_checkpointing is None)
and self.capabilities
and self.capabilities.get("n_gpu", 1) > 1
and self.adapter in ("lora", "qlora")
and self.rl == RLType.DPO
and not self.fsdp
and not self.deepspeed
):
LOG.warning(
"gradient_checkpointing with DPO + DDP + LoRA is not recommended."
)
return self
class DistributedValidationMixin:
"""validation for distributed training."""

166
src/axolotl/utils/tee.py Normal file
View File

@@ -0,0 +1,166 @@
"""
Utilities for managing the debug log file and providing a file-only stream for logging
handlers.
"""
from __future__ import annotations
import io
import os
import sys
import threading
from pathlib import Path
from typing import TextIO, cast
_lock = threading.Lock()
_file_handle: io.TextIOWrapper | None = None
_log_path: str | None = None
_tee_installed: bool = False
_orig_stdout: TextIO | None = None
_orig_stderr: TextIO | None = None
class _FileOnlyWriter(io.TextIOBase):
"""A stream-like object that writes only to the tee file.
Before the file is prepared, writes are dropped (no-op).
"""
def write(self, s: str) -> int: # type: ignore[override]
with _lock:
if _file_handle is not None:
_file_handle.write(s)
return len(s)
return len(s)
def flush(self) -> None: # type: ignore[override]
with _lock:
if _file_handle is not None:
try:
_file_handle.flush()
except Exception:
pass
file_only_stream: io.TextIOBase = _FileOnlyWriter()
class _StreamTee(io.TextIOBase):
"""A minimal tee that mirrors writes to the debug log file.
Installed only after the debug log is prepared; no buffering.
"""
def __init__(self, stream: io.TextIOBase):
self._stream = stream
def write(self, s: str) -> int: # type: ignore[override]
with _lock:
n = self._stream.write(s)
if _file_handle is not None:
_file_handle.write(s)
return n
def flush(self) -> None: # type: ignore[override]
with _lock:
self._stream.flush()
if _file_handle is not None:
try:
_file_handle.flush()
except Exception:
pass
@property
def encoding(self): # type: ignore[override]
return getattr(self._stream, "encoding", None)
@property
def errors(self): # type: ignore[override]
return getattr(self._stream, "errors", None)
def isatty(self): # type: ignore[override]
return getattr(self._stream, "isatty", lambda: False)()
def fileno(self): # type: ignore[override]
if hasattr(self._stream, "fileno"):
return self._stream.fileno()
raise OSError("Underlying stream has no fileno")
def prepare_debug_log(cfg, filename: str = "debug.log") -> str:
"""
Prepare the debug log.
Creates the output directory, handles append/truncate logic based on cfg, and opens
the debug log file for subsequent writes via file-only handlers.
"""
global _file_handle, _log_path, _tee_installed
with _lock:
# If already initialized, reuse existing path
if _log_path is not None:
return _log_path
output_dir = cfg.output_dir
os.makedirs(output_dir, exist_ok=True)
log_path = Path(output_dir) / filename
append = bool(
cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints")
)
if not append and log_path.exists():
log_path.unlink()
fh = open(log_path, "a", encoding="utf-8")
fh.flush()
_file_handle = fh
_log_path = str(log_path)
# Install a tee so stdout/stderr are mirrored to the debug file
# Allow disabling via env for testing or advanced usage.
tee_enabled = os.getenv("AXOLOTL_TEE_STDOUT", "1").lower() not in {
"0",
"false",
"no",
}
if tee_enabled and not _tee_installed:
# Save originals so we can restore later (e.g., tests)
global _orig_stdout, _orig_stderr
_orig_stdout = sys.stdout
_orig_stderr = sys.stderr
sys.stdout = _StreamTee(cast(io.TextIOBase, sys.stdout))
sys.stderr = _StreamTee(cast(io.TextIOBase, sys.stderr))
_tee_installed = True
return _log_path
def close_debug_log() -> None:
"""Flush and close the debug log and uninstall the stdout/stderr tee.
Safe to call even if not initialized.
"""
global _file_handle, _log_path, _tee_installed, _orig_stdout, _orig_stderr
with _lock:
# Restore original stdout/stderr if we installed a tee
if _tee_installed:
if _orig_stdout is not None:
sys.stdout = _orig_stdout
if _orig_stderr is not None:
sys.stderr = _orig_stderr
_tee_installed = False
_orig_stdout = None
_orig_stderr = None
# Close the file handle if open
if _file_handle is not None:
try:
_file_handle.flush()
_file_handle.close()
except Exception:
pass
finally:
_file_handle = None
_log_path = None

View File

@@ -31,6 +31,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
if checkpoints:
last_checkpoint = str(checkpoints[-1])
if not update:
LOG.info(f"Resuming from last checkpoint at {last_checkpoint}")
return last_checkpoint
if (
@@ -40,6 +41,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
):
cfg.resume_from_checkpoint = last_checkpoint
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
"Using auto-resume functionality to resume from checkpoint at "
f"{cfg.resume_from_checkpoint}"
)
return cfg.resume_from_checkpoint

View File

@@ -655,15 +655,6 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def prepare_opinionated_env(cfg):
if cfg.qlora_sharded_model_loading:
# model loading is forked after the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.sample_packing:
# multipack parallel packing sampler defaults to using fork
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(
cfg,
train_dataset,

View File

@@ -199,7 +199,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"gradient_checkpointing": False,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,
@@ -278,7 +278,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"gradient_checkpointing": False,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,

View File

@@ -221,44 +221,53 @@ def test_model_specific_activation(model_name, expected_activation):
assert layer.mlp.forward.__func__ is expected_activation
def test_kernel_patch_conditions():
"""Test various conditions that should prevent kernel patching."""
test_configs = [
# Dropout prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
},
# Bias prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "lora_only",
},
]
def test_kernel_patch_requires_zero_dropout():
"""Kernel patching should be skipped when dropout is enabled."""
config = {
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
}
for config in test_configs:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
# Verify no patches applied when dropout is non-zero
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
def test_kernel_patch_with_bias_enabled():
"""Kernel patching should succeed when LoRA bias is enabled."""
config = {
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "lora_only",
}
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify patches applied when bias support is enabled
assert layer.forward.__func__ is apply_lora_mlp_swiglu
def test_kernel_config_options():

View File

@@ -0,0 +1,35 @@
"""Integration tests for MistralCommonTokenizer patches."""
import pytest
class TestMistralTokenizerPatchIntegration:
"""Test MistralCommonTokenizer patch integration."""
@pytest.mark.integration
def test_mistral_tokenizer_image_patch(self):
"""Test that MistralCommonTokenizer image patch can be applied."""
try:
from transformers.tokenization_mistral_common import MistralCommonTokenizer
except ImportError:
pytest.skip("MistralCommonTokenizer not available")
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
# Store original method
original_apply_chat_template = MistralCommonTokenizer.apply_chat_template
# Apply patch
apply_mistral_tokenizer_image_patch()
# Verify patch was applied
assert (
MistralCommonTokenizer.apply_chat_template != original_apply_chat_template
), "apply_chat_template was not patched"
# Verify the method is still callable
assert callable(MistralCommonTokenizer.apply_chat_template), (
"Patched method is not callable"
)

View File

@@ -0,0 +1,77 @@
"""Integration tests for Pixtral Flash Attention patches."""
import pytest
import torch
class TestPixtralFlashAttentionPatchIntegration:
"""Test Pixtral Flash Attention patch integration."""
@pytest.mark.integration
def test_pixtral_flash_attention_patch(self):
"""Test that Pixtral Flash Attention patch can be applied and works correctly."""
try:
from transformers import modeling_flash_attention_utils
except ImportError:
pytest.skip("Flash Attention utils not available")
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
apply_patch_is_packed_sequence,
)
# Store original method
original_is_packed_sequence = modeling_flash_attention_utils._is_packed_sequence
# Apply patch and get unpatch function
unpatch_fn = apply_patch_is_packed_sequence()
# Verify patch was applied
assert (
modeling_flash_attention_utils._is_packed_sequence
!= original_is_packed_sequence
), "_is_packed_sequence was not patched"
# Test the patched function with 1D position_ids
patched_fn = modeling_flash_attention_utils._is_packed_sequence
# Test 1D position_ids 1 sequence
position_ids_1d = torch.tensor([0, 1, 2, 3])
result = patched_fn(position_ids_1d, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is False, "1D sequential position_ids should not be packed"
# Test 1D packed 2 sequences
position_ids_1d_packed = torch.tensor([0, 1, 2, 0, 1, 2])
result = patched_fn(position_ids_1d_packed, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is True, "1D packed position_ids should be detected as packed"
# Test 2D packed 2 sequences
position_ids_2d_packed = torch.tensor([[0, 1, 2, 3, 0, 1]])
result = patched_fn(position_ids_2d_packed, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is True, "2D packed position_ids should be detected as packed"
# Test 2D 1 sequence
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5]])
result = patched_fn(position_ids_2d_normal, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is False, "2D sequential position_ids should not be packed"
# Test 2D batch size 2
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])
result = patched_fn(position_ids_2d_normal, batch_size=2)
assert isinstance(result, bool), "Function should return a boolean"
assert result is False, "2D position_ids batch 2 should not be packed"
# Test None case
result = patched_fn(None, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is False, "None position_ids should return False"
# Test unpatch function
unpatch_fn()
assert (
modeling_flash_attention_utils._is_packed_sequence
== original_is_packed_sequence
), "unpatch function did not restore original method"

View File

@@ -0,0 +1,111 @@
"""Integration tests for Qwen3 Next modeling patches."""
import pytest
import torch
# Skip entire module if qwen3_next not available
qwen3_next = pytest.importorskip("transformers.models.qwen3_next.modeling_qwen3_next")
class TestQwen3NextModelingPatchIntegration:
"""Test Qwen3 Next modeling patch integration."""
@pytest.mark.integration
def test_qwen3_next_decoder_layer_patch(self):
"""Test that Qwen3Next decoder layer patch can be applied."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_decoder_layer,
)
# Store original method
original_forward = qwen3_next.Qwen3NextDecoderLayer.forward
# Apply patch and get unpatch function
unpatch_fn = patch_qwen3_next_decoder_layer()
# Verify patch was applied
assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, (
"decoder layer forward method was not patched"
)
# Verify the method is still callable
assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), (
"Patched method is not callable"
)
# Test unpatch function
if unpatch_fn:
unpatch_fn()
assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, (
"unpatch function did not restore original method"
)
@pytest.mark.integration
def test_qwen3_next_gateddelta_layer_patch(self):
"""Test that Qwen3Next GatedDeltaNet patch can be applied."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_gateddelta_layer,
)
# Store original method
original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward
# Apply patch and get unpatch function
unpatch_fn = patch_qwen3_next_gateddelta_layer()
# Verify patch was applied
assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, (
"GatedDeltaNet forward method was not patched"
)
# Verify the method is still callable
assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), (
"Patched method is not callable"
)
# Test unpatch function
if unpatch_fn:
unpatch_fn()
assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, (
"unpatch function did not restore original method"
)
@pytest.mark.integration
def test_qwen3_next_imports_patch(self):
"""Test that Qwen3Next imports patch can be applied without errors."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_imports,
)
# Apply patch - should not raise any exceptions even if modules unavailable
unpatch_fn = patch_qwen3_next_imports()
# Test that unpatch function is returned (or None if skipped)
assert unpatch_fn is None or callable(unpatch_fn), (
"patch_qwen3_next_imports should return None or callable unpatch function"
)
@pytest.mark.integration
def test_qwen3_next_modeling_packing_patch(self):
"""Test that all Qwen3Next modeling patches can be applied together."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_modeling_packing,
)
# This should not raise any exceptions
patch_qwen3_next_modeling_packing()
@pytest.mark.integration
def test_get_cu_seqlens_utility():
"""Test the get_cu_seqlens utility function."""
from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens
# Test with simple position_ids
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
cu_seqlens = get_cu_seqlens(position_ids)
assert cu_seqlens.dtype == torch.int32, "Should be int32 dtype"
# Should return tensor with start positions and total length
expected = torch.tensor([0, 3, 5], dtype=torch.int32)
assert torch.equal(cu_seqlens, expected), f"Expected {expected}, got {cu_seqlens}"

View File

@@ -0,0 +1,66 @@
"""Tests for the HF Trainer context parallel patch."""
import pytest
from transformers import Trainer
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
GUARD_PATTERN,
PATCHED_GUARD,
patch_prepare_context_parallel_inputs,
)
@pytest.fixture
def restore_trainer_prepare_method():
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
original_method = getattr(
Trainer,
"_original_prepare_context_parallel_inputs",
Trainer._prepare_context_parallel_inputs,
)
patched_attr_present = hasattr(
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
)
yield
Trainer._prepare_context_parallel_inputs = original_method
if patched_attr_present:
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
def test_patch_attention_guard(restore_trainer_prepare_method):
"""Patch should swap the guard to allow sdpa or flash attention."""
# Ensure we start from the unpatched method
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
Trainer._prepare_context_parallel_inputs = (
Trainer._original_prepare_context_parallel_inputs
)
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
patch_prepare_context_parallel_inputs()
patched_method = Trainer._prepare_context_parallel_inputs
assert patched_method is not None
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
source = Trainer._axolotl_prepare_context_parallel_inputs_source
assert GUARD_PATTERN not in source
assert PATCHED_GUARD in source
def test_patch_is_idempotent(restore_trainer_prepare_method):
"""Calling the patch twice should leave the same patched function in place."""
patch_prepare_context_parallel_inputs()
first_patched = Trainer._prepare_context_parallel_inputs
patch_prepare_context_parallel_inputs()
second_patched = Trainer._prepare_context_parallel_inputs
assert first_patched is second_patched

View File

@@ -0,0 +1,43 @@
"""Integration tests for Voxtral modeling patches."""
import pytest
class TestVoxtralModelingPatchIntegration:
"""Test Voxtral modeling patch integration."""
@pytest.mark.integration
def test_voxtral_conditional_generation_patch(self):
"""Test that Voxtral conditional generation patch can be applied."""
try:
from transformers.models.voxtral.modeling_voxtral import (
VoxtralForConditionalGeneration,
)
except ImportError:
pytest.skip("VoxtralForConditionalGeneration not available")
from axolotl.monkeypatch.models.voxtral.modeling import (
patch_voxtral_conditional_generation_forward,
)
# Store original method
original_forward = VoxtralForConditionalGeneration.forward
# Apply patch and get unpatch function
unpatch_fn = patch_voxtral_conditional_generation_forward()
# Verify patch was applied
assert VoxtralForConditionalGeneration.forward != original_forward, (
"forward method was not patched"
)
# Verify the method is still callable
assert callable(VoxtralForConditionalGeneration.forward), (
"Patched method is not callable"
)
# Test unpatch function
unpatch_fn()
assert VoxtralForConditionalGeneration.forward == original_forward, (
"unpatch function did not restore original method"
)

View File

@@ -177,6 +177,15 @@ def fixture_devstral_1_1_tokenizer():
return tokenizer
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'

Some files were not shown because too many files have changed in this diff Show More