Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
b5198d8734 granite chat multipack support and example 2025-08-02 20:57:00 -04:00
Wing Lian
4ab6a1bd7e add support for granite chat templates 2025-08-02 11:29:03 -04:00
89 changed files with 1951 additions and 2598 deletions

View File

@@ -54,7 +54,7 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
@@ -64,16 +64,9 @@ jobs:
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.8.0
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
# python_version: "3.11"
# pytorch: nightly
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base-nightly"
dockerfile: "Dockerfile-base-nightly"
# # "next" is for release candidates of pytorch
# - cuda: "128"
# cuda_version: 12.8.1
@@ -129,13 +122,6 @@ jobs:
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -143,13 +129,6 @@ jobs:
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -24,13 +24,12 @@ jobs:
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"

View File

@@ -27,7 +27,7 @@ repos:
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1
rev: v1.17.0
hooks:
- id: mypy
additional_dependencies:

View File

@@ -185,6 +185,7 @@ datasets:
| `flash_attention` | `false` | Use flash attention |
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
| `sdp_attention` | `false` | Use scaled dot product |
| `s2_attention` | `false` | Use shifted sparse attention |

View File

@@ -296,6 +296,7 @@
# flash_attention:
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# # Whether to use scaled-dot-product attention
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
@@ -540,6 +541,7 @@ xformers_attention: ${XFORMERS_ATTENTION}
flash_attention: ${FLASH_ATTENTION}
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
sdp_attention: ${SDP_ATTENTION}
s2_attention: ${S2_ATTENTION}

View File

@@ -25,28 +25,17 @@
## 🎉 Latest Updates
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
<details>
<summary>Expand older updates</summary>
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
</details>
## ✨ Overview
Axolotl is a tool designed to streamline post-training for various AI models.

View File

@@ -274,7 +274,6 @@ website:
- docs/dataset_preprocessing.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
- section: "Advanced Features"
contents:

View File

@@ -212,11 +212,10 @@ Instead of passing `tools` via the system prompt, an alternative method would be
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
Example config for Llama4:
```yaml
chat_template: llama4
datasets:
- path: Nanobit/text-tools-2k-test
- path: ...
type: chat_template
# field_tools: tools # default is `tools`
```

View File

@@ -1,6 +1,4 @@
---
title: "N-D Parallelism (Beta)"
---
# N-D Parallelism
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
@@ -73,10 +71,6 @@ Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size
## Examples
::: {.callout-tip}
See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).
:::
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
- You want FSDP within each node and DDP across nodes.
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
@@ -101,7 +95,7 @@ This matrix describes how different parallelism methods can be combined in Axolo
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
- `tp_size` refers to `tensor_parallel_size`

View File

@@ -1,129 +0,0 @@
---
title: Optimizers
description: Configuring optimizers
---
## Overview
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
Here is a list of optimizers supported by transformers as of `v4.54.0`:
- `adamw_torch`
- `adamw_torch_fused`
- `adamw_torch_xla`
- `adamw_torch_npu_fused`
- `adamw_apex_fused`
- `adafactor`
- `adamw_anyprecision`
- `adamw_torch_4bit`
- `adamw_torch_8bit`
- `ademamix`
- `sgd`
- `adagrad`
- `adamw_bnb_8bit`
- `adamw_8bit` # alias for adamw_bnb_8bit
- `ademamix_8bit`
- `lion_8bit`
- `lion_32bit`
- `paged_adamw_32bit`
- `paged_adamw_8bit`
- `paged_ademamix_32bit`
- `paged_ademamix_8bit`
- `paged_lion_32bit`
- `paged_lion_8bit`
- `rmsprop`
- `rmsprop_bnb`
- `rmsprop_bnb_8bit`
- `rmsprop_bnb_32bit`
- `galore_adamw`
- `galore_adamw_8bit`
- `galore_adafactor`
- `galore_adamw_layerwise`
- `galore_adamw_8bit_layerwise`
- `galore_adafactor_layerwise`
- `lomo`
- `adalomo`
- `grokadamw`
- `schedule_free_radam`
- `schedule_free_adamw`
- `schedule_free_sgd`
- `apollo_adamw`
- `apollo_adamw_layerwise`
- `stable_adamw`
## Custom Optimizers
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
### optimi_adamw
```yaml
optimizer: optimi_adamw
```
### ao_adamw_4bit
Deprecated: Please use `adamw_torch_4bit`.
### ao_adamw_8bit
Deprecated: Please use `adamw_torch_8bit`.
### ao_adamw_fp8
```yaml
optimizer: ao_adamw_fp8
```
### adopt_adamw
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
```yaml
optimizer: adopt_adamw
```
### came_pytorch
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
```yaml
optimizer: came_pytorch
# optional args (defaults below)
adam_beta1: 0.9
adam_beta2: 0.999
adam_beta3: 0.9999
adam_epsilon: 1e-30
adam_epsilon2: 1e-16
```
### muon
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
```yaml
optimizer: muon
```
### dion
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
Note: Implementation written for PyTorch 2.7+ for DTensor
```yaml
optimizer: dion
dion_lr: 0.01
dion_momentum: 0.95
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
```

View File

@@ -1,53 +0,0 @@
# Finetune ArceeAI's AFM with Axolotl
[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:
```bash
axolotl train examples/arcee/afm-4.5b-qlora.yaml
```
This config uses about 7.8GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: arcee-ai/AFM-4.5B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -47,6 +47,7 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
]
},
{

View File

@@ -10,14 +10,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:

View File

@@ -1,52 +0,0 @@
# ND Parallelism Examples
This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.
## Quick Start
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Run the command below:
```bash
# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node
axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
```
## Example Configurations
### Single Node (8 GPUs)
**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))
- Uses all 3 parallelism dimensions on a single node
- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU
```yaml
dp_shard_size: 2 # FSDP across 2 GPUs
tensor_parallel_size: 2 # TP across 2 GPUs
context_parallel_size: 2 # CP across 2 GPUs
# Total: 2 × 2 × 2 = 8 GPUs
```
### Multi-Node
**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))
- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication
- Ideal for: Scaling to multiple nodes while maintaining training efficiency
```yaml
dp_shard_size: 4 # FSDP within each 4-GPU group
tensor_parallel_size: 2 # TP within each node
dp_replicate_size: 2 # DDP across 2 groups
# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)
```
## Learn More
- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)
- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)
- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,47 +0,0 @@
base_model: meta-llama/Llama-3.1-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 4
dp_replicate_size: 2
tensor_parallel_size: 2
# context_parallel_size: 2
dataset_prepared_path: last_run_prepared
special_tokens:
pad_token: <|end_of_text|>
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 2048
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1

View File

@@ -1,46 +0,0 @@
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 2
# dp_replicate_size: 1
context_parallel_size: 2
tensor_parallel_size: 2
dataset_prepared_path: last_run_prepared
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 8192
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1 # must be 1 when using context parallel
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:

View File

@@ -4,14 +4,17 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:

View File

@@ -1,74 +0,0 @@
# Finetune OpenAI's GPT-OSS with Axolotl
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
```bash
# LoRA SFT linear layers (1x48GB @ ~44GiB)
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
```
Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
### Training 120B
On 8xH100s
```bash
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
```
### Tool use
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
Here is an example dataset config:
```yaml
datasets:
- path: Nanobit/text-tools-2k-test
type: chat_template
```
See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,67 +0,0 @@
# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading
# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model
base_model: axolotl-ai-co/gpt-oss-120b-dequantized
use_kernels: false
dp_shard_size: 16 # requires 2x8xH100 nodes
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
cpu_ram_efficient_loading: true

View File

@@ -1,58 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
# choose the zero3 configuration that best fits your system capabilities
deepspeed: deepspeed_configs/zero3_bf16.json

View File

@@ -1,68 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true
# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.
# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`

View File

@@ -1,64 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true

View File

@@ -1,67 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
lora_target_linear: true
# TODO: not supported for now, see peft#2710
#lora_target_parameters: # target the experts in the last two layers
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-4
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"

View File

@@ -45,6 +45,7 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -49,6 +49,7 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -8,14 +8,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:

View File

@@ -27,6 +27,7 @@ sequence_len: 2048
sample_packing: true
eval_sample_packing: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -26,6 +26,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -26,6 +26,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -1,66 +0,0 @@
# SLURM Multi-Node Training
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
## Prerequisites
- Access to a SLURM cluster with GPU nodes
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
## Usage
### Standard SLURM Clusters
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
2. Place your Axolotl config file (`train.yaml`) in the same directory.
3. Set the appropriate environment variables for the job:
```bash
export HF_TOKEN="your-huggingface-token"
# metric tracking
# export WANDB_API_KEY="your-wandb-api-key"
# ...
```
4. Submit the job:
```bash
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
```
Where:
- `NUM_NODES`: Number of nodes to use
- `NUM_TRAINERS`: GPUs per node (typically 8)
- `PRIMARY_ADDR`: Hostname/IP of the master node
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
5. (Optional) Run other slurm commands:
```bash
# check job info
scontrol show job axolotl-cli
# check job queue
squeue
# check cluster status
sinfo
```
### RunPod Instant Clusters
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
1. **Deploy a SLURM Cluster**:
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
- Click "Create a Cluster"
- Choose your GPU type, node count, and region
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
- Deploy the cluster
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
## Additional Resources
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)

View File

@@ -1,20 +0,0 @@
#!/bin/bash
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
# export HF_TOKEN="..."
# export WANDB_API_KEY="..."
#
# ---------- SBATCH commands ---------- #
#SBATCH --job-name=axolotl-slurm-multinode
#SBATCH --ntasks-per-node=1
#SBATCH --nodes=$NUM_NODES
#SBATCH --gpus-per-task=8
#SBATCH --cpus-per-task=128
export TORCH_DIST_INIT_BARRIER=0
srun axolotl preprocess train.yaml
srun axolotl train train.yaml --launcher torchrun -- \
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"

View File

@@ -6,14 +6,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Voxtral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Please install the below.

View File

@@ -1,9 +1,8 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.46.1
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
bitsandbytes==0.46.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
@@ -13,21 +12,19 @@ liger-kernel==0.6.1
packaging==23.2
huggingface_hub>=0.33.0
peft==0.17.0
transformers==4.55.0
peft==0.16.0
transformers==4.54.1
tokenizers>=0.21.1
accelerate==1.10.0
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
datasets==4.0.0
deepspeed>=0.17.0
trl==0.21.0
trl==0.20.0
hf_xet==1.1.5
kernels==0.9.0
trackio
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.41.1
gradio==5.23.3
modal==1.0.2
pydantic==2.10.6
@@ -69,6 +66,6 @@ torchao==0.12.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
axolotl-contribs-mit==0.0.3
mistral-common==1.8.3

View File

@@ -44,13 +44,8 @@ add_keys_to_authorized() {
chmod 700 -R ~/.ssh
}
# Set SSH port
if [ ! -z "$SSH_PORT" ]; then
sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config
fi
if [[ $PUBLIC_KEY ]]; then
# runpod, prime intellect
# runpod
add_keys_to_authorized "$PUBLIC_KEY"
# Start the SSH service in the background
service ssh start
@@ -81,13 +76,5 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
fi
# start the runpod slurm init
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
if [[ -f "$SLURM_INIT" ]]; then
echo "[entrypoint] running $SLURM_INIT..."
bash "$SLURM_INIT"
fi
# Execute the passed arguments (CMD)
exec "$@"

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
)

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.12.0"
__version__ = "0.12.0.dev"

View File

@@ -2,10 +2,12 @@
CLI to start the vllm server for online RL
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import trl
from trl.scripts.vllm_serve import ScriptArguments
from axolotl.cli.config import load_cfg
@@ -40,17 +42,13 @@ def do_vllm_serve(
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = 1
data_parallel_size = 1
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
data_parallel_size = (
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
)
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
data_parallel_size = (
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = (
@@ -83,3 +81,63 @@ def do_vllm_serve(
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)
def patch_vllm_worker():
from multiprocessing.connection import Connection
from vllm import LLM
def llm_worker(
script_args: AxolotlScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
# Set required environment variables for DP to work with vLLM
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
enable_reasoning=script_args.enable_reasoning,
reasoning_parser=script_args.reasoning_parser,
)
# Send ready signal to parent process
connection.send({"status": "ready"})
while True:
# Wait for commands from the parent process
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
# Handle commands
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args, kwargs = command.get("args", ()), command.get("kwargs", {})
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
trl.scripts.vllm_serve.llm_worker = llm_worker

View File

@@ -13,5 +13,4 @@ MOE_ARCH_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"gpt_oss": "GptOssDecoderLayer",
}

View File

@@ -24,10 +24,12 @@ from pathlib import Path
from typing import Any
import torch
from accelerate import PartialState
from transformers import (
TrainerCallback,
)
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
@@ -38,7 +40,6 @@ from axolotl.utils.callbacks import (
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
@@ -266,24 +267,27 @@ class TrainerBuilderBase(abc.ABC):
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)
optimizer_cls = DionOptimizerFactory
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
optimizer_kwargs.update(adam_kwargs)
_, device_mesh = build_parallelism_config(self.cfg)
if device_mesh is not None:
optimizer_kwargs["device_mesh"] = device_mesh
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
@@ -429,12 +433,30 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict):
partial_state = PartialState()
has_pc_attr = (
hasattr(partial_state, "parallelism_config")
and partial_state.parallelism_config
)
has_pc_key = (
"parallelism_config"
in partial_state._shared_state # pylint: disable=protected-access
and partial_state._shared_state[ # pylint: disable=protected-access
"parallelism_config"
]
)
use_configured_state = has_pc_attr or has_pc_key
if self.cfg.accelerator_config:
use_configured_state = self.cfg.accelerator_config.pop(
"use_configured_state", use_configured_state
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
**self.cfg.accelerator_config
use_configured_state=use_configured_state, **self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state,
)
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True:
@@ -494,20 +516,10 @@ class TrainerBuilderBase(abc.ABC):
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
arg_map = {
"dion_learning_rate": "dion_lr",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
training_args_kwargs["average_tokens_across_devices"] = False

View File

@@ -43,7 +43,6 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@@ -137,18 +136,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
if self.cfg.trainer_cls:
# override the trainer cls
try:
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
return trainer_cls
except (ImportError, AttributeError, ValueError) as e:
raise ValueError(
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
) from e
return AxolotlTrainer
def build(self, total_num_steps):
@@ -363,7 +350,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple
)
elif self.cfg.pad_to_sequence_len is None:
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple

View File

@@ -15,7 +15,6 @@ from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
@@ -73,16 +72,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.trainer_cls:
# override the trainer cls
try:
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
except (ImportError, AttributeError, ValueError) as e:
raise ValueError(
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
) from e
return trainer_cls, trainer_cls_args
def _build_training_arguments(self, total_num_steps):

View File

@@ -10,11 +10,8 @@ from functools import partial, wraps
from typing import Any, Callable, Literal, Optional
import datasets
import safetensors
import torch
from accelerate.state import AcceleratorState
from datasets import Dataset
from peft import PeftModel
from torch.utils.data import (
BatchSampler,
DataLoader,
@@ -22,10 +19,8 @@ from torch.utils.data import (
Sampler,
SequentialSampler,
)
from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -520,18 +515,7 @@ class AxolotlTrainer(
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
accelerator_config = self.args.accelerator_config.to_dict()
use_configured_state = accelerator_config.get("use_configured_state", False)
if not use_configured_state:
AcceleratorState._reset_state( # pylint: disable=protected-access
reset_partial_state=True
)
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
res = super().create_accelerator_and_postprocess()
if self.is_fsdp_enabled:
if (
@@ -540,6 +524,8 @@ class AxolotlTrainer(
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
return res
# pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
@@ -581,10 +567,10 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
logs["memory/max_memory_active"] = active
logs["memory/max_memory_allocated"] = allocated
logs["memory/device_memory_reserved"] = reserved
except (ValueError, FileNotFoundError):
pass
del self._stored_metrics[train_eval]
@@ -604,64 +590,3 @@ class AxolotlTrainer(
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
else (PreTrainedModel, PeftModel)
)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes,
):
self.accelerator.unwrap_model(
self.model, keep_torch_compile=False
).save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
)
else:
LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -2,7 +2,6 @@
Mixin for correctly saving fsdp
"""
from accelerate import PartialState
from transformers import Trainer
@@ -19,15 +18,3 @@ class DistributedParallelMixin(Trainer):
):
state_dict = self.accelerator.get_state_dict(self.model)
super()._save(output_dir, state_dict=state_dict)
def create_accelerator_and_postprocess(self):
super().create_accelerator_and_postprocess()
if (
self.accelerator.distributed_type == "FSDP"
and self.accelerator.state.fsdp_plugin is None
):
# pylint: disable=protected-access
# handle Context Parallelism without FSDP
self.accelerator.state.distributed_type = "MULTI_GPU"
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
PartialState().distributed_type = "MULTI_GPU"

View File

@@ -243,18 +243,3 @@ class AxolotlTrainingMixins:
)
# end of multi-modal section
dion_learning_rate: float | None = field(
default=None,
metadata={"help": "The learning rate for Dion"},
)
dion_momentum: float | None = field(
default=None,
metadata={"help": "The momentum for Dion"},
)
dion_rank_fraction: float | None = field(
default=None,
)
dion_rank_multiple_of: int | None = field(
default=None,
)

View File

@@ -26,11 +26,9 @@ import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -643,24 +641,3 @@ class BaseOptimizerFactory:
self, opt_model, training_args, **optimizer_kwargs
) -> Optimizer | None:
pass
# duplicated from transformers
def get_decay_parameter_names(self, model) -> list[str]:
"""
Get all parameter names that weight decay will be applied to.
This function filters out parameters in two ways:
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
2. By parameter name patterns (containing 'bias', or variation of 'norm')
"""
forbidden_name_patterns = [
r"bias",
r"layernorm",
r"rmsnorm",
r"(?:^|\.)norm(?:$|\.)",
r"_norm(?:$|\.)",
]
decay_parameters = get_parameter_names(
model, [nn.LayerNorm], forbidden_name_patterns
)
return decay_parameters

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
```
## Usage
@@ -31,7 +31,6 @@ plugins:
## Supported Models
- arcee
- cohere
- cohere2
- gemma
@@ -42,17 +41,13 @@ plugins:
- gemma3n_text
- glm
- glm4
- gpt_oss
- granite
- granitemoe
- hunyuan_v1_dense
- hunyuan_v1_moe
- llama
- llama4
- llama4_text
- mistral
- mistral3
- mixtral
- mllama
- phi
- phi3

View File

@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
)

View File

@@ -284,12 +284,12 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
return sample
def _tokenize_single_prompt(self, prompt):
target_token_ids = prompt.get("target_token_ids", None)
logprobs = prompt.pop(self.logprobs_field)
target_token_ids = prompt.pop("target_token_ids")
tokenized_prompt = super()._tokenize_single_prompt(prompt)
if target_token_ids is not None:
tokenized_prompt["target_token_ids"] = target_token_ids
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt["target_token_ids"] = target_token_ids
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt

View File

@@ -14,7 +14,6 @@ from typing import Callable
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from torch.distributed.tensor import DTensor
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
@@ -26,7 +25,6 @@ def get_lora_parameters(
proj: nn.Module,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
QuantState | None,
torch.Tensor | None,
torch.Tensor | None,
@@ -39,54 +37,39 @@ def get_lora_parameters(
proj: The projection module to extract parameters from.
Returns:
A tuple containing the base weights, quantization state, LoRA A and B weights,
scaling factor, and base layer bias. Quant state, weights, and bias may be
`None` if not available.
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
LoRA B matrix, and scaling factor. States and matrices may be None if not
available.
"""
# For DPO or disabled adapters
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
W = base_layer.weight
b = base_layer.bias
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, b, quant_state, None, None, None
quant_state = getattr(W, "quant_state", None)
return W, quant_state, None, None, None
active_adapter = (
proj.active_adapters[0]
if hasattr(proj, "active_adapters")
else proj.active_adapter
)
linear_A = proj.lora_A[active_adapter]
linear_B = proj.lora_B[active_adapter]
# This manual unsharding is needed for FSDP2 + LoRA kernels compatibility.
# We fuse linear layers + LoRA adapters calculations into a single
# torch.autograd.Function, bypassing the registered unshard / reshard behavior.
# Note that we don't apply resharding later in this module (it gets messy quickly),
# but LoRA parameters are generally small enough that this is not an issue.
if isinstance(linear_A.weight, DTensor):
linear_A.unshard()
linear_B.unshard()
A = linear_A.weight
B = linear_B.weight
A = proj.lora_A[active_adapter].weight
B = proj.lora_B[active_adapter].weight
s = proj.scaling[active_adapter]
return W, b, quant_state, A, B, s
quant_state = getattr(W, "quant_state", None)
return W, quant_state, A, B, s
def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
b: torch.Tensor | None,
W_quant: QuantState | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
s: float | None,
W_quant: QuantState,
A: torch.Tensor,
B: torch.Tensor,
s: float,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
@@ -107,22 +90,20 @@ def matmul_lora(
dtype = X.dtype
W = dequantize(W.t(), W_quant)
reshape = False
if X.dim() == 3:
batch, seq_len, _ = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
out = torch.matmul(X, W, out=out)
if W_quant is not None:
del W
if A is not None:
A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr]
out += s * X @ A @ B
if b is not None:
out += b
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
return out.view(batch, seq_len, -1) if reshape else out
@@ -136,20 +117,17 @@ class LoRA_MLP(torch.autograd.Function):
ctx,
X: torch.Tensor,
gate_weight: torch.Tensor,
gate_bias: torch.Tensor | None,
gate_quant: QuantState | None,
gate_quant: object | None,
gate_A: torch.Tensor | None,
gate_B: torch.Tensor | None,
gate_scale: float,
up_weight: torch.Tensor,
up_bias: torch.Tensor | None,
up_quant: QuantState | None,
up_quant: object | None,
up_A: torch.Tensor | None,
up_B: torch.Tensor | None,
up_scale: float,
down_weight: torch.Tensor,
down_bias: torch.Tensor | None,
down_quant: QuantState | None,
down_quant: object | None,
down_A: torch.Tensor | None,
down_B: torch.Tensor | None,
down_scale: float,
@@ -164,22 +142,20 @@ class LoRA_MLP(torch.autograd.Function):
ctx: Autograd context
X: Input features
gate_weight: Gate projection weight
gate_bias: Gate projection bias
gate_quant: Gate quantization state
gate_A: Gate LoRA A matrix
gate_B: Gate LoRA B matrix
gate_scale: Gate LoRA scale
up_weight: Up projection weight
up_quant: Up projection quantization state
up_A: Up projection LoRA A matrix
up_B: Up projection LoRA B matrix
up_scale: Up projection LoRA scale
down_weight: Down projection weight
down_bias: Down projection bias
down_quant: Down projection quantization state
down_A: Down projection LoRA A matrix
down_B: Down projection LoRA B matrix
down_scale: Down projection LoRA scale
up_weight: Up-projection weight
up_quant: Up-projection quantization state
up_A: Up-projection LoRA A matrix
up_B: Up-projection LoRA B matrix
up_scale: Up-projection LoRA scale
down_weight: Down-projection weight
down_quant: Down-projection quantization state
down_A: Down-projection LoRA A matrix
down_B: Down-projection LoRA B matrix
down_scale: Down-projection LoRA scale
activation_fn: Forward activation function
activation_fn_backward: Backward activation function
inplace: Whether to perform operations in-place
@@ -188,17 +164,15 @@ class LoRA_MLP(torch.autograd.Function):
Output transformed by multi-layer perceptron and activation function
"""
# Compute projections
gate = matmul_lora(
X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale
)
up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale)
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
# Activation
hidden = activation_fn(gate, up)
# Down projection
output = matmul_lora(
hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale
hidden, down_weight, down_quant, down_A, down_B, down_scale
)
# Save for backward
@@ -221,26 +195,22 @@ class LoRA_MLP(torch.autograd.Function):
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
None,
]:
"""
Performs backward pass computation for LoRA MLP.
@@ -252,7 +222,7 @@ class LoRA_MLP(torch.autograd.Function):
Returns:
Tuple containing gradients for all inputs from forward pass:
- Input gradient tensor (or `None`)
- `None` for weights/biases/quantization states
- `None` for weights/quantization states
- LoRA A/B matrix gradients (or `None`)
- `None` for scaling factors
- `None` for activation functions and flags
@@ -295,10 +265,9 @@ class LoRA_MLP(torch.autograd.Function):
dtype = X.dtype
# Down projection
grad_down = matmul_lora(
DW = matmul_lora(
grad_output,
down_weight.t(),
None,
down_quant,
down_B,
down_A,
@@ -306,7 +275,7 @@ class LoRA_MLP(torch.autograd.Function):
)
# Activation backward
h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up)
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
# Initialize and compute LoRA gradients
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
@@ -346,8 +315,8 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight, gate_quant)
dX += grad_gate @ gate_weight
gate_weight = dequantize(gate_weight.t(), gate_quant)
dX += grad_gate @ gate_weight.t()
del gate_weight
if gate_A is not None and gate_B is not None:
@@ -365,26 +334,22 @@ class LoRA_MLP(torch.autograd.Function):
dX,
None,
None,
None,
d_gate_A.t() if d_gate_A is not None else None,
d_gate_B.t() if d_gate_B is not None else None,
None,
None,
None,
None,
d_up_A.t() if d_up_A is not None else None,
d_up_B.t() if d_up_B is not None else None,
None,
None,
None,
None,
d_down_A.t() if d_down_A is not None else None,
d_down_B.t() if d_down_B is not None else None,
None,
None,
None,
None,
None,
)
@@ -399,26 +364,23 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.
Returns:
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
"""
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateb,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upb,
upW_quant,
upA,
upB,
upS,
downW,
downb,
downW_quant,
downA,
downB,
@@ -442,25 +404,22 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T
Returns:
Output tensor after applying LoRA-adapted MLP with GEGLU activation
"""
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateb,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upb,
upW_quant,
upA,
upB,
upS,
downW,
downb,
downW_quant,
downA,
downB,
@@ -487,19 +446,16 @@ class LoRA_QKV(torch.autograd.Function):
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
q_weight: torch.Tensor,
q_bias: torch.Tensor | None,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
k_weight: torch.Tensor,
k_bias: torch.Tensor | None,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
v_weight: torch.Tensor,
v_bias: torch.Tensor | None,
v_quant: QuantState | None,
v_A: torch.Tensor | None,
v_B: torch.Tensor | None,
@@ -513,19 +469,16 @@ class LoRA_QKV(torch.autograd.Function):
ctx: Autograd context
X: Input tensor
q_weight: Query projection weight
q_bias: Query projection bias
q_quant: Query quantization state
q_A: Query LoRA A matrix
q_B: Query LoRA B matrix
q_scale: Query LoRA scale
k_weight: Key projection weight
k_bias: Key projection bias
k_quant: Key quantization state
k_A: Key LoRA A matrix
k_B: Key LoRA B matrix
k_scale: Key LoRA scale
v_weight: Value projection weight
v_bias: Value projection bias
v_quant: Value quantization state
v_A: Value LoRA A matrix
v_B: Value LoRA B matrix
@@ -535,21 +488,20 @@ class LoRA_QKV(torch.autograd.Function):
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale)
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
ctx.scales = (q_scale, k_scale, v_scale)
ctx.quants = (q_quant, k_quant, v_quant)
ctx.weights = (q_weight, k_weight, v_weight)
ctx.biases = (q_bias, k_bias, v_bias)
ctx.inplace = inplace
return Q, K, V
@staticmethod
@torch_amp_custom_bwd
@torch_amp_custom_fwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
@@ -559,19 +511,16 @@ class LoRA_QKV(torch.autograd.Function):
torch.Tensor,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
@@ -659,31 +608,31 @@ class LoRA_QKV(torch.autograd.Function):
# Transpose gradients if needed
if d_A_q is not None:
d_A_q = d_A_q.t()
d_B_q = d_B_q.t() # type: ignore[union-attr]
if d_B_q is not None:
d_B_q = d_B_q.t()
if d_A_k is not None:
d_A_k = d_A_k.t()
d_B_k = d_B_k.t() # type: ignore[union-attr]
if d_B_k is not None:
d_B_k = d_B_k.t()
if d_A_v is not None:
d_A_v = d_A_v.t()
d_B_v = d_B_v.t() # type: ignore[union-attr]
if d_B_v is not None:
d_B_v = d_B_v.t()
return (
grad_X.view(batch, seq_len, -1),
None,
None,
None,
d_A_q,
d_B_q,
None,
None,
None,
None,
d_A_k,
d_B_k,
None,
None,
None,
None,
d_A_v,
d_B_v,
None,
@@ -704,25 +653,22 @@ def apply_lora_qkv(
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(
X,
QW,
Qb,
QW_quant,
QA,
QB,
QS,
KW,
Kb,
KW_quant,
KA,
KB,
KS,
VW,
Vb,
VW_quant,
VA,
VB,
@@ -742,11 +688,10 @@ class LoRA_O(torch.autograd.Function):
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
W: torch.Tensor,
b: torch.Tensor,
W_quant: QuantState | None,
A: torch.Tensor,
B: torch.Tensor,
s: float,
A: torch.Tensor | None,
B: torch.Tensor | None,
S: float,
) -> torch.Tensor:
"""
Forward pass for output projection with LoRA.
@@ -755,20 +700,19 @@ class LoRA_O(torch.autograd.Function):
ctx: Autograd context
X: Input tensor
W: Output projection weight
b: Output projection bias
W_quant: Weight quantization state
A: LoRA A matrix
B: LoRA B matrix
s: LoRA scaling factor
S: LoRA scaling factor
Returns:
Output projection result
Output projection tensor
"""
XW = matmul_lora(X, W, b, W_quant, A, B, s)
XW = matmul_lora(X, W, W_quant, A, B, S)
ctx.custom_saved_tensors = (
W,
W_quant,
s,
S,
)
ctx.save_for_backward(A, B, X)
@@ -783,9 +727,8 @@ class LoRA_O(torch.autograd.Function):
torch.Tensor,
None,
None,
None,
torch.Tensor,
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
None,
]:
"""
@@ -798,7 +741,7 @@ class LoRA_O(torch.autograd.Function):
Returns:
Tuple containing gradients for all forward inputs
"""
W, W_quant, s = ctx.custom_saved_tensors
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
batch, seq_len, hd = X.shape
@@ -808,19 +751,17 @@ class LoRA_O(torch.autograd.Function):
# Weight projection
dY_X = X.t() @ dY
d_A = s * dY_X @ B
d_B = s * A @ dY_X
d_A = S * dY_X @ B
d_B = S * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
A, B = A.to(dtype), B.to(dtype)
dX += s * dY @ B @ A
# W, b, W_quant, A, B, s
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
@@ -833,7 +774,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
Returns:
Transformed output tensor
"""
OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS)
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
return output

View File

@@ -76,7 +76,6 @@ def load_lora(
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
@@ -107,7 +106,6 @@ def load_lora(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
target_parameters=lora_target_parameters,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,

View File

@@ -1,5 +1,5 @@
"""
Model loader class implementation for loading, configuring, and patching various models.
"""Model loader class implementation for loading, configuring, and patching various
models.
"""
import gc
@@ -13,7 +13,7 @@ import peft
import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from accelerate import PartialState, init_empty_weights
from accelerate.parallelism_config import ParallelismConfig
from peft import (
PeftConfig,
@@ -22,7 +22,6 @@ from peft import (
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
@@ -50,11 +49,7 @@ from axolotl.loaders.utils import (
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
build_parallelism_config,
get_device_count,
get_device_type,
)
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
@@ -92,7 +87,6 @@ class ModelLoader:
use_parallel_config: bool | None = False
parallelism_config: ParallelismConfig | None = None
device_mesh: DeviceMesh | None = None
def __init__(
self,
@@ -208,8 +202,6 @@ class ModelLoader:
self._set_device_map_config()
if self.cfg.revision_of_model:
self.model_kwargs["revision"] = self.cfg.revision_of_model
if self.cfg.use_kernels:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
@@ -308,10 +300,7 @@ class ModelLoader:
)
# Handle DeepSpeed Zero3
if (
is_deepspeed_zero3_enabled()
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
):
if is_deepspeed_zero3_enabled():
self._set_z3_leaf_modules()
# Apply gradient checkpointing if needed
@@ -416,12 +405,85 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
@staticmethod
def _get_parallel_config_kwargs(
world_size: int,
tensor_parallel_size: int = 1,
context_parallel_size: int = 1,
dp_shard_size: int | None = None,
dp_replicate_size: int | None = None,
is_fsdp: bool = False,
):
pc_kwargs = {}
remaining_world_size = world_size
if tensor_parallel_size and tensor_parallel_size > 1:
pc_kwargs["tp_size"] = tensor_parallel_size
remaining_world_size = remaining_world_size // tensor_parallel_size
if context_parallel_size and context_parallel_size > 1:
pc_kwargs["cp_size"] = context_parallel_size
remaining_world_size = remaining_world_size // context_parallel_size
if dp_shard_size is None and dp_replicate_size in (None, 1):
if remaining_world_size > 1:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if dp_replicate_size and dp_replicate_size > 1:
pc_kwargs["dp_replicate_size"] = dp_replicate_size
remaining_world_size = remaining_world_size // dp_replicate_size
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
if not is_fsdp:
raise ValueError(
"dp_shard_size was configured without a corresponding fsdp_config! "
"Please ensure you have configured FSDP using fsdp_config."
)
pc_kwargs["dp_shard_size"] = dp_shard_size
remaining_world_size = remaining_world_size // dp_shard_size
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
pc_kwargs["dp_replicate_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
if "dp_shard_size" not in pc_kwargs and is_fsdp:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
raise ValueError(
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
f"{pc_kwargs}"
)
return pc_kwargs
def _set_parallel_config(self):
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
parallelism_config, device_mesh = build_parallelism_config(self.cfg)
if parallelism_config:
self.parallelism_config = parallelism_config
self.device_mesh = device_mesh
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
get_world_size(),
self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
self.cfg.dp_shard_size,
self.cfg.dp_replicate_size,
bool(self.cfg.fsdp or self.cfg.fsdp_config),
)
if pc_kwargs:
self.parallelism_config = ParallelismConfig(
**pc_kwargs,
)
device_mesh = self.parallelism_config.build_device_mesh("cuda")
partial_state = PartialState()
# fmt: off
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
self.parallelism_config
)
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
device_mesh
)
# fmt: on
def _set_auto_model_loader(self):
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
@@ -503,17 +565,8 @@ class ModelLoader:
def _set_quantization_config(self):
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
if self.cfg.model_quantization_config == "Mxfp4Config":
from transformers import Mxfp4Config
mxfp4_kwargs = {}
if self.cfg.model_quantization_config_kwargs:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
else:
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
@@ -548,9 +601,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
"load_in_4bit", False
):
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -576,9 +627,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
"load_in_8bit", False
):
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
bnb_config = {
"load_in_8bit": True,
}
@@ -599,9 +648,7 @@ class ModelLoader:
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.attn_implementation:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
elif self.cfg.flex_attention:
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
@@ -674,7 +721,7 @@ class ModelLoader:
if self.cfg.tensor_parallel_size > 1:
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
self.model_kwargs["tp_plan"] = "auto"
self.model_kwargs["device_mesh"] = self.device_mesh
self.model_kwargs["device_mesh"] = PartialState().device_mesh
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
@@ -690,18 +737,6 @@ class ModelLoader:
elif self.is_qlora_and_fsdp_enabled:
skip_move_to_device = True
if (
self.cfg.tensor_parallel_size <= 1
and self.cfg.fsdp_config.cpu_ram_efficient_loading
and self.cfg.fsdp_version == 2
):
# setting device_map for TP is not supported
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if local_rank == 0:
self.model_kwargs["device_map"] = "cpu"
else:
self.model_kwargs["device_map"] = "meta"
if (
self.is_qlora_and_fsdp_enabled
and self.cfg.fsdp_config.cpu_ram_efficient_loading
@@ -810,9 +845,6 @@ class ModelLoader:
self.model._tp_size = self.cfg.tensor_parallel_size
self.model._device_mesh = self.model_kwargs["device_mesh"]
if self.cfg.experimental_skip_move_to_device is not None:
skip_move_to_device = self.cfg.experimental_skip_move_to_device
return skip_move_to_device
def _set_z3_leaf_modules(self):

View File

@@ -65,7 +65,6 @@ class PatchManager:
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -76,20 +75,8 @@ class PatchManager:
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
patch_prepare_from_posids,
)
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,
patch_maybe_log_save_evaluate,
)
patch_fsdp2 = (
self.cfg.torch_compile
and self.cfg.fsdp_config
and self.cfg.fsdp_version == 2
)
patch_prepare_from_posids()
patch_evaluation_loop(patch_fsdp2)
patch_maybe_log_save_evaluate()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
@@ -116,14 +103,6 @@ class PatchManager:
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.context_parallel_size > 1 or (
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
):
from axolotl.monkeypatch.accelerate.parallelism_config import (
patch_parallelism_config,
)
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
@@ -281,23 +260,6 @@ class PatchManager:
has_remote_code=has_remote_code,
)
def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches."""
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import (
@@ -368,21 +330,31 @@ class PatchManager:
patch_self_attn_lora()
def _patch_llama_flash_attention(self):
def _patch_llama_flash_attention(self, packed=False):
"""Apply Flash Attention patches for LLaMA models."""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
if self.cfg.s2_attention:
if packed:
if self.cfg.device not in ["mps", "cpu"] and not self.inference:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
elif self.cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
@@ -413,7 +385,7 @@ class PatchManager:
and self.cfg.sample_packing
):
if self.cfg.flash_attention:
self._patch_llama_flash_attention()
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
elif self.cfg.xformers_attention:
self._patch_llama_xformers_attention()
elif self.cfg.sample_packing:
@@ -436,12 +408,17 @@ class PatchManager:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("Patching with SwiGLU...")
replace_llama_mlp_with_swiglu(model)
if self.cfg.flash_attn_fuse_qkv:
LOG.info("Patching with fused QKV...")
replace_llama_qkv_with_fused(model)
def _apply_unsloth_patches(self, model):
"""Apply unsloth optimization patches."""
if self.cfg.unsloth_lora_mlp:

View File

@@ -7,7 +7,6 @@ import functools
import sys
import torch
import torch.distributed as dist
from torch import nn
from axolotl.utils.bench import log_gpu_memory_usage
@@ -37,49 +36,25 @@ def fsdp2_load_full_state_dict(
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, sharded_meta_param in meta_sharded_sd.items():
full_tensor = None
if _accelerator.is_main_process:
full_tensor = full_sd[param_name]
full_tensor = full_tensor.to(sharded_meta_param.dtype)
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
if hasattr(sharded_meta_param, "device_mesh"):
device_mesh = sharded_meta_param.device_mesh
if _accelerator.is_main_process:
full_tensor = full_tensor.to(device_mesh.device_type)
else:
full_tensor = torch.empty(
sharded_meta_param.size(),
device=device_mesh.device_type,
dtype=sharded_meta_param.dtype,
)
sharded_param = distribute_tensor(
full_tensor,
device_mesh,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
src_data_rank=0,
)
else:
# Non-sharded parameters
if _accelerator.is_main_process:
sharded_param = full_tensor.to(torch.device("cuda"))
else:
# broadcast manually
sharded_param = torch.empty_like(
sharded_meta_param,
device=torch.device("cuda"),
dtype=sharded_meta_param.dtype,
)
dist.broadcast(sharded_param, src=0)
sharded_param = full_tensor
if offload_to_cpu:
sharded_param = sharded_param.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_param)
del full_tensor
full_sd[param_name] = None
model.load_state_dict(sharded_sd, assign=True, strict=True)
end_time = time.time()
LOG.debug(

View File

@@ -1,77 +0,0 @@
"""
workaround to allow parallelism config for pure CP
"""
# pylint: disable=protected-access
import os
import warnings
from accelerate import DistributedType
def _validate_accelerator(self, accelerator):
_warnings = set()
if not accelerator.multi_device and self.total_size == 1:
# No distributed setup, valid parallelism config
return
# We need this to ensure DDP works
if self.total_size == 1:
self._set_size("dp_replicate", accelerator.num_processes)
if self.total_size != accelerator.num_processes:
raise ValueError(
f"ParallelismConfig total_size ({self.total_size}) does not match "
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
f"dp_shard_size/tp_size/cp_size."
)
# allow parallelism config when not using fsdp if using pure context parallelism
allow_parallelism_config = False
if (
self.cp_size > 1 # pylint: disable=chained-comparison
and self.dp_shard_size <= 1
and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true"
):
allow_parallelism_config = True
if (
self.total_size > 1
and not allow_parallelism_config
and not (accelerator.is_fsdp2 or accelerator.multi_device)
):
raise ValueError(
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
)
for parallelism, size in self._sizes.items():
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
_warnings.add(
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
)
if _warnings and accelerator.is_main_process:
warnings.warn(
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
UserWarning,
)
def patched_is_fsdp2(self) -> bool:
"""
Patched version of is_fsdp2 that guards against a None fsdp_plugin.
"""
# The new logic checks if fsdp_plugin exists before accessing its attributes
return (
self.distributed_type == DistributedType.FSDP
and self.fsdp_plugin
and self.fsdp_plugin.fsdp_version == 2
)
def patch_parallelism_config():
from accelerate.accelerator import AcceleratorState, ParallelismConfig
ParallelismConfig._validate_accelerator = _validate_accelerator
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)

View File

@@ -1,205 +0,0 @@
"""
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
our LoRA / QLoRA Triton kernels to work with FSDP2.
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
Params4bit parameters.
"""
import importlib
import inspect
import torch
from torch.nn import Parameter
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patched_torch_function(cls, func, types, args=(), kwargs=None):
"""
Patched version of Params4bit.__torch_function__ for preserving Params4bit
class identity and attributes.
"""
if kwargs is None:
kwargs = {}
if func in [torch.chunk, torch.split]:
tensor = args[0]
result = Parameter.__torch_function__(func, types, args, kwargs)
if isinstance(result, tuple):
return tuple(
cls(
data=chunk,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
for chunk in result
)
return cls(
data=result,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
return Parameter.__torch_function__(func, types, args, kwargs)
# pylint: disable=protected-access
def apply_bnb_torch_function_patch():
"""
Patch Params4bit.__torch_function__ using Axolotl-style approach.
Returns:
True if patching succeeded, False otherwise.
"""
from bitsandbytes.nn.modules import Params4bit
Params4bit.__torch_function__ = classmethod(patched_torch_function)
LOG.info("Successfully patched Params4bit.__torch_function__")
# pylint: disable=protected-access
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
original_source = inspect.getsource(FSDPParam._init_sharded_param)
original_source, _ = detab_code(original_source)
# Define the replacement
original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
patched_param_creation = """ import bitsandbytes as bnb
if isinstance(param, bnb.nn.modules.Params4bit):
self.sharded_param = bnb.nn.modules.Params4bit(
data=sharded_param,
requires_grad=param.requires_grad,
quant_state=param.quant_state,
blocksize=param.blocksize,
compress_statistics=param.compress_statistics,
quant_type=param.quant_type,
quant_storage=param.quant_storage,
module=param.module,
bnb_quantized=param.bnb_quantized,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
else:
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
# Apply the replacement
if original_param_creation in original_source:
patched_source = original_source.replace(
original_param_creation, patched_param_creation
)
patched_source = patched_source.replace(
"def _init_sharded_param(",
"def patched_init_sharded_param(",
1,
)
# Load necessary imports
module_name = FSDPParam.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
# Replace the method
FSDPParam._init_sharded_param = patched_init_sharded_param # pylint: disable=undefined-variable # noqa: F821
LOG.info("Successfully applied FSDP _init_sharded_param patch")
else:
LOG.warning("Could not find target code for _init_sharded_param patching")
def apply_init_unsharded_param_patch():
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
original_source = inspect.getsource(FSDPParam.init_unsharded_param)
original_source, _ = detab_code(original_source)
# Define the replacement
original_param_creation = """ self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
)"""
patched_param_creation = """ import bitsandbytes as bnb
local_tensor = self.sharded_param._local_tensor
if isinstance(local_tensor, bnb.nn.modules.Params4bit):
self._unsharded_param = bnb.nn.modules.Params4bit(
data=unsharded_param,
requires_grad=self.sharded_param.requires_grad,
quant_state=local_tensor.quant_state,
blocksize=local_tensor.blocksize,
compress_statistics=local_tensor.compress_statistics,
quant_type=local_tensor.quant_type,
quant_storage=local_tensor.quant_storage,
module=local_tensor.module,
bnb_quantized=local_tensor.bnb_quantized,
)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
)"""
# Apply the replacement
if original_param_creation in original_source:
patched_source = original_source.replace(
original_param_creation, patched_param_creation
)
patched_source = patched_source.replace(
"def init_unsharded_param(",
"def patched_init_unsharded_param(",
1,
)
# Load necessary imports
module_name = FSDPParam.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
# Replace the method
FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821
LOG.info("Successfully applied FSDP init_unsharded_param patch")
else:
LOG.warning("Could not find target code for patching")

View File

@@ -3,26 +3,39 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
)
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import (
LlamaMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import set_module_name
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from axolotl.utils.logging import get_logger
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
)
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
@@ -69,6 +82,19 @@ def replace_llama_mlp_with_swiglu(model):
set_module_name(model, name, mlp)
def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
qkv = FusedAttention(
module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.o_proj,
)
set_module_name(model, name, qkv)
def patch_fa_llama_cross_entropy():
LOG.info(
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
@@ -116,6 +142,7 @@ def patch_llama_rms_norm():
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
use_shifted_sparse_attn: Optional[bool] = False,
@@ -127,6 +154,16 @@ def replace_llama_attn_with_flash_attn(
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
flashattn_forward_with_s2attn
)
else:
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
flashattn_forward
)
if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = (
llama_model_forward
)
# skip only if explicitly disabled
if cross_entropy:
@@ -137,6 +174,49 @@ def replace_llama_attn_with_flash_attn(
patch_llama_rms_norm()
class FusedAttention(LlamaAttention):
"""
Fused QKV Attention layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
q: torch.nn.Linear, # pylint: disable=invalid-name
k: torch.nn.Linear, # pylint: disable=invalid-name
v: torch.nn.Linear, # pylint: disable=invalid-name
o: torch.nn.Linear, # pylint: disable=invalid-name
):
super().__init__(config)
self.config = config
self.init_device = next(iter(q.state_dict().values())).device
# define equivalent fused qkv projection
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
self.qkv_proj = torch.nn.Linear(
q.in_features, sum(self.out_features), device=self.init_device, bias=False
)
self.o_proj = o
# overwrite initialized weights with pretrained weights
self.qkv_proj.weight.data = torch.cat(
(q.weight.data, k.weight.data, v.weight.data), dim=0
)
def _post_training(self, model, name):
q_proj, k_proj, v_proj = torch.split(
self.qkv_proj.weight.data, self.out_features, dim=0
)
new_attn = LlamaAttention(self.config)
new_attn.q_proj.weight.data = q_proj
new_attn.k_proj.weight.data = k_proj
new_attn.v_proj.weight.data = v_proj
new_attn.o_proj.weight.data = self.o_proj.weight.data
set_module_name(model, name, new_attn)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
@@ -275,3 +355,576 @@ def flashattn_forward_with_s2attn(
.reshape(bsz, q_len, nheads, self.head_dim)
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
def flashattn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
if isinstance(self, FusedAttention):
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
self.out_features, dim=-1
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# flash-attn v2 start
#
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
#
# flash-attn v2 end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
def output_pad_fn(output_unpad):
return pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
def output_pad_fn(output_unpad):
return rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(
*inputs,
)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
padding_mask,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
"""
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

View File

@@ -156,11 +156,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return Llama4TextAttention
if model_type == "mistral3":
from transformers.models.mistral.modeling_mistral import MistralAttention
return MistralAttention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
@@ -395,6 +390,7 @@ def apply_lora_kernel_patches(
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -404,8 +400,7 @@ def apply_lora_kernel_patches(
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
@@ -414,6 +409,7 @@ def apply_lora_kernel_patches(
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -422,14 +418,14 @@ def apply_lora_kernel_patches(
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
@@ -439,8 +435,7 @@ def apply_lora_kernel_patches(
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters and no "
"lora_magnitude_vector (DoRA)"
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
)
LOG.setLevel(original_level)

View File

@@ -3,14 +3,53 @@
# pylint: disable=duplicate-code
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
from transformers.models.mistral.modeling_mistral import (
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def replace_mistral_attn_with_flash_attn(
packed: Optional[bool] = False,
):
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
flashattn_forward
)
if packed:
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
MistralDecoderLayer
)
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
mistral_model_forward
)
def patch_mistral_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
@@ -18,3 +57,604 @@ def patch_mistral_cross_entropy():
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
sliding_window,
): # pylint: disable=unused-argument
# [bsz, seq_len]
if attention_mask is None or sliding_window is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
return attention_mask
def flashattn_forward(
self: OriginalMistralAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
use_sliding_windows = (
getattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)
if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if (
hasattr(self.config, "sliding_window")
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
past_key_value = (past_key, past_value) if use_cache else None
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
window_size=window_size,
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal,
window_size=window_size,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
def output_pad_fn(output_unpad):
return pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
def output_pad_fn(output_unpad):
return rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def mistral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = (
self._gradient_checkpointing_func( # pylint: disable=protected-access
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class MistralDecoderLayer(OriginalMistralDecoderLayer):
"""
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

View File

@@ -36,8 +36,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"glm",
"glm4",
"smollm3",
"gpt_oss",
"arcee",
"granite",
"granitemoe",
]

View File

@@ -0,0 +1,78 @@
"""
fix for FSDP2 evals when using torch.compile
"""
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
model.eval()
"""
PATCHED_TRAINER_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
def get_evaluation_loop_code() -> str:
training_loop = inspect.getsource(Trainer.evaluation_loop)
return training_loop
def check_evaluation_loop_is_patchable() -> bool:
eval_loop = get_evaluation_loop_code()
eval_loop, _ = detab_code(eval_loop)
return ORIGINAL_TRAINER_CODE in eval_loop
def patch_evaluation_loop_for_fsdp2():
"""
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
"""
try:
evaluation_loop = get_evaluation_loop_code()
except OSError:
return
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
evaluation_loop
)
evaluation_loop, _ = detab_code(evaluation_loop)
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
return
evaluation_loop = evaluation_loop.replace(
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
)
evaluation_loop = evaluation_loop.replace(
"def evaluation_loop(",
"def _fixed_evaluation_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in evaluation_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer.evaluation_loop = ( # pylint: disable=protected-access
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -1,165 +0,0 @@
"""
Module for patching transformers Trainer loss calculation to use nanmean.
This is needed for context parallelism since chunks of the input sequences may be fully
masked and return NaNs in the loss calculation.
Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with
the other evaluation_loop patch because we can't patch the same code twice without
raising an OSError.
"""
import importlib
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_EVAL_CODE = {
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()',
"array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()',
}
PATCHED_EVAL_CODE = {
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()',
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}
ORIGINAL_FSDP2_CODE = """
model.eval()
"""
PATCHED_FSDP2_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
def check_evaluation_loop_is_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
# pylint: disable=protected-access
def patch_evaluation_loop(patch_fsdp2: bool):
"""Patch the evaluation_loop method."""
# Check if already patched
if hasattr(Trainer, "_original_evaluation_loop"):
LOG.info("Trainer.evaluation_loop already patched")
return
# Check if the patterns exist
try:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
except OSError:
return
Trainer.evaluation = evaluation_loop_source
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
# Apply the nanmean patches
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"]
)
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
)
# Apply FSDP2 eval guard patch if needed
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
)
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
# Rename the function to avoid conflicts
evaluation_loop_source = evaluation_loop_source.replace(
"def evaluation_loop(",
"def axolotl_evaluation_loop(",
1,
)
# Get the module for necessary imports
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# Import necessary items from the module
items_to_import = []
for item in dir(module):
if item in evaluation_loop_source:
items_to_import.append(item)
# Execute the imports and patched method
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
Trainer.evaluation_loop = (
axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)
def check_maybe_log_save_evaluate_is_patchable() -> bool:
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
return ORIGINAL_MAYBE_CODE in maybe_log_source
# pylint: disable=protected-access
def patch_maybe_log_save_evaluate():
"""Patch the _maybe_log_save_evaluate method."""
# Check if already patched
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
LOG.info("Trainer._maybe_log_save_evaluate already patched")
return
# Check if the patterns exist
try:
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
except OSError:
return
Trainer._original_maybe_log_save_evaluate = maybe_log_source
maybe_log_source, _ = detab_code(maybe_log_source)
# Apply the patch
maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)
# Rename the function to avoid conflicts
maybe_log_source = maybe_log_source.replace(
"def _maybe_log_save_evaluate(",
"def axolotl_maybe_log_save_evaluate(",
1,
)
# Get the module for necessary imports
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# Import necessary items from the module
items_to_import = []
for item in dir(module):
if item in maybe_log_source:
items_to_import.append(item)
# Execute the imports and patched method
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821

View File

@@ -41,9 +41,7 @@ class ChatTemplatePrompter(Prompter):
field_messages: str = "messages",
field_system: str = "system",
field_tools: str = "tools",
field_thinking: str = "reasoning_content",
roles: dict[str, list[str]] | None = None,
template_thinking_key: str | None = "reasoning_content",
chat_template_kwargs: dict[str, Any] | None = None,
drop_system_message: bool = False,
):
@@ -52,9 +50,8 @@ class ChatTemplatePrompter(Prompter):
message_property_mappings = {
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
}
if template_thinking_key and field_thinking:
message_property_mappings[template_thinking_key] = field_thinking
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
@@ -77,12 +74,10 @@ class ChatTemplatePrompter(Prompter):
self.field_messages = field_messages
self.field_system = field_system
self.field_tools = field_tools
self.field_thinking = field_thinking
self.tokenizer = tokenizer
self.processor: ProcessorMixin | None = processor
self.chat_template = chat_template
self.chat_template_kwargs = chat_template_kwargs or {}
self.template_thinking_key: str = template_thinking_key or "reasoning_content"
self.max_length = max_length
self.drop_system_message = drop_system_message
@@ -747,9 +742,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
# get the thinking content
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
transformed_message[self.prompter.template_thinking_key] = (
thinking_content.strip()
)
transformed_message["reasoning_content"] = thinking_content.strip()
# take remainder of the content
# strip whitespace from beginning of the remainder (thinking tokens)
@@ -960,10 +953,6 @@ class StrategyLoader:
None,
),
"field_messages": dataset_config.get("field_messages", "messages"),
"field_thinking": dataset_config.get("field_thinking", "reasoning_content"),
"template_thinking_key": dataset_config.get(
"template_thinking_key", "reasoning_content"
),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.

View File

@@ -218,7 +218,6 @@ def execute_training(
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
gather_outputs=cfg.rl is RLType.GRPO,
device_mesh=trainer.accelerator.torch_device_mesh,
)
)
@@ -275,7 +274,7 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
if trainer.is_fsdp_enabled or cfg.fsdp_config:
if trainer.is_fsdp_enabled:
if cfg.fsdp_config or cfg.fsdp:
if cfg.fsdp_config.final_state_dict_type:
state_dict_type = cfg.fsdp_config.final_state_dict_type
@@ -567,10 +566,6 @@ def train(
resume_from_checkpoint = determine_resume_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint)
# clear cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
create_model_card(cfg, trainer)

View File

@@ -0,0 +1,62 @@
{# Alias tools -> available_tools #}
{%- if tools and not available_tools -%}
{%- set available_tools = tools -%}
{%- endif -%}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = "Knowledge Cutoff Date: April 2024.
Today's Date: " + strftime_now('%B %d, %Y') + ".
You are Granite, developed by IBM." %}
{%- if available_tools and documents %}
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif available_tools %}
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
{%- elif documents %}
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif thinking %}
{%- set system_message = system_message + " You are a helpful AI assistant.
Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between <think></think> and write your response between <response></response> for each user query." %}
{%- else %}
{%- set system_message = system_message + " You are a helpful AI assistant." %}
{%- endif %}
{%- if 'citations' in controls and documents %}
{%- set system_message = system_message + '
Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
{%- endif %}
{%- if 'hallucinations' in controls and documents %}
{%- set system_message = system_message + '
Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %}
{%- endif %}
{%- set loop_messages = messages %}
{%- endif %}
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
' }}
{%- if available_tools %}
{{- '<|start_of_role|>available_tools<|end_of_role|>' }}
{{- available_tools | tojson(indent=4) }}
{{- '<|end_of_text|>
' }}
{%- endif %}
{%- if documents %}
{%- for document in documents %}
{{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|>
' }}
{{- document['text'] }}
{{- '<|end_of_text|>
' }}
{%- endfor %}
{%- endif %}
{%- for message in loop_messages %}
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- if loop.last and add_generation_prompt %}
{{- '<|start_of_role|>assistant' }}
{%- if controls %}
{{- ' ' + controls | tojson()}}
{%- endif %}
{{- '<|end_of_role|>' }}
{%- endif %}
{%- endfor %}

View File

@@ -0,0 +1,64 @@
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = "Knowledge Cutoff Date: April 2024.
Today's Date: " + strftime_now('%B %d, %Y') + ".
You are Granite, developed by IBM." %}
{%- if tools and documents %}
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif tools %}
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
{%- elif documents %}
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- else %}
{%- set system_message = system_message + " You are a helpful AI assistant." %}
{%- endif %}
{%- if 'citations' in controls and documents %}
{%- set system_message = system_message + '
In your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
{%- endif %}
{%- if 'hallucinations' in controls and documents %}
{%- set system_message = system_message + '
Finally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %}
{%- endif %}
{%- set loop_messages = messages %}
{%- endif %}
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
' }}
{%- if tools %}
{{- '<|start_of_role|>tools<|end_of_role|>' }}
{{- tools | tojson(indent=4) }}
{{- '<|end_of_text|>
' }}
{%- endif %}
{%- if documents %}
{{- '<|start_of_role|>documents<|end_of_role|>' }}
{%- for document in documents %}
{{- 'Document ' + loop.index0 | string + '
' }}
{{- document['text'] }}
{%- if not loop.last %}
{{- '
'}}
{%- endif%}
{%- endfor %}
{{- '<|end_of_text|>
' }}
{%- endif %}
{%- for message in loop_messages %}
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- if loop.last and add_generation_prompt %}
{{- '<|start_of_role|>assistant' }}
{%- if controls %}
{{- ' ' + controls | tojson()}}
{%- endif %}
{{- '<|end_of_role|>' }}
{%- endif %}
{%- endfor %}

View File

@@ -161,8 +161,6 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
Collator for multipack specific to the using the BatchSampler
"""
squash_position_ids: bool = False
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features: List[List[dict]] = [features]
@@ -178,15 +176,6 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
elif feature == "position_ids" and self.squash_position_ids:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
# concatenate, get total length and create arange of new total position ids
position_ids = np.concatenate(arrays)
total_length = position_ids.shape[0]
position_ids = np.arange(total_length)
out_features[i][feature] = position_ids
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item

View File

@@ -5,8 +5,8 @@ import inspect
import torch
import torch.distributed as dist
from accelerate import PartialState
from torch import nn
from torch.distributed import DeviceMesh
from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
@@ -194,7 +194,6 @@ class SequenceParallelContextManager:
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
gather_outputs: bool,
device_mesh: DeviceMesh | None = None,
):
self.models = models
self.context_parallel_size = context_parallel_size
@@ -202,7 +201,6 @@ class SequenceParallelContextManager:
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
self.gather_outputs = gather_outputs
self.device_mesh = device_mesh
self._register_ring_attn()
@@ -242,8 +240,9 @@ class SequenceParallelContextManager:
def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism
partial_state = PartialState()
register_ring_attn_from_device_mesh(
device_mesh=self.device_mesh,
device_mesh=partial_state.device_mesh,
context_parallel_dim=("cp",),
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,

View File

@@ -1,7 +1,6 @@
"""Data handling specific to SFT."""
import functools
import os
import tempfile
from typing import Literal
@@ -105,9 +104,6 @@ def _prepare_standard_dataset(
finally:
loader.cleanup()
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
return train_dataset, eval_dataset, -1, prompters
# Validate sample packing configuration for evaluation
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)

View File

@@ -8,7 +8,6 @@ from datetime import timedelta
import torch
import torch.distributed as dist
from accelerate import PartialState
from accelerate.utils import ParallelismConfig
from transformers.utils.import_utils import (
is_torch_cuda_available,
is_torch_mps_available,
@@ -51,10 +50,7 @@ def init_distributed_state():
global distributed_state # pylint: disable=global-statement
if distributed_state is None:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
try:
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
except ValueError:
pass
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
def get_distributed_state() -> PartialState | None:
@@ -294,77 +290,3 @@ def reduce_and_broadcast(fn1, fn2):
# Use compute_and_broadcast to compute the reduced value on the main process
# and then broadcast it to all ranks
return compute_and_broadcast(lambda: fn2(gathered_values))
def build_parallelism_config(cfg):
pc_kwargs = _get_parallel_config_kwargs(
get_world_size(),
cfg.tensor_parallel_size,
cfg.context_parallel_size,
cfg.dp_shard_size,
cfg.dp_replicate_size,
bool(cfg.fsdp or cfg.fsdp_config),
)
if pc_kwargs:
parallelism_config = ParallelismConfig(
**pc_kwargs,
)
device_mesh = parallelism_config.build_device_mesh("cuda")
return parallelism_config, device_mesh
return None, None
def _get_parallel_config_kwargs(
world_size: int,
tensor_parallel_size: int = 1,
context_parallel_size: int = 1,
dp_shard_size: int | None = None,
dp_replicate_size: int | None = None,
is_fsdp: bool = False,
):
pc_kwargs = {}
remaining_world_size = world_size
if tensor_parallel_size and tensor_parallel_size > 1:
pc_kwargs["tp_size"] = tensor_parallel_size
remaining_world_size = remaining_world_size // tensor_parallel_size
if context_parallel_size and context_parallel_size > 1:
pc_kwargs["cp_size"] = context_parallel_size
remaining_world_size = remaining_world_size // context_parallel_size
if dp_shard_size is None and dp_replicate_size in (None, 1):
if remaining_world_size > 1:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if dp_replicate_size and dp_replicate_size > 1:
pc_kwargs["dp_replicate_size"] = dp_replicate_size
remaining_world_size = remaining_world_size // dp_replicate_size
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
if not is_fsdp:
raise ValueError(
"dp_shard_size was configured without a corresponding fsdp_config! "
"Please ensure you have configured FSDP using fsdp_config."
)
pc_kwargs["dp_shard_size"] = dp_shard_size
remaining_world_size = remaining_world_size // dp_shard_size
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
pc_kwargs["dp_replicate_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
if "dp_shard_size" not in pc_kwargs and is_fsdp:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
raise ValueError(
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
f"{pc_kwargs}"
)
return pc_kwargs

View File

@@ -1,28 +0,0 @@
"""
Helper for importing modules from strings
"""
import importlib
def get_cls_from_module_str(module_str: str):
# use importlib to dynamically load the reward function from the module
if not isinstance(module_str, str) or not module_str.strip():
raise ValueError("module_str must be a non-empty string")
parts = module_str.split(".")
if len(parts) < 2:
raise ValueError(f"Invalid module string format: {module_str}")
try:
cls_name = parts[-1]
module_path = ".".join(parts[:-1])
mod = importlib.import_module(module_path)
mod_cls = getattr(mod, cls_name)
return mod_cls
except ImportError as e:
raise ImportError(f"Failed to import module '{module_path}': {e}") from e
except AttributeError as e:
raise AttributeError(
f"Class '{cls_name}' not found in module '{module_path}': {e}"
) from e

View File

@@ -4,7 +4,6 @@ import math
from functools import partial
from typing import Sequence
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -46,10 +45,8 @@ class RexLR(LRScheduler):
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:
initial_lr = group["lr"]
if isinstance(initial_lr, Tensor):
initial_lr = initial_lr.clone()
group.setdefault("initial_lr", initial_lr)
group.setdefault("initial_lr", group["lr"])
# Pass self.last_step as last_epoch to the parent.
super().__init__(optimizer, last_epoch=self.last_step)

View File

@@ -110,13 +110,6 @@ class AxolotlInputConfig(
},
)
trainer_cls: str | None = Field(
default=None,
json_schema_extra={
"description": "module to custom trainer class to use for training"
},
)
rl: RLType | None = Field(
default=None,
json_schema_extra={
@@ -538,6 +531,12 @@ class AxolotlInputConfig(
"description": "Whether to use flash-attention rms norm implementation - advanced use only"
},
)
flash_attn_fuse_qkv: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to fuse QKV into a single operation"
},
)
flash_attn_fuse_mlp: bool | None = Field(
default=None,
json_schema_extra={
@@ -551,13 +550,6 @@ class AxolotlInputConfig(
eager_attention: bool | None = None
attn_implementation: str | None = Field(
default=None,
json_schema_extra={
"description": "Specify a custom attention implementation, used mostly for kernels."
},
)
unsloth_cross_entropy_loss: bool | None = None
unsloth_lora_mlp: bool | None = None
unsloth_lora_qkv: bool | None = None

View File

@@ -118,18 +118,6 @@ class SFTDataset(BaseModel):
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
},
)
field_thinking: str | None = Field(
default=None,
json_schema_extra={
"description": 'Key containing the reasoning trace (default: "reasoning_content").'
},
)
template_thinking_key: str | None = Field(
default=None,
json_schema_extra={
"description": "The key the chat template expects that indicates the reasoning trace."
},
)
# deprecated, use message_property_mappings
message_field_role: str | None = None
# deprecated, use message_property_mappings

View File

@@ -67,6 +67,8 @@ class ChatTemplate(str, Enum):
command_a_tool_use = "command_a_tool_use"
command_a_rag = "command_a_rag"
aya = "aya"
granite = "granite"
granitemoe = "granitemoe"
class CustomSupportedOptimizers(str, Enum):
@@ -79,7 +81,6 @@ class CustomSupportedOptimizers(str, Enum):
adopt_adamw = "adopt_adamw"
came_pytorch = "came_pytorch"
muon = "muon"
dion = "dion"
class RingAttnFunc(str, Enum):

View File

@@ -1,7 +1,5 @@
"""Pydantic models for model input / output, etc. configuration"""
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from axolotl.utils.logging import get_logger
@@ -64,28 +62,6 @@ class ModelInputConfig(BaseModel):
json_schema_extra={"description": "Trust remote code for untrusted source"},
)
experimental_skip_move_to_device: bool | None = Field(
default=None,
json_schema_extra={
"description": "Don't move the model to the device before sharding. "
"This is an experimental feature that may be included in the future as the default."
},
)
use_kernels: bool | None = Field(
default=None,
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
)
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
default=None,
json_schema_extra={"description": "Model loading quantization config"},
)
model_quantization_config_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={"description": "kwargs for model quantization config"},
)
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):

View File

@@ -54,7 +54,6 @@ class LoraConfig(BaseModel):
lora_alpha: int | None = None
lora_fan_in_fan_out: bool | None = None
lora_target_modules: str | list[str] | None = None
lora_target_parameters: str | list[str] | None = None
lora_target_linear: bool | None = Field(
default=None,
json_schema_extra={"description": "If true, will target all linear modules"},

View File

@@ -138,26 +138,6 @@ class HyperparametersConfig(BaseModel):
adam_beta3: float | None = Field(
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
)
dion_lr: float | None = Field(
default=None, json_schema_extra={"description": "Dion Optimizer learning rate"}
)
dion_momentum: float | None = Field(
default=None, json_schema_extra={"description": "Dion Optimizer momentum"}
)
dion_rank_fraction: float | None = Field(
default=1.0,
json_schema_extra={
"description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension."
},
)
dion_rank_multiple_of: int | None = Field(
default=1,
json_schema_extra={
"description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding."
},
)
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)

View File

@@ -559,6 +559,20 @@ class LoRAValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_8bit(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_axolotl_unsloth(cls, data):
@@ -577,7 +591,9 @@ class LoRAValidationMixin:
@model_validator(mode="after")
def check_fused_lora(self):
if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp:
if self.adapter in ["lora", "qlora"] and (
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
):
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
return self
@@ -603,7 +619,7 @@ class LoRAValidationMixin:
@model_validator(mode="before")
@classmethod
def check_lora_kernels_8bit(cls, data):
def check_lora_kernel_8bit(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
@@ -611,39 +627,36 @@ class LoRAValidationMixin:
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with 8-bit LoRA a the moment."
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_dora(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("peft_use_dora"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with DoRA at the moment."
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_rl(cls, data):
def check_lora_kernel_rl(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("rl"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with RL at the moment."
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment."
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_dropout_parameters(cls, data):
if (
data.get("lora_dropout", 0.0)
and data.get("lora_dropout") > 0.0
and data.get("lora_target_parameters")
):
# lora.ParamWrapper does not work with lora_dropout != 0
raise ValueError(
"`lora_dropout` does not work when using `lora_target_parameters`"
)
class RLValidationMixin:
"""Validation methods related to RL training configuration."""
@@ -972,16 +985,6 @@ class SystemValidationMixin:
raise ValueError("deepspeed and fsdp cannot be used together.")
return data
@model_validator(mode="before")
@classmethod
def check_model_quantization_config_vs_bnb(cls, data):
if data.get("model_quantization_config"):
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
"model_quantization_config and load_in_8bit or load_in_4bit cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):
@@ -1147,19 +1150,6 @@ class ModelCompatibilityValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_gpt_oss_fsdp_loading(cls, data):
if data.get("model_quantization_config", "") == "Mxfp4Config":
if (
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
is True
):
raise ValueError(
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
)
return data
class ComplexValidationMixin:
"""Complex validation methods that involve multiple systems."""
@@ -1205,7 +1195,7 @@ class ComplexValidationMixin:
"ReLoRA is not compatible with the one_cycle scheduler"
)
if self.flash_attn_fuse_mlp:
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
return self

View File

@@ -15,6 +15,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger
@@ -596,25 +597,6 @@ def setup_fsdp_envs(cfg):
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
def setup_parallelism_envs(cfg):
set_accelerate_parallelism_config = False
if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_TP_SIZE"] = str(cfg.tensor_parallel_size)
if cfg.dp_shard_size and cfg.dp_shard_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_DP_SHARD_SIZE"] = str(cfg.dp_shard_size)
if cfg.dp_replicate_size and cfg.dp_replicate_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_DP_REPLICATE_SIZE"] = str(cfg.dp_replicate_size)
if cfg.context_parallel_size and cfg.context_parallel_size > 1:
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
if set_accelerate_parallelism_config:
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
@@ -633,7 +615,6 @@ def prepare_optim_env(cfg):
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage)
setup_parallelism_envs(cfg)
setup_torch_compile_env(cfg)
if cfg.fp8:
@@ -686,6 +667,8 @@ def setup_trainer(
"""
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2:
patch_evaluation_loop_for_fsdp2()
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref

View File

@@ -64,7 +64,6 @@ def sample_tensors():
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
),
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
"b": torch.randn(out_dim, device="cuda", dtype=torch.float16),
"scale": 0.5,
"shapes": {
"batch": batch_size,
@@ -104,24 +103,23 @@ def mock_proj():
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, _, A, B, s = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
assert b.shape == (128,)
assert A.shape == (8, 64)
assert B.shape == (128, 8)
assert s == 0.5
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
@@ -129,7 +127,6 @@ def test_matmul_lora(sample_tensors):
"""Tests matmul_lora function"""
X = sample_tensors["X"]
W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -141,20 +138,19 @@ def test_matmul_lora(sample_tensors):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test base matmul
out1 = matmul_lora(X, W, b, None, None, None, None)
matmul = torch.matmul(X, W.t())
expected1 = matmul + b
out1 = matmul_lora(X, W, None, None, None, None)
expected1 = torch.matmul(X, W.t())
assert torch.allclose(out1, expected1, rtol=1e-3)
# Test with LoRA
out2 = matmul_lora(X, W, b, None, A, B, scale)
out2 = matmul_lora(X, W, None, A, B, scale)
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
expected2 = matmul + lora_term + b
expected2 = expected1 + lora_term
assert torch.allclose(out2, expected2, rtol=1e-3)
# Test 3D input reshaping
X_3d = X.clone()
out3 = matmul_lora(X_3d, W, b, None, A, B, scale)
out3 = matmul_lora(X_3d, W, None, A, B, scale)
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -179,19 +175,16 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
output = LoRA_MLP.apply(
X,
gate_proj.weight,
gate_proj.bias,
None, # gate_quant
None, # gate_A
None, # gate_B
None, # gate_scale
up_proj.weight,
up_proj.bias,
None, # up_quant
None, # up_A
None, # up_B
None, # up_scale
down_proj.weight,
down_proj.bias,
None, # down_quant
None, # down_A
None, # down_B
@@ -250,19 +243,16 @@ def test_lora_mlp_with_adapters(
output = LoRA_MLP.apply(
X,
gate_proj.weight,
gate_proj.bias,
None,
gate_A,
gate_B,
scale,
up_proj.weight,
up_proj.bias,
None,
up_A,
up_B,
scale,
down_proj.weight,
down_proj.bias,
None,
down_A,
down_B,
@@ -333,7 +323,6 @@ def test_lora_qkv(sample_tensors):
X.requires_grad = True
# Test without LoRA adapters
# pylint: disable=duplicate-code
Q1, K1, V1 = LoRA_QKV.apply(
X,
q_weight,
@@ -341,19 +330,16 @@ def test_lora_qkv(sample_tensors):
None,
None,
None,
None,
k_weight,
None,
None,
None,
None,
None,
v_weight,
None,
None,
None,
None,
None,
True,
)
@@ -370,19 +356,16 @@ def test_lora_qkv(sample_tensors):
X,
q_weight,
None,
None,
q_A,
q_B,
scale,
k_weight,
None,
None,
k_A,
k_B,
scale,
v_weight,
None,
None,
v_A,
v_B,
scale,
@@ -416,7 +399,6 @@ def test_lora_o(sample_tensors):
"""Tests LoRA output projection"""
X = sample_tensors["X"]
W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -429,7 +411,7 @@ def test_lora_o(sample_tensors):
# Test forward pass
X.requires_grad = True
output = LoRA_O.apply(X, W, b, None, A, B, scale)
output = LoRA_O.apply(X, W, None, A, B, scale)
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -443,7 +425,6 @@ def test_with_quantization(sample_tensors, mock_quantstate):
"""Tests LoRA with quantized weights"""
X = sample_tensors["X"] # [batch, seq, hidden]
W = sample_tensors["W"] # [out, hidden]
b = sample_tensors["b"] # [out]
scale = 0.5
shapes = sample_tensors["shapes"]
@@ -455,13 +436,13 @@ def test_with_quantization(sample_tensors, mock_quantstate):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test matmul with quantization
out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
assert not torch.isnan(out).any()
# Test with different batch sizes
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
assert out2.shape == (4, 6, W.shape[0])
assert not torch.isnan(out2).any()
@@ -478,12 +459,11 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
"""Tests various input shapes and dimensions"""
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
b = torch.randn(out, device="cuda", dtype=torch.float16)
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
scale = 0.5
result = matmul_lora(X, W, b, None, A, B, scale)
result = matmul_lora(X, W, None, A, B, scale)
assert result.shape == (batch, seq, out)
@@ -491,7 +471,6 @@ def test_gradient_flow(sample_tensors):
"""Tests gradient flow through LoRA layers"""
X = sample_tensors["X"].clone()
W = sample_tensors["W"].clone()
b = sample_tensors["b"].clone()
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -507,7 +486,7 @@ def test_gradient_flow(sample_tensors):
B.requires_grad = True
# Forward pass
out = matmul_lora(X, W, b, None, A, B, scale)
out = matmul_lora(X, W, None, A, B, scale)
loss = out.sum()
# Backward pass

View File

@@ -174,69 +174,6 @@ class TestFSDP2:
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_lora_sft_kernels(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_qlora_sft(self, temp_dir):
cfg = DictDefault(
@@ -299,70 +236,6 @@ class TestFSDP2:
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_qlora_sft_kernels(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(

View File

@@ -1,131 +0,0 @@
"""Integration tests for FSDP Params4bit patches."""
from unittest.mock import Mock, patch
import bitsandbytes as bnb
import pytest
import torch
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
patched_torch_function,
)
@pytest.fixture
def mock_params4bit():
"""Create a mock Params4bit instance with test attributes."""
mock_instance = Mock()
mock_instance.requires_grad = True
mock_instance.quant_state = "test_state"
mock_instance.blocksize = 128
mock_instance.compress_statistics = True
mock_instance.quant_type = "fp4"
mock_instance.quant_storage = "test_storage"
mock_instance.module = "test_module"
mock_instance.bnb_quantized = True
return mock_instance
class TestBnbTorchFunctionPatch:
"""Test the Params4bit.__torch_function__ patch."""
def test_apply_patch(self):
"""Test that the patch can be applied."""
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
apply_bnb_torch_function_patch()
assert hasattr(mock_cls, "__torch_function__")
assert isinstance(mock_cls.__torch_function__, classmethod)
# pylint: disable=redefined-outer-name
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
"""Test that torch.chunk preserves Params4bit attributes."""
mock_cls = Mock()
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
result = patched_torch_function(
mock_cls,
torch.chunk,
(type(mock_params4bit),),
args=(mock_params4bit, 2),
)
assert isinstance(result, tuple)
assert len(result) == 2
# Check that Params4bit constructor was called with preserved attributes
assert mock_cls.call_count == 2
for call in mock_cls.call_args_list:
kwargs = call[1]
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
assert kwargs["quant_state"] == mock_params4bit.quant_state
assert kwargs["blocksize"] == mock_params4bit.blocksize
# pylint: disable=redefined-outer-name
def test_other_functions_fallback(self, mock_params4bit):
"""Test that non-chunk/split functions use Parameter fallback."""
mock_cls = Mock()
fallback_result = torch.tensor([5, 6, 7])
with patch(
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
) as mock_fallback:
result = patched_torch_function(
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
)
# Should call Parameter.__torch_function__ and return its result
mock_fallback.assert_called_once()
assert result is fallback_result
mock_cls.assert_not_called()
class TestFSDPPatchIntegration:
"""Test FSDP patch integration."""
@pytest.mark.integration
def test_all_patches_together(self):
"""Test that all patches can be applied together."""
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
# Store original methods before patching
original_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
# pylint: disable=protected-access
original_init_sharded = FSDPParam._init_sharded_param
original_init_unsharded = FSDPParam.init_unsharded_param
# Apply patches
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
# Verify patches were applied
current_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
if original_torch_function is not None:
assert (
current_torch_function != original_torch_function
), "Params4bit.__torch_function__ was not patched"
else:
assert (
current_torch_function is not None
), "Params4bit.__torch_function__ was not added"
# Check that FSDP methods were patched
assert (
# pylint: disable=protected-access
FSDPParam._init_sharded_param
!= original_init_sharded
), "_init_sharded_param was not patched"
assert (
FSDPParam.init_unsharded_param != original_init_unsharded
), "init_unsharded_param was not patched"

View File

@@ -29,6 +29,7 @@ class TestFusedLlama(unittest.TestCase):
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"pad_to_sequence_len": True,
"flash_attn_fuse_qkv": True,
"flash_attn_fuse_mlp": True,
"sample_packing": True,
"sequence_len": 1024,

View File

@@ -13,7 +13,6 @@ from .utils import (
check_model_output_exists,
require_torch_2_5_1,
require_torch_2_6_0,
require_torch_2_7_0,
with_temp_dir,
)
@@ -161,49 +160,6 @@ class TestCustomOptimizers(unittest.TestCase):
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_7_0
def test_dion(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "dion",
"dion_lr": 0.01,
"dion_momentum": 0.95,
"lr_scheduler": "cosine",
"weight_decay": 0.01,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -1,28 +0,0 @@
"""Unit tests for trainer loss calc monkeypatch."""
import unittest
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
check_evaluation_loop_is_fsdp2_patchable,
check_evaluation_loop_is_patchable,
check_maybe_log_save_evaluate_is_patchable,
)
class TestTrainerLossCalc(unittest.TestCase):
"""
Unit test class for trainer loss calc monkeypatch
"""
def test_trainer_loss_calc_is_patchable(self):
"""
Test that the upstream transformers code is still patchable. This will fail if
the patched code changes upstream.
"""
assert check_evaluation_loop_is_patchable()
assert check_evaluation_loop_is_fsdp2_patchable()
assert check_maybe_log_save_evaluate_is_patchable()
if __name__ == "__main__":
unittest.main()

View File

@@ -9,7 +9,6 @@ from transformers.utils.import_utils import is_torch_mps_available
from axolotl.loaders import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import _get_parallel_config_kwargs
class TestModelsUtils:
@@ -194,13 +193,15 @@ class TestModelsUtils:
is_fsdp,
expected,
):
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
res = (
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
)
)
if expected[0] > 1:

View File

@@ -1,37 +0,0 @@
"""
test cases for axolotl.utils.import_helper
"""
import pytest
from axolotl.utils.import_helper import get_cls_from_module_str
def test_get_cls_from_module_str():
cls = get_cls_from_module_str("axolotl.core.trainers.base.AxolotlTrainer")
assert cls.__name__ == "AxolotlTrainer"
def test_get_cls_from_module_str_empty_string():
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
get_cls_from_module_str("")
def test_get_cls_from_module_str_whitespace_only():
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
get_cls_from_module_str(" ")
def test_get_cls_from_module_str_invalid_format():
with pytest.raises(ValueError, match="Invalid module string format"):
get_cls_from_module_str("single_part")
def test_get_cls_from_module_str_nonexistent_module():
with pytest.raises(ImportError, match="Failed to import module"):
get_cls_from_module_str("nonexistent.module.Class")
def test_get_cls_from_module_str_nonexistent_class():
with pytest.raises(AttributeError, match="Class 'NonExistentClass' not found"):
get_cls_from_module_str("axolotl.core.trainers.base.NonExistentClass")