Compare commits

...

6 Commits

Author SHA1 Message Date
Dan Saunders
3299f182ba ungate lora with bias 2025-09-25 12:40:13 -04:00
Dan Saunders
2fc430d365 update lora optims doc 2025-09-25 12:24:25 -04:00
Dan Saunders
f9748c4dc5 Cp fix (#3182)
* patch transformers to allow CP + FA2

* nits

* only patch in CP > 1 case
2025-09-25 12:03:50 -04:00
miketung
33975ce4bc feat(qwen3-next): Adds targeting of shared expert and attention modules (#3183)
* Adds targetting of shared expert and attention modules in each layer

* Update VRAM usage

---------

Co-authored-by: Mike Tung <mike@diffbot.com>
2025-09-25 17:06:16 +07:00
陈华杰
e8b962d47f feat: support training with JSON string tool arguments (#3136)
* feat: support training with JSON string tool arguments; fix PyArrow data type inconsistent error

* feat: raise error for tool call arguments decode

* Add test_chat_templates_tool_call_string_arguments.py

Add test for string arguments

* fix: change to correct qwen3 tokenizer

* fix: update docs to clarify arguments json

* chore: lint

* fix: duplicate

* chore: revert

* feat: add error to faq

* fix: remove duplicate fixture

---------

Co-authored-by: caoqinping <caoqinping@lixiang.com>
Co-authored-by: gamersover-blog <1611885128@qq.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-09-25 12:06:21 +07:00
NanoCode012
856ff12171 feat(doc): add optimizations table of content to our improvements (#3175) [skip ci]
* chore: format

* feat: add usage for alst

* chore: wording

* feat: add optimizations doc

* Apply suggestion from @SalmanMohammadi

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/dataset-formats/index.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* feat: add alst, act offloading, nd parallelism, use relative links, and fix format

* chore: comments

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-09-24 16:13:49 -04:00
20 changed files with 617 additions and 72 deletions

View File

@@ -267,6 +267,7 @@ website:
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- docs/optimizations.qmd
- section: "Core Concepts"
contents:

View File

@@ -212,6 +212,14 @@ Instead of passing `tools` via the system prompt, an alternative method would be
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
::: {.callout-warning}
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
```
"arguments": "{\"...\": \"...\"}"
```
:::
Example config for Llama4:
```yaml
chat_template: llama4

View File

@@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet
### Pre-training without streaming
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.

View File

@@ -140,3 +140,7 @@ description: Frequently asked questions
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
**Q: `Error parsing tool_calls arguments as JSON.`
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.

View File

@@ -5,10 +5,11 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.
(including DDP, DeepSpeed, and FSDP2) training. These include (1) SwiGLU and GEGLU
activation function Triton kernels, and (2) LoRA MLP and attention custom autograd
functions. Our goal was to leverage operator fusion and tensor re-use in order to
improve speed and reduce memory usage during the forward and backward passes of these
calculations.
We currently support several common model architectures, including (but not limited to):
@@ -92,13 +93,12 @@ Currently, LoRA kernels are not supported for RLHF training, only SFT.
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
- Targeted LoRA adapters cannot use Dropout
- This may limit model expressivity / cause overfitting
- Targeted LoRA adapters cannot have bias terms
- Targeted LoRA adapters must disable dropout (`lora_dropout: 0`)
- This may limit model expressivity
- Adapters that already include bias terms are supported.
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
be re-finetuned without these features in order to be useful.
Models with pre-existing LoRA adapters that use Dropout may need to be re-finetuned
without it in order to be as performant.
## Implementation details
@@ -131,6 +131,5 @@ computation path.
## Future Work
- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Support for dropout
- Additional operator fusions

133
docs/optimizations.qmd Normal file
View File

@@ -0,0 +1,133 @@
---
title: Optimizations Guide
description: A guide to the performance and memory optimizations available in Axolotl.
---
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
## Speed Optimizations
These optimizations focus on increasing training throughput and reducing total training time.
### Sample Packing
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
- **Config:** `sample_packing: true`
- **Learn more:** [Sample Packing](multipack.qmd)
### Attention Implementations
Using an optimized attention implementation is critical for training speed.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
*Note: You should only enable one attention backend.*
### LoRA Optimizations
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
## Memory Optimizations
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
### Parameter Efficient Finetuning (LoRA & QLoRA)
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
### Gradient Checkpointing & Activation Offloading
These techniques save VRAM by changing how activations are handled.
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
### Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
### Liger Kernels
Provides efficient Triton kernels to improve training speed and reduce memory usage.
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
## Long Context Models
Techniques to train models on sequences longer than their original context window.
### RoPE Scaling
Extends a model's context window by interpolating its Rotary Position Embeddings.
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
### Sequence Parallelism
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
### Artic Long Sequence Training (ALST)
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
- TiledMLP to reduce memory usage in MLP layers.
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
- Activation Offloading to CPU.
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
## Large Models (Distributed Training)
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
### N-D Parallelism (Beta)
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
## Quantization
Techniques to reduce the precision of model weights for memory savings.
### 4-bit Training (QLoRA)
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
### FP8 Training
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
### Quantization Aware Training (QAT)
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
- **Learn more:** [QAT Documentation](qat.qmd)
### GPTQ
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)

View File

@@ -30,6 +30,7 @@ qat:
```
We support the following quantization schemas:
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
- `Int8DynamicActivationInt4Weight`
- `Float8DynamicActivationFloat8Weight`

View File

@@ -7,3 +7,24 @@ techniques. It is a combination of:
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
## Usage
```yaml
tiled_mlp: true
# See Sequence Parallelism docs
# https://docs.axolotl.ai/docs/sequence_parallelism.html
context_parallel_size: int
plugins:
# See Cut Cross Entropy docs
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# or Liger Kernel docs
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
- axolotl.integrations.liger.LigerPlugin
# ...
```

View File

@@ -38,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about 41.7 GiB VRAM.
This config uses about 45.62 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀

View File

@@ -27,6 +27,14 @@ lora_r: 16
lora_alpha: 8
lora_dropout: 0.05
lora_target_modules:
- linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz
- linear_attn.out_proj
- shared_expert.up_proj
- shared_expert.down_proj
- shared_expert.gate_proj
- shared_expert_gate
- mlp.gate
- q_proj
- v_proj
- k_proj

View File

@@ -84,6 +84,13 @@ class PatchManager:
patch_evaluation_loop()
patch_maybe_log_save_evaluate()
if self.cfg.context_parallel_size > 1:
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs,
)
patch_prepare_context_parallel_inputs()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)

View File

@@ -323,8 +323,8 @@ def apply_lora_kernel_patches(
AssertionError: If multiple adapters are active (currently unsupported).
Note:
The optimizations require LoRA adapters with no dropout and no bias terms. The
function will skip patching if these conditions aren't met.
The optimizations require LoRA adapters with no dropout. The function will skip
patching if that condition isn't met.
"""
if not isinstance(model, PeftModelForCausalLM):
raise TypeError("Model must be a PeftModelForCausalLM")
@@ -340,10 +340,10 @@ def apply_lora_kernel_patches(
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
can_patch = lora_config.lora_dropout == 0
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Cannot patch layers - requires `lora_dropout: 0`")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model

View File

@@ -0,0 +1,68 @@
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
from __future__ import annotations
import importlib
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
PATCHED_GUARD = (
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
)
def patch_prepare_context_parallel_inputs() -> None:
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
return
try:
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
except OSError as exc: # pragma: no cover - occurs when source is unavailable
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
return
if GUARD_PATTERN not in original_source:
LOG.warning(
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
"skipping FlashAttention context parallelism patch"
)
return
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
patched_source, _ = detab_code(patched_source)
patched_source = patched_source.replace(
"def _prepare_context_parallel_inputs(",
"def axolotl_prepare_context_parallel_inputs(",
1,
)
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# import symbols referenced in the method so exec can succeed
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
exec(patched_source, globals())
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
LOG.debug(
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
)

View File

@@ -2,6 +2,7 @@
HF Chat Templates prompt strategy
"""
import json
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
@@ -794,6 +795,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if val is not None:
transformed_message[key] = val
if "tool_calls" in transformed_message and transformed_message["tool_calls"]:
for tool_call in transformed_message["tool_calls"]:
if "function" in tool_call and "arguments" in tool_call["function"]:
args = tool_call["function"]["arguments"]
if isinstance(args, str):
try:
tool_call["function"]["arguments"] = json.loads(args)
except json.JSONDecodeError as e:
LOG.error(
f"Error parsing tool_calls arguments as JSON. "
f"Function: {tool_call.get('function', {}).get('name', 'unknown')}, "
f"Arguments string: {args!r}, "
f"Error: {e}"
)
raise
return transformed_message
def _get_images(self, prompt):

View File

@@ -221,44 +221,53 @@ def test_model_specific_activation(model_name, expected_activation):
assert layer.mlp.forward.__func__ is expected_activation
def test_kernel_patch_conditions():
"""Test various conditions that should prevent kernel patching."""
test_configs = [
# Dropout prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
},
# Bias prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "lora_only",
},
]
def test_kernel_patch_requires_zero_dropout():
"""Kernel patching should be skipped when dropout is enabled."""
config = {
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
}
for config in test_configs:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
# Verify no patches applied when dropout is non-zero
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
def test_kernel_patch_with_bias_enabled():
"""Kernel patching should succeed when LoRA bias is enabled."""
config = {
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "lora_only",
}
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify patches applied when bias support is enabled
assert layer.forward.__func__ is apply_lora_mlp_swiglu
def test_kernel_config_options():

View File

@@ -0,0 +1,66 @@
"""Tests for the HF Trainer context parallel patch."""
import pytest
from transformers import Trainer
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
GUARD_PATTERN,
PATCHED_GUARD,
patch_prepare_context_parallel_inputs,
)
@pytest.fixture
def restore_trainer_prepare_method():
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
original_method = getattr(
Trainer,
"_original_prepare_context_parallel_inputs",
Trainer._prepare_context_parallel_inputs,
)
patched_attr_present = hasattr(
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
)
yield
Trainer._prepare_context_parallel_inputs = original_method
if patched_attr_present:
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
def test_patch_attention_guard(restore_trainer_prepare_method):
"""Patch should swap the guard to allow sdpa or flash attention."""
# Ensure we start from the unpatched method
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
Trainer._prepare_context_parallel_inputs = (
Trainer._original_prepare_context_parallel_inputs
)
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
patch_prepare_context_parallel_inputs()
patched_method = Trainer._prepare_context_parallel_inputs
assert patched_method is not None
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
source = Trainer._axolotl_prepare_context_parallel_inputs_source
assert GUARD_PATTERN not in source
assert PATCHED_GUARD in source
def test_patch_is_idempotent(restore_trainer_prepare_method):
"""Calling the patch twice should leave the same patched function in place."""
patch_prepare_context_parallel_inputs()
first_patched = Trainer._prepare_context_parallel_inputs
patch_prepare_context_parallel_inputs()
second_patched = Trainer._prepare_context_parallel_inputs
assert first_patched is second_patched

View File

@@ -177,6 +177,15 @@ def fixture_devstral_1_1_tokenizer():
return tokenizer
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'

View File

@@ -6,7 +6,6 @@ import json
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.chat_template import StrategyLoader
from axolotl.utils.dict import DictDefault
@@ -23,15 +22,6 @@ def fixture_messages_w_tools():
return Dataset.from_list(rows)
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
@pytest.fixture(name="qwen3_prompt_strategy")
def qwen3_chat_template_strategy(qwen3_tokenizer):
cfg = DictDefault(

View File

@@ -4,7 +4,6 @@ Tests for splitting reasoning/thinking from content into separate field
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.chat_template import (
load,
@@ -56,15 +55,6 @@ def messages_w_reasoning_fixture():
)
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
class TestSplitThinking:
"""
test class to make sure datasets with reasoning content conforms to the chat_template strategy

View File

@@ -0,0 +1,214 @@
"""
Tests for handling json tool content
"""
import json
import pytest
from datasets import Dataset
from axolotl.prompt_strategies.chat_template import (
load,
)
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="qwen3_instruct_prompt_strategy")
def qwen3_instruct_chat_template_strategy(qwen3_tokenizer):
strategy = load(
qwen3_tokenizer,
DictDefault(
{
"train_on_inputs": False,
"sequence_len": 512,
}
),
DictDefault(
{
"chat_template": "qwen3",
"message_field_role": "role",
"message_field_content": "content",
"message_property_mappings": {
"role": "role",
"content": "content",
},
"roles": {
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
"field_messages": "messages",
}
),
)
return strategy
class TestQwen3IdenticalConversationArgs:
"""
Test Qwen3 tools is identical between JSON and dict
"""
@pytest.fixture(name="conversation_dict_args_dataset")
def fixture_conversation_dict_args_dataset(self):
"""
Provides a dataset with conversation where arguments is a dict.
"""
user_content = "What is the weather in Boston?"
function_name = "get_current_weather"
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
data = [
{
"messages": [
{"role": "user", "content": user_content},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": function_name,
"arguments": arguments_dict, # dict格式
}
}
],
},
],
}
]
return Dataset.from_list(data)
@pytest.fixture(name="conversation_str_args_dataset")
def fixture_conversation_str_args_dataset(self):
"""
Provides a dataset with conversation where arguments is a JSON string.
"""
user_content = "What is the weather in Boston?"
function_name = "get_current_weather"
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
arguments_str = json.dumps(arguments_dict)
data = [
{
"messages": [
{"role": "user", "content": user_content},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": function_name,
"arguments": arguments_str, # str格式
}
}
],
},
],
}
]
return Dataset.from_list(data)
@pytest.fixture(name="conversation_mixed_time_types_dataset")
def fixture_conversation_mixed_time_types_dataset(self):
"""
Provides a dataset where 'time' field has different types in different tool calls.
"""
data = [
{
"messages": [
{
"role": "user",
"content": "Get weather information at different times",
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "func1",
"arguments": json.dumps(
{"time": "2025-08-01"}
), # string type
}
},
{
"function": {
"name": "func2",
"arguments": json.dumps(
{"time": 1690876800}
), # number type
}
},
],
},
],
}
]
return Dataset.from_list(data)
def test_dict_and_str_args_produce_identical_output(
self,
conversation_dict_args_dataset,
conversation_str_args_dataset,
qwen3_instruct_prompt_strategy,
qwen3_tokenizer,
):
"""
Tests that after tokenization and decoding, the outputs for both
dict and string `arguments` are exactly the same.
"""
processed_dict_args = conversation_dict_args_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages"],
)
processed_str_args = conversation_str_args_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages"],
)
decoded_prompt_from_dict = qwen3_tokenizer.decode(
processed_dict_args[0]["input_ids"]
)
decoded_prompt_from_str = qwen3_tokenizer.decode(
processed_str_args[0]["input_ids"]
)
assert decoded_prompt_from_dict == decoded_prompt_from_str, (
f"Dict format output:\n{decoded_prompt_from_dict}\n"
f"String format output:\n{decoded_prompt_from_str}"
)
assert (
processed_dict_args[0]["input_ids"] == processed_str_args[0]["input_ids"]
), "The tokenized input_ids should be identical for dict and str arguments"
def test_str_args_with_mixed_time_types_no_error(
self,
conversation_mixed_time_types_dataset,
qwen3_instruct_prompt_strategy,
qwen3_tokenizer,
):
"""
Tests that when 'time' field has different types (string vs number)
in different tool calls, str format arguments don't cause errors.
"""
processed = conversation_mixed_time_types_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages"],
)
assert len(processed) == 1
assert "input_ids" in processed[0]
assert len(processed[0]["input_ids"]) > 0
decoded = qwen3_tokenizer.decode(processed[0]["input_ids"])
assert "2025-08-01" in decoded, "String time value should be present"
assert "1690876800" in decoded, "Number time value should be present"