Compare commits
6 Commits
cp-sdpa
...
lora-fsdp2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3299f182ba | ||
|
|
2fc430d365 | ||
|
|
f9748c4dc5 | ||
|
|
33975ce4bc | ||
|
|
e8b962d47f | ||
|
|
856ff12171 |
@@ -267,6 +267,7 @@ website:
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
- docs/optimizations.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
|
||||
@@ -212,6 +212,14 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
|
||||
|
||||
```
|
||||
"arguments": "{\"...\": \"...\"}"
|
||||
```
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
|
||||
@@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet
|
||||
|
||||
### Pre-training without streaming
|
||||
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
|
||||
@@ -140,3 +140,7 @@ description: Frequently asked questions
|
||||
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
|
||||
|
||||
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
|
||||
|
||||
**Q: `Error parsing tool_calls arguments as JSON.`
|
||||
|
||||
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.
|
||||
|
||||
@@ -5,10 +5,11 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
(including DDP, DeepSpeed, and FSDP2) training. These include (1) SwiGLU and GEGLU
|
||||
activation function Triton kernels, and (2) LoRA MLP and attention custom autograd
|
||||
functions. Our goal was to leverage operator fusion and tensor re-use in order to
|
||||
improve speed and reduce memory usage during the forward and backward passes of these
|
||||
calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
@@ -92,13 +93,12 @@ Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
|
||||
- Targeted LoRA adapters cannot use Dropout
|
||||
- This may limit model expressivity / cause overfitting
|
||||
- Targeted LoRA adapters cannot have bias terms
|
||||
- Targeted LoRA adapters must disable dropout (`lora_dropout: 0`)
|
||||
- This may limit model expressivity
|
||||
- Adapters that already include bias terms are supported.
|
||||
|
||||
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
|
||||
be re-finetuned without these features in order to be useful.
|
||||
Models with pre-existing LoRA adapters that use Dropout may need to be re-finetuned
|
||||
without it in order to be as performant.
|
||||
|
||||
## Implementation details
|
||||
|
||||
@@ -131,6 +131,5 @@ computation path.
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Support for dropout
|
||||
- Additional operator fusions
|
||||
|
||||
133
docs/optimizations.qmd
Normal file
133
docs/optimizations.qmd
Normal file
@@ -0,0 +1,133 @@
|
||||
---
|
||||
title: Optimizations Guide
|
||||
description: A guide to the performance and memory optimizations available in Axolotl.
|
||||
---
|
||||
|
||||
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
|
||||
|
||||
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
|
||||
|
||||
## Speed Optimizations
|
||||
|
||||
These optimizations focus on increasing training throughput and reducing total training time.
|
||||
|
||||
### Sample Packing
|
||||
|
||||
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
|
||||
|
||||
- **Config:** `sample_packing: true`
|
||||
- **Learn more:** [Sample Packing](multipack.qmd)
|
||||
|
||||
### Attention Implementations
|
||||
|
||||
Using an optimized attention implementation is critical for training speed.
|
||||
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
||||
|
||||
*Note: You should only enable one attention backend.*
|
||||
|
||||
### LoRA Optimizations
|
||||
|
||||
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
|
||||
|
||||
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
|
||||
|
||||
## Memory Optimizations
|
||||
|
||||
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
|
||||
|
||||
### Parameter Efficient Finetuning (LoRA & QLoRA)
|
||||
|
||||
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
|
||||
|
||||
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
|
||||
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
|
||||
|
||||
### Gradient Checkpointing & Activation Offloading
|
||||
|
||||
These techniques save VRAM by changing how activations are handled.
|
||||
|
||||
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
|
||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||
|
||||
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
|
||||
|
||||
### Liger Kernels
|
||||
|
||||
Provides efficient Triton kernels to improve training speed and reduce memory usage.
|
||||
|
||||
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
||||
|
||||
## Long Context Models
|
||||
|
||||
Techniques to train models on sequences longer than their original context window.
|
||||
|
||||
### RoPE Scaling
|
||||
|
||||
Extends a model's context window by interpolating its Rotary Position Embeddings.
|
||||
|
||||
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
|
||||
|
||||
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
|
||||
|
||||
### Artic Long Sequence Training (ALST)
|
||||
|
||||
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
|
||||
|
||||
- TiledMLP to reduce memory usage in MLP layers.
|
||||
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
|
||||
- Activation Offloading to CPU.
|
||||
|
||||
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
|
||||
|
||||
## Large Models (Distributed Training)
|
||||
|
||||
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
|
||||
|
||||
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
|
||||
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
|
||||
|
||||
### N-D Parallelism (Beta)
|
||||
|
||||
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
|
||||
|
||||
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Techniques to reduce the precision of model weights for memory savings.
|
||||
|
||||
### 4-bit Training (QLoRA)
|
||||
|
||||
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
|
||||
|
||||
### FP8 Training
|
||||
|
||||
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
|
||||
|
||||
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
|
||||
|
||||
### Quantization Aware Training (QAT)
|
||||
|
||||
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
|
||||
|
||||
- **Learn more:** [QAT Documentation](qat.qmd)
|
||||
|
||||
### GPTQ
|
||||
|
||||
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
||||
|
||||
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
||||
@@ -30,6 +30,7 @@ qat:
|
||||
```
|
||||
|
||||
We support the following quantization schemas:
|
||||
|
||||
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
|
||||
- `Int8DynamicActivationInt4Weight`
|
||||
- `Float8DynamicActivationFloat8Weight`
|
||||
|
||||
@@ -7,3 +7,24 @@ techniques. It is a combination of:
|
||||
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
|
||||
|
||||
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
tiled_mlp: true
|
||||
|
||||
# See Sequence Parallelism docs
|
||||
# https://docs.axolotl.ai/docs/sequence_parallelism.html
|
||||
context_parallel_size: int
|
||||
|
||||
plugins:
|
||||
# See Cut Cross Entropy docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# or Liger Kernel docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
# ...
|
||||
|
||||
```
|
||||
|
||||
@@ -38,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 41.7 GiB VRAM.
|
||||
This config uses about 45.62 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
|
||||
@@ -27,6 +27,14 @@ lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- linear_attn.in_proj_ba
|
||||
- linear_attn.in_proj_qkvz
|
||||
- linear_attn.out_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert_gate
|
||||
- mlp.gate
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
|
||||
@@ -84,6 +84,13 @@ class PatchManager:
|
||||
patch_evaluation_loop()
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
if self.cfg.context_parallel_size > 1:
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
|
||||
@@ -323,8 +323,8 @@ def apply_lora_kernel_patches(
|
||||
AssertionError: If multiple adapters are active (currently unsupported).
|
||||
|
||||
Note:
|
||||
The optimizations require LoRA adapters with no dropout and no bias terms. The
|
||||
function will skip patching if these conditions aren't met.
|
||||
The optimizations require LoRA adapters with no dropout. The function will skip
|
||||
patching if that condition isn't met.
|
||||
"""
|
||||
if not isinstance(model, PeftModelForCausalLM):
|
||||
raise TypeError("Model must be a PeftModelForCausalLM")
|
||||
@@ -340,10 +340,10 @@ def apply_lora_kernel_patches(
|
||||
lora_config = model.model.peft_config[active_adapter]
|
||||
|
||||
# Only patch if conditions are met
|
||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||
can_patch = lora_config.lora_dropout == 0
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||
LOG.warning("Cannot patch layers - requires `lora_dropout: 0`")
|
||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||
return model
|
||||
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
||||
PATCHED_GUARD = (
|
||||
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
|
||||
)
|
||||
|
||||
|
||||
def patch_prepare_context_parallel_inputs() -> None:
|
||||
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
|
||||
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
|
||||
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
|
||||
return
|
||||
|
||||
try:
|
||||
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
|
||||
except OSError as exc: # pragma: no cover - occurs when source is unavailable
|
||||
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
|
||||
return
|
||||
|
||||
if GUARD_PATTERN not in original_source:
|
||||
LOG.warning(
|
||||
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
|
||||
"skipping FlashAttention context parallelism patch"
|
||||
)
|
||||
return
|
||||
|
||||
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
|
||||
patched_source, _ = detab_code(patched_source)
|
||||
patched_source = patched_source.replace(
|
||||
"def _prepare_context_parallel_inputs(",
|
||||
"def axolotl_prepare_context_parallel_inputs(",
|
||||
1,
|
||||
)
|
||||
|
||||
module_name = Trainer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# import symbols referenced in the method so exec can succeed
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
|
||||
exec(patched_source, globals())
|
||||
|
||||
Trainer._original_prepare_context_parallel_inputs = (
|
||||
Trainer._prepare_context_parallel_inputs
|
||||
)
|
||||
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
|
||||
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
|
||||
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
|
||||
LOG.debug(
|
||||
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
|
||||
)
|
||||
@@ -2,6 +2,7 @@
|
||||
HF Chat Templates prompt strategy
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
|
||||
|
||||
@@ -794,6 +795,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if val is not None:
|
||||
transformed_message[key] = val
|
||||
|
||||
if "tool_calls" in transformed_message and transformed_message["tool_calls"]:
|
||||
for tool_call in transformed_message["tool_calls"]:
|
||||
if "function" in tool_call and "arguments" in tool_call["function"]:
|
||||
args = tool_call["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
tool_call["function"]["arguments"] = json.loads(args)
|
||||
except json.JSONDecodeError as e:
|
||||
LOG.error(
|
||||
f"Error parsing tool_calls arguments as JSON. "
|
||||
f"Function: {tool_call.get('function', {}).get('name', 'unknown')}, "
|
||||
f"Arguments string: {args!r}, "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
return transformed_message
|
||||
|
||||
def _get_images(self, prompt):
|
||||
|
||||
@@ -221,44 +221,53 @@ def test_model_specific_activation(model_name, expected_activation):
|
||||
assert layer.mlp.forward.__func__ is expected_activation
|
||||
|
||||
|
||||
def test_kernel_patch_conditions():
|
||||
"""Test various conditions that should prevent kernel patching."""
|
||||
test_configs = [
|
||||
# Dropout prevents patching
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
},
|
||||
# Bias prevents patching
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "lora_only",
|
||||
},
|
||||
]
|
||||
def test_kernel_patch_requires_zero_dropout():
|
||||
"""Kernel patching should be skipped when dropout is enabled."""
|
||||
config = {
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
}
|
||||
|
||||
for config in test_configs:
|
||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
peft_config = get_peft_config(config)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
peft_config = get_peft_config(config)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
# Should not patch
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
|
||||
# Verify no patches applied
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||
# Verify no patches applied when dropout is non-zero
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||
|
||||
|
||||
def test_kernel_patch_with_bias_enabled():
|
||||
"""Kernel patching should succeed when LoRA bias is enabled."""
|
||||
config = {
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "lora_only",
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
peft_config = get_peft_config(config)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
|
||||
# Verify patches applied when bias support is enabled
|
||||
assert layer.forward.__func__ is apply_lora_mlp_swiglu
|
||||
|
||||
|
||||
def test_kernel_config_options():
|
||||
|
||||
66
tests/monkeypatch/test_trainer_context_parallel_patch.py
Normal file
66
tests/monkeypatch/test_trainer_context_parallel_patch.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Tests for the HF Trainer context parallel patch."""
|
||||
|
||||
import pytest
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
GUARD_PATTERN,
|
||||
PATCHED_GUARD,
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restore_trainer_prepare_method():
|
||||
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
|
||||
original_method = getattr(
|
||||
Trainer,
|
||||
"_original_prepare_context_parallel_inputs",
|
||||
Trainer._prepare_context_parallel_inputs,
|
||||
)
|
||||
patched_attr_present = hasattr(
|
||||
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
Trainer._prepare_context_parallel_inputs = original_method
|
||||
if patched_attr_present:
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
|
||||
|
||||
|
||||
def test_patch_attention_guard(restore_trainer_prepare_method):
|
||||
"""Patch should swap the guard to allow sdpa or flash attention."""
|
||||
# Ensure we start from the unpatched method
|
||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||
Trainer._prepare_context_parallel_inputs = (
|
||||
Trainer._original_prepare_context_parallel_inputs
|
||||
)
|
||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
patched_method = Trainer._prepare_context_parallel_inputs
|
||||
assert patched_method is not None
|
||||
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
|
||||
|
||||
source = Trainer._axolotl_prepare_context_parallel_inputs_source
|
||||
assert GUARD_PATTERN not in source
|
||||
assert PATCHED_GUARD in source
|
||||
|
||||
|
||||
def test_patch_is_idempotent(restore_trainer_prepare_method):
|
||||
"""Calling the patch twice should leave the same patched function in place."""
|
||||
patch_prepare_context_parallel_inputs()
|
||||
first_patched = Trainer._prepare_context_parallel_inputs
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
second_patched = Trainer._prepare_context_parallel_inputs
|
||||
|
||||
assert first_patched is second_patched
|
||||
@@ -177,6 +177,15 @@ def fixture_devstral_1_1_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
|
||||
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
|
||||
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
|
||||
|
||||
@@ -6,7 +6,6 @@ import json
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import StrategyLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -23,15 +22,6 @@ def fixture_messages_w_tools():
|
||||
return Dataset.from_list(rows)
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_prompt_strategy")
|
||||
def qwen3_chat_template_strategy(qwen3_tokenizer):
|
||||
cfg = DictDefault(
|
||||
|
||||
@@ -4,7 +4,6 @@ Tests for splitting reasoning/thinking from content into separate field
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
load,
|
||||
@@ -56,15 +55,6 @@ def messages_w_reasoning_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestSplitThinking:
|
||||
"""
|
||||
test class to make sure datasets with reasoning content conforms to the chat_template strategy
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Tests for handling json tool content
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
load,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_instruct_prompt_strategy")
|
||||
def qwen3_instruct_chat_template_strategy(qwen3_tokenizer):
|
||||
strategy = load(
|
||||
qwen3_tokenizer,
|
||||
DictDefault(
|
||||
{
|
||||
"train_on_inputs": False,
|
||||
"sequence_len": 512,
|
||||
}
|
||||
),
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "qwen3",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
"roles": {
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
"field_messages": "messages",
|
||||
}
|
||||
),
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
class TestQwen3IdenticalConversationArgs:
|
||||
"""
|
||||
Test Qwen3 tools is identical between JSON and dict
|
||||
"""
|
||||
|
||||
@pytest.fixture(name="conversation_dict_args_dataset")
|
||||
def fixture_conversation_dict_args_dataset(self):
|
||||
"""
|
||||
Provides a dataset with conversation where arguments is a dict.
|
||||
"""
|
||||
user_content = "What is the weather in Boston?"
|
||||
function_name = "get_current_weather"
|
||||
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
|
||||
|
||||
data = [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments_dict, # dict格式
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
return Dataset.from_list(data)
|
||||
|
||||
@pytest.fixture(name="conversation_str_args_dataset")
|
||||
def fixture_conversation_str_args_dataset(self):
|
||||
"""
|
||||
Provides a dataset with conversation where arguments is a JSON string.
|
||||
"""
|
||||
user_content = "What is the weather in Boston?"
|
||||
function_name = "get_current_weather"
|
||||
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
|
||||
arguments_str = json.dumps(arguments_dict)
|
||||
|
||||
data = [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments_str, # str格式
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
return Dataset.from_list(data)
|
||||
|
||||
@pytest.fixture(name="conversation_mixed_time_types_dataset")
|
||||
def fixture_conversation_mixed_time_types_dataset(self):
|
||||
"""
|
||||
Provides a dataset where 'time' field has different types in different tool calls.
|
||||
"""
|
||||
data = [
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Get weather information at different times",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "func1",
|
||||
"arguments": json.dumps(
|
||||
{"time": "2025-08-01"}
|
||||
), # string type
|
||||
}
|
||||
},
|
||||
{
|
||||
"function": {
|
||||
"name": "func2",
|
||||
"arguments": json.dumps(
|
||||
{"time": 1690876800}
|
||||
), # number type
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
return Dataset.from_list(data)
|
||||
|
||||
def test_dict_and_str_args_produce_identical_output(
|
||||
self,
|
||||
conversation_dict_args_dataset,
|
||||
conversation_str_args_dataset,
|
||||
qwen3_instruct_prompt_strategy,
|
||||
qwen3_tokenizer,
|
||||
):
|
||||
"""
|
||||
Tests that after tokenization and decoding, the outputs for both
|
||||
dict and string `arguments` are exactly the same.
|
||||
"""
|
||||
processed_dict_args = conversation_dict_args_dataset.map(
|
||||
qwen3_instruct_prompt_strategy.tokenize_prompt,
|
||||
batched=True,
|
||||
remove_columns=["messages"],
|
||||
)
|
||||
|
||||
processed_str_args = conversation_str_args_dataset.map(
|
||||
qwen3_instruct_prompt_strategy.tokenize_prompt,
|
||||
batched=True,
|
||||
remove_columns=["messages"],
|
||||
)
|
||||
|
||||
decoded_prompt_from_dict = qwen3_tokenizer.decode(
|
||||
processed_dict_args[0]["input_ids"]
|
||||
)
|
||||
|
||||
decoded_prompt_from_str = qwen3_tokenizer.decode(
|
||||
processed_str_args[0]["input_ids"]
|
||||
)
|
||||
|
||||
assert decoded_prompt_from_dict == decoded_prompt_from_str, (
|
||||
f"Dict format output:\n{decoded_prompt_from_dict}\n"
|
||||
f"String format output:\n{decoded_prompt_from_str}"
|
||||
)
|
||||
|
||||
assert (
|
||||
processed_dict_args[0]["input_ids"] == processed_str_args[0]["input_ids"]
|
||||
), "The tokenized input_ids should be identical for dict and str arguments"
|
||||
|
||||
def test_str_args_with_mixed_time_types_no_error(
|
||||
self,
|
||||
conversation_mixed_time_types_dataset,
|
||||
qwen3_instruct_prompt_strategy,
|
||||
qwen3_tokenizer,
|
||||
):
|
||||
"""
|
||||
Tests that when 'time' field has different types (string vs number)
|
||||
in different tool calls, str format arguments don't cause errors.
|
||||
"""
|
||||
processed = conversation_mixed_time_types_dataset.map(
|
||||
qwen3_instruct_prompt_strategy.tokenize_prompt,
|
||||
batched=True,
|
||||
remove_columns=["messages"],
|
||||
)
|
||||
|
||||
assert len(processed) == 1
|
||||
assert "input_ids" in processed[0]
|
||||
assert len(processed[0]["input_ids"]) > 0
|
||||
|
||||
decoded = qwen3_tokenizer.decode(processed[0]["input_ids"])
|
||||
assert "2025-08-01" in decoded, "String time value should be present"
|
||||
assert "1690876800" in decoded, "Number time value should be present"
|
||||
Reference in New Issue
Block a user