Compare commits

..

8 Commits

Author SHA1 Message Date
Wing Lian
42922f8f8b register pressure estimation and pruning for h200/b200 2026-03-19 06:39:16 -04:00
Wing Lian
7041592ca7 fix casting for H200 and B200 2026-03-19 05:57:54 -04:00
Wing Lian
fec0c3a99e chore: lint 2026-03-19 07:27:23 +00:00
Wing Lian
31d8d068bb handle base+lora split kernel for older moe models 2026-03-19 07:11:30 +00:00
Wing Lian
66fea258c7 add correctness unit tests and benchmarks for scattermoe + lora 2026-03-19 06:40:04 +00:00
Wing Lian
07ff389be8 selective dequant 2026-03-19 06:40:04 +00:00
Wing Lian
2dcca15f65 more scattermoe optims 2026-03-19 06:40:04 +00:00
Wing Lian
c5db90aa3f optimize moe + lora 2026-03-19 06:40:04 +00:00
101 changed files with 454 additions and 1769 deletions

View File

@@ -128,9 +128,11 @@ quartodoc:
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils

View File

@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -3,13 +3,11 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
set -o pipefail
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
# hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning"
# hf download "microsoft/Phi-3.5-mini-instruct"
# hf download "microsoft/Phi-3-medium-128k-instruct"
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
hf download "NousResearch/Meta-Llama-3-8B"
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
hf download "microsoft/Phi-4-reasoning"
hf download "microsoft/Phi-3.5-mini-instruct"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \

View File

@@ -68,6 +68,10 @@ def run_cmd(cmd: str, run_folder: str):
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
# Propagate errors from subprocess.
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")
try:
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
print(f"Command '{cmd}' failed with exit code {exit_code}")
return exit_code
except Exception as e: # pylint: disable=broad-except
print(f"Command '{cmd}' failed with exception {e}")

View File

@@ -37,7 +37,6 @@ coverage:
only_pulls: false
flags: null
paths: null
informational: true
parsers:
gcov:

View File

@@ -1,5 +1,5 @@
---
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
title: Gradient Checkpointing and Activation Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
@@ -27,33 +27,3 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
### Enabling Layer Offloading
```yaml
layer_offloading: true
```
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
trainable adapter weights stay on GPU permanently.
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
prefetched asynchronously on a separate CUDA stream for overlap.
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
previous layer is prefetched.
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
that is kept on GPU at any given time.
**Requirements:**
- CUDA GPU (CPU-only training is not supported for this feature)
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
- Best combined with LoRA/QLoRA where most parameters are frozen

View File

@@ -20,7 +20,6 @@ format:
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Qwen3.5](#sec-qwen3-5)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
@@ -192,14 +191,6 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3.5 {#sec-qwen3-5}
```yaml
base_model: Qwen/Qwen3.5-9B
chat_template: qwen3_5
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.

View File

@@ -54,13 +54,6 @@ These techniques save VRAM by changing how activations are handled.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
### Layer Offloading
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
- **Config:** `layer_offloading: true`
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
### Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\""
]
},
{

View File

@@ -1,5 +1,8 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -24,11 +27,6 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,5 +1,8 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -24,11 +27,6 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,5 +1,9 @@
base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp
@@ -20,11 +24,6 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_model_dir:

View File

@@ -6,6 +6,9 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
Note: Training this model requires weights in BF16 which we will link to later.
Users interested in training can convert / descale the existing FP8 weights.
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -1,4 +1,4 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
base_model: mistralai/Mistral-Small-4-119B-2603
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -1,4 +1,4 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
base_model: mistralai/Mistral-Small-4-119B-2603
processor_type: AutoProcessor
plugins:

View File

@@ -1,4 +1,4 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
base_model: mistralai/Mistral-Small-4-119B-2603
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -1,4 +1,4 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
base_model: mistralai/Mistral-Small-4-119B-2603
processor_type: AutoProcessor
plugins:

View File

@@ -1,57 +0,0 @@
base_model: nvidia/Nemotron-Mini-4B-Instruct
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/nemotron-mini-4b-qlora
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- up_proj
- down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
special_tokens:

View File

@@ -1,84 +0,0 @@
base_model: Qwen/Qwen3.5-122B-A10B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -32,11 +32,7 @@ lora_target_modules:
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
#lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
@@ -56,6 +52,7 @@ learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false

View File

@@ -1,59 +0,0 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
# Freeze vision encoder
unfrozen_parameters:
- model\.language_model\..*
- lm_head\..*
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,81 +0,0 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment below to also target the linear attention projections.
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,7 +1,9 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
# the text-only path. For multimodal (image+text) fine-tuning, add image
# columns to your dataset following axolotl's multimodal dataset format.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -1,85 +0,0 @@
base_model: Qwen/Qwen3.5-35B-A3B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -32,11 +32,7 @@ lora_target_modules:
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
#lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj

View File

@@ -1,6 +1,10 @@
base_model: Qwen/Qwen3.5-9B
base_model: Qwen/Qwen3.5-7B
processor_type: AutoProcessor
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
# Vision and text tokens are processed together by the same transformer layers.
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
# These 3 lines are required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
@@ -26,6 +30,8 @@ lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# Targets the language model attention and MLP layers.
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
# the same transformer stack, so standard attention targets work for both modalities.
lora_target_modules:
- q_proj
- k_proj

View File

@@ -1,49 +0,0 @@
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Required for multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,6 +1,15 @@
# Finetune Qwen3.5 with Axolotl
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
Available configs:
| Config | Model | Type |
|---|---|---|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
## Getting started
@@ -9,69 +18,35 @@
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Pick any config from the table below and run:
```bash
axolotl train examples/qwen3.5/<config>.yaml
```
Available configs:
| Config | Model | Type | Peak VRAM |
|---|---|---|---|
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
### Gated DeltaNet Linear Attention
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
```yaml
lora_target_modules:
# ... standard projections ...
- linear_attn.in_proj_qkv
- linear_attn.in_proj_z
- linear_attn.out_proj
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
### Routed Experts (MoE)
4. Run a finetuning example:
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
```bash
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
axolotl train examples/qwen3.5/27b-qlora.yaml
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router
```
# MoE 35B-A3B text-only (QLoRA)
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
### Shared Experts (MoE)
# MoE 122B-A10B text-only (QLoRA)
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
```yaml
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# 7B vision+text (LoRA, multimodal dataset)
axolotl train examples/qwen3.5/7b-lora-vision.yaml
```
### TIPS
- For inference hyp, please see the respective model card details.
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
## Optimization Guides

View File

@@ -61,11 +61,5 @@ skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.pytest.ini_options]
addopts = "-m 'not slow'"
markers = [
"slow: marks tests as slow",
]
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"'
)

View File

@@ -81,23 +81,16 @@ def parse_requirements(extras_require_map):
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 10):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.5.0",
"fbgemm-gpu-genai==1.5.0",
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.17.1"]
elif (major, minor) >= (2, 9):
if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:

View File

@@ -3,7 +3,6 @@
import os
from pathlib import Path
import httpcore
from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
@@ -48,7 +47,7 @@ def check_user_token() -> bool:
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except (HTTPError, httpcore.ConnectError):
except HTTPError:
LOG.warning(
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
)

View File

@@ -353,30 +353,6 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adamw":
from flashoptim import FlashAdamW
optimizer_cls = FlashAdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adam":
from flashoptim import FlashAdam
optimizer_cls = FlashAdam
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_sgd":
from flashoptim import FlashSGD
optimizer_cls = FlashSGD
elif self.cfg.optimizer == "flash_sgdw":
from flashoptim import FlashSGDW
optimizer_cls = FlashSGDW
elif self.cfg.optimizer == "flash_lion":
from flashoptim import FlashLion
optimizer_cls = FlashLion
if "betas" in adam_kwargs:
optimizer_kwargs["betas"] = adam_kwargs["betas"]
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
@@ -508,8 +484,6 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.layer_offloading:
training_args_kwargs["layer_offloading"] = True
if self.cfg.activation_offloading is True:
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False

View File

@@ -421,13 +421,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
# config reflects this regardless of how the model was instantiated.
if (
self.cfg.reward_model
and getattr(self.model.config, "num_labels", None) != 1
):
self.model.config.num_labels = 1
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,

View File

@@ -208,11 +208,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if (
self.cfg.adapter
and self.peft_config
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
):
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (

View File

@@ -29,12 +29,10 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.experimental.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
DistributedParallelMixin,
LayerOffloadingMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
@@ -53,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state."
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
@@ -67,7 +67,6 @@ class AxolotlTrainer(
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
LayerOffloadingMixin,
ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer,

View File

@@ -1 +0,0 @@
TOKENS_STATE_FILE = "tokens_state.json"

View File

@@ -2,8 +2,7 @@
Axolotl specific DPO args
"""
from dataclasses import dataclass, field
from typing import Optional
from dataclasses import dataclass
from trl import DPOConfig
@@ -17,4 +16,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
dpo_norm_loss: bool | None = False
rpo_alpha: Optional[float] = field(default=None)

View File

@@ -4,7 +4,6 @@
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .layer_offloading import LayerOffloadingMixin
from .distributed_parallel import DistributedParallelMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin

View File

@@ -1,304 +0,0 @@
"""
Trainer mixin for layer-wise parameter offloading to CPU.
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
transfer stream), post-hook offloads layer N-1's frozen params.
Backward: same in reverse order.
"""
import contextlib
import torch
import torch.nn as nn
from transformers import Trainer
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
"""Recursively search the model for the decoder layer ModuleList.
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
where layers are at model.language_model.layers).
"""
# BFS to find the first ModuleList containing decoder layers
queue = [model]
while queue:
m = queue.pop(0)
for _name, child in m.named_children():
if isinstance(child, nn.ModuleList) and len(child) > 0:
first_type = type(child[0]).__name__
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
layer_types = list({type(layer).__name__ for layer in child})
return child, layer_types
else:
queue.append(child)
return None, []
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
"""Get all non-trainable parameters in a layer."""
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
class LayerOffloadManager:
"""Manages offloading frozen decoder layer params to CPU and streaming
them back during forward/backward with CUDA stream overlap.
Only frozen (requires_grad=False) parameters are offloaded.
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
"""
def __init__(
self,
model: nn.Module,
num_prefetch: int = 1,
):
self.model = model
self.num_prefetch = num_prefetch
self._hooks: list = []
self._device = None
# Find decoder layers
self.layers, layer_types = _find_decoder_layers(model)
if self.layers is None:
LOG.warning(
"LayerOffloadManager: no decoder layers found, offloading disabled"
)
self.enabled = False
return
self.enabled = True
self.n_layers = len(self.layers)
LOG.info(
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
)
# Determine GPU device
for p in model.parameters():
if p.device.type == "cuda":
self._device = p.device
break
if self._device is None:
LOG.warning("LayerOffloadManager: no CUDA parameters found")
self.enabled = False
return
# Transfer stream for async prefetch
self._transfer_stream = torch.cuda.Stream(device=self._device)
# Track which layers have their frozen params on GPU
self._on_gpu: set[int] = set(range(self.n_layers))
# Cache: frozen param references per layer (list of (name, param) tuples)
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
]
# CPU storage: pinned tensors for each layer's frozen params
# Populated on first offload
self._cpu_data: list[dict[str, torch.Tensor]] = [
{} for _ in range(self.n_layers)
]
# Offload all layers upfront
self._offload_all()
# Release cached memory blocks back to the driver
torch.cuda.empty_cache()
def _offload_all(self):
"""Move all frozen params in all decoder layers to CPU."""
mem_before = torch.cuda.memory_allocated(self._device)
for i in range(self.n_layers):
self._offload_layer(i)
mem_after = torch.cuda.memory_allocated(self._device)
freed = (mem_before - mem_after) / 1e6
LOG.info(
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
f"freed {freed:.0f} MB GPU memory"
)
def _offload_layer(self, idx: int):
"""Move frozen params of layer idx to CPU pinned memory."""
if idx not in self._on_gpu:
return
for name, param in self._frozen_params[idx]:
if param.device.type != "cuda":
continue
# Allocate pinned CPU tensor on first offload
if name not in self._cpu_data[idx]:
self._cpu_data[idx][name] = torch.empty_like(
param.data, device="cpu", pin_memory=True
)
cpu_buf = self._cpu_data[idx][name]
# Async copy GPU -> CPU (on transfer stream for overlap)
cpu_buf.copy_(param.data, non_blocking=True)
# Point parameter at a dummy CPU tensor to free GPU memory
param.data = cpu_buf
self._on_gpu.discard(idx)
def _load_layer(self, idx: int, stream=None):
"""Move frozen params of layer idx back to GPU."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
ctx = (
torch.cuda.stream(stream)
if stream is not None
else contextlib.nullcontext()
)
with ctx:
for _name, param in self._frozen_params[idx]:
if param.device.type == "cuda":
continue
gpu_data = param.data.to(self._device, non_blocking=True)
param.data = gpu_data
self._on_gpu.add(idx)
def _prefetch_layer(self, idx: int):
"""Async prefetch layer idx on the transfer stream."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
self._load_layer(idx, stream=self._transfer_stream)
def _wait_transfer(self):
"""Make default stream wait for any in-flight transfers."""
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
def setup_hooks(self):
"""Register forward and backward hooks on each decoder layer."""
if not self.enabled:
return
for idx in range(self.n_layers):
layer = self.layers[idx]
def make_pre_fwd(i):
def hook(module, args):
# Ensure this layer is on GPU
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch next layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i + offset)
return hook
def make_post_fwd(i):
def hook(module, args, output):
# Offload previous layer (no longer needed in forward)
if i > 0:
self._offload_layer(i - 1)
# Offload last layer after forward
if i == self.n_layers - 1:
self._offload_layer(i)
return hook
def make_pre_bwd(i):
def hook(module, grad_output):
# Load this layer for backward
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch previous layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i - offset)
return hook
def make_post_bwd(i):
def hook(module, grad_input, grad_output):
# Offload the layer above
if i < self.n_layers - 1:
self._offload_layer(i + 1)
# Offload first layer after backward
if i == 0:
self._offload_layer(i)
return hook
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
h2 = layer.register_forward_hook(make_post_fwd(idx))
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
self._hooks.extend([h1, h2, h3, h4])
def remove_hooks(self):
"""Remove all hooks and restore layers to GPU."""
for h in self._hooks:
h.remove()
self._hooks.clear()
if self.enabled:
for i in range(self.n_layers):
if i not in self._on_gpu:
self._load_layer(i)
def pre_step(self):
"""Called before each training step — ensure layers start offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for forward
self._prefetch_layer(0)
def post_step(self):
"""Called after each training step — ensure layers are offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for next step
self._prefetch_layer(0)
class _LayerOffloadContext:
"""Context manager wrapping pre_step / post_step around a training step."""
def __init__(self, manager: LayerOffloadManager):
self.manager = manager
def __enter__(self):
self.manager.pre_step()
return self
def __exit__(self, *args):
self.manager.post_step()
class LayerOffloadingMixin(Trainer):
"""
Trainer mixin class for layer-wise parameter offloading to CPU.
Offloads frozen decoder layer params to CPU at init, then streams them
on/off GPU one layer at a time during each training step.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if getattr(self.args, "layer_offloading", False):
LOG.info("Layer parameter offloading enabled")
self._layer_offload_manager = LayerOffloadManager(
model=self.model,
num_prefetch=1,
)
self._layer_offload_manager.setup_hooks()
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
else:
self._layer_offload_manager = None
self._layer_offload_ctx = contextlib.nullcontext()
def training_step(self, *args, **kwargs):
with self._layer_offload_ctx:
return super().training_step(*args, **kwargs)

View File

@@ -235,13 +235,6 @@ class AxolotlTrainingMixins:
metadata={"help": "Use activation offloading with CUDA streams for training."},
)
layer_offloading: bool | None = field(
default=None,
metadata={
"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
```
## Usage

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
)

View File

@@ -15,7 +15,6 @@ SPARSE_MOE_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock",
"qwen3_next": "Qwen3NextSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
@@ -59,16 +58,7 @@ def resolve_moe_block_classes(model_type: str):
cls_names = entry if isinstance(entry, list) else [entry]
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
try:
module = importlib.import_module(module_path)
except ModuleNotFoundError:
# Text sub-model types (e.g. qwen3_5_moe_text) share the parent module
if model_type.endswith("_text"):
parent_type = model_type.removesuffix("_text")
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
module = importlib.import_module(module_path)
else:
raise
module = importlib.import_module(module_path)
classes = []
for cls_name in cls_names:

View File

@@ -199,30 +199,24 @@ def _estimate_register_pressure(
num_warps: int,
*tile_sizes: tuple[int, int],
) -> float:
"""Rough estimate of per-thread register footprint from live tile sizes.
"""Estimate per-thread register count from live tile sizes.
This is a heuristic, NOT an accurate register count. Triton uses tensor
core MMA fragments that pack multiple elements per register, and can spill
to local memory when the hardware limit (255 regs/thread) is exceeded.
The estimate is used to prune only truly extreme configs that would cause
excessive spilling or compilation failures. The threshold is set high
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
(est ~520) work fine in practice via spilling.
Each tile of shape (rows, cols) requires rows*cols elements distributed
across 32 threads per warp, but each thread in the warp holds a fragment.
For Triton GEMM-style kernels, the register footprint per thread is
approximately sum(rows * cols) / 32 for each live tile, plus ~40 for
scalar overhead (loop counters, pointers, masks, etc.).
Returns estimated registers per thread.
"""
# Each thread in a warp holds ~1/32 of the tile elements
# Each thread in a warp holds 1/32 of the tile elements
tile_regs = sum(r * c for r, c in tile_sizes) / 32
scalar_overhead = 40
return tile_regs + scalar_overhead
# Soft limit for register pressure pruning. Only prune configs with extreme
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
_MAX_REGS_SOFT_LIMIT = 1024
# Maximum registers per thread on NVIDIA GPUs
_MAX_REGS_PER_THREAD = 255
# =============================================================================
@@ -363,7 +357,7 @@ def _scatter2scatter_lora_configs():
Search space:
BLOCK_M: {32, 64, 128}
BLOCK_N: {32, 64}
BLOCK_N: {32, 64, 128, 256}
BLOCK_K: {32, 64, 128}
num_warps: {4, 8}
num_stages: {3, 4, 5}
@@ -371,7 +365,7 @@ def _scatter2scatter_lora_configs():
configs = []
for block_m, block_n, block_k, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64], # BLOCK_N
[32, 64, 128, 256], # BLOCK_N
[32, 64, 128], # BLOCK_K
[4, 8], # num_warps
[3, 4, 5], # num_stages
@@ -425,7 +419,7 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
(block_r, block_k), # a tile
(block_n, block_r), # b tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
if est_regs > _MAX_REGS_PER_THREAD:
continue
scored.append((smem, config))
@@ -943,16 +937,16 @@ def _scatter2scatter_lora_dX_configs():
Search space:
BLOCK_M: {32, 64, 128} (token tile)
BLOCK_K: {32, 64, 128} (output tile)
BLOCK_N: {32, 64} (reduction tile)
BLOCK_K: {32, 64, 128, 256} (output tile)
BLOCK_N: {32, 64, 128, 256} (reduction tile)
num_warps: {4, 8}
num_stages: {3, 4, 5}
"""
configs = []
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128], # BLOCK_K (output dimension)
[32, 64], # BLOCK_N (reduction dimension)
[32, 64, 128, 256], # BLOCK_K (output dimension)
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
@@ -1005,7 +999,7 @@ def _prune_dX_configs(configs, named_args, **kwargs):
(block_n, block_r), # b tile
(block_r, block_k), # a tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
if est_regs > _MAX_REGS_PER_THREAD:
continue
scored.append((smem, config))
@@ -1278,9 +1272,9 @@ def _group_bwd_lora_configs():
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space:
BLOCK_M: {32, 64, 128} (token-loop tile)
BLOCK_K: {32, 64, 128}
BLOCK_N: {32, 64}
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
BLOCK_K: {32, 64, 128, 256}
BLOCK_N: {32, 64, 128, 256}
num_warps: {4, 8}
num_stages: {3, 4, 5}
@@ -1289,9 +1283,9 @@ def _group_bwd_lora_configs():
"""
configs = []
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128], # BLOCK_K
[32, 64], # BLOCK_N
[32, 64, 128, 256], # BLOCK_M
[32, 64, 128, 256], # BLOCK_K
[32, 64, 128, 256], # BLOCK_N
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
@@ -1338,7 +1332,7 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
(block_n, block_r), # b tile
(block_m, block_r), # xa intermediate
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
if est_regs > _MAX_REGS_PER_THREAD:
continue
scored.append((smem, config))
@@ -1587,7 +1581,7 @@ def _prune_split_configs(configs, named_args, **kwargs):
(block_m, block_dim), # other tile
(block_r, BLOCK_INNER), # lora weight
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
if est_regs > _MAX_REGS_PER_THREAD:
continue
if smem <= smem_cap - _SMEM_SLACK:

View File

@@ -640,9 +640,7 @@ class LoRA_QKV(torch.autograd.Function):
del q_weight
del q_weight_t
if A_q is not None and B_q is not None:
# Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in]
# This is 65x fewer FLOPs than materializing B@A into [out, in]
grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled)
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path
k_weight_t = dequantize(k_weight, k_quant)
@@ -650,7 +648,7 @@ class LoRA_QKV(torch.autograd.Function):
del k_weight
del k_weight_t
if A_k is not None and B_k is not None:
grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled)
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path
v_weight_t = dequantize(v_weight, v_quant)
@@ -658,7 +656,7 @@ class LoRA_QKV(torch.autograd.Function):
del v_weight
del v_weight_t
if A_v is not None and B_v is not None:
grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled)
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
# Transpose gradients if needed
if d_A_q is not None:
@@ -821,8 +819,7 @@ class LoRA_O(torch.autograd.Function):
del W
A, B = A.to(dtype), B.to(dtype)
# Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in]
dX.addmm_(torch.mm(dY, B), A, alpha=s)
dX += s * dY @ B @ A
# W, b, W_quant, A, B, s
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None

View File

@@ -571,6 +571,15 @@ class PatchManager:
LOG.info("Patching with xformers attention...")
hijack_llama_attention()
def _patch_llama_sample_packing(self):
"""Apply sample packing patches for LLaMA models."""
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
hijack_llama_prepare_4d_mask()
def _patch_llama_derived_model(self):
"""Modify all llama derived models in one block."""
if self.cfg.is_llama_derived_model and not (
@@ -582,6 +591,8 @@ class PatchManager:
self._patch_llama_flash_attention()
elif self.cfg.xformers_attention:
self._patch_llama_xformers_attention()
elif self.cfg.sample_packing:
self._patch_llama_sample_packing()
elif self.cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."

View File

@@ -221,14 +221,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()

View File

@@ -78,21 +78,30 @@ def patch_parallelism_config():
def patch_prepare_cp():
import contextlib
import functools
import torch
from accelerate import Accelerator
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
@contextlib.contextmanager
def _noop_cp_context(
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
):
yield
from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
self._cp_context = _noop_cp_context
return args
Accelerator._prepare_cp = patched_prepare_cp

View File

@@ -0,0 +1,24 @@
"""
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
"""
from typing import Optional
import torch
from axolotl.monkeypatch.utils import mask_2d_to_4d
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def hijack_expand_mask():
import transformers
transformers.models.llama.modeling_llama._expand_mask = _expand_mask

View File

@@ -0,0 +1,26 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
from axolotl.monkeypatch.utils import (
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
def hijack_llama_prepare_4d_mask():
from transformers import modeling_attn_mask_utils
from transformers.models.llama import modeling_llama
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = (
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = (
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_llama._prepare_4d_causal_attention_mask = (
patched_prepare_4d_causal_attention_mask
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = (
patched_prepare_4d_causal_attention_mask
)

View File

@@ -51,29 +51,6 @@ QKV_PATCHES = [
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
(
"""
query_states, gate = torch.chunk(
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states, gate = torch.chunk(
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
]
ORIGINAL_O_CODE = """
@@ -322,8 +299,6 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
if hasattr(pretrained_model, "language_model"):
return pretrained_model.language_model.layers
if hasattr(pretrained_model, "model"):
if hasattr(pretrained_model.model, "language_model"):
return pretrained_model.model.language_model.layers
return pretrained_model.model.layers
raise NotImplementedError(

View File

@@ -59,7 +59,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"ministral3",
"mistral4",
"afmoe",
"nemotron",
]

View File

@@ -3,10 +3,15 @@ Shared utils for the monkeypatches
"""
import re
from typing import Tuple
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
@@ -165,6 +170,65 @@ def set_module_name(model, name, value):
setattr(parent, child_name, value)
def mask_2d_to_4d(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
return masked_zero_one_mask
def patched_prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def patched_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)

View File

@@ -1,96 +0,0 @@
"""
Synthetic dataset generator for benchmarking and testing.
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
Useful for benchmarking memory usage and speed by sequence length, and for validating
weighted dataset mixes.
YAML configuration example:
datasets:
- path: synthetic
type: _synthetic
length: 1000
sequence_length: 2048
min_input_id: 100
max_input_id: 32000
seed: 42
"""
from typing import Any, Dict, Optional
import numpy as np
from datasets import Dataset
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
def __init__(
self,
sequence_length: int = 2048,
length: int = 1000,
min_input_id: int = 100,
max_input_id: int = 32000,
seed: Optional[int] = None,
):
self.sequence_length = sequence_length
self.length = length
self.min_input_id = min_input_id
self.max_input_id = max_input_id
self.seed = seed
def wrap_dataset(
self,
dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
) -> Dataset:
LOG.info(
f"Generating synthetic dataset: {self.length} samples, "
f"sequence_length={self.sequence_length}, "
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
)
rng = np.random.default_rng(self.seed)
input_ids = rng.integers(
low=self.min_input_id,
high=self.max_input_id,
size=(self.length, self.sequence_length),
).tolist()
attention_mask = [[1] * self.sequence_length] * self.length
# labels == input_ids means we train on all tokens
labels = [row[:] for row in input_ids]
return Dataset.from_dict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
length = ds_cfg.get("length", 1000)
min_input_id = ds_cfg.get("min_input_id", 100)
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
seed = ds_cfg.get("seed", None)
return SyntheticDatasetStrategy(
sequence_length=sequence_length,
length=length,
min_input_id=min_input_id,
max_input_id=max_input_id,
seed=seed,
)

View File

@@ -82,7 +82,7 @@ def setup_model_and_tokenizer(
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
model, peft_config = model_loader.load()
if getattr(model, "generation_config", None) is not None:
if model.generation_config is not None:
model.generation_config.do_sample = True
model_properties = model.config.to_dict()

View File

@@ -25,11 +25,9 @@ def toggle_fake_quant(mod: nn.Module, enable: bool):
if (
isinstance(mod, FakeQuantizedLinear)
and mod.activation_fake_quantizer is not None
and hasattr(mod.activation_fake_quantizer, "enabled")
):
mod.activation_fake_quantizer.enabled = enable
if hasattr(mod.weight_fake_quantizer, "enabled"):
mod.weight_fake_quantizer.enabled = enable
mod.weight_fake_quantizer.enabled = enable
class QATCallback(TrainerCallback):

View File

@@ -12,11 +12,12 @@ from transformers import (
TrainingArguments,
)
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state.json"
class TokensPerSecondCallback(TrainerCallback):
"""

View File

@@ -22,12 +22,7 @@ from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.datasets import (
DPODataset,
KTODataset,
SFTDataset,
SyntheticDataset,
)
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
LOG = get_logger(__name__)
@@ -313,14 +308,6 @@ def validate_config(
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
elif (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
cfg["datasets"][idx] = SyntheticDataset(
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
)
elif not isinstance(ds_cfg, SFTDataset):
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))

View File

@@ -376,14 +376,10 @@ def _load_and_process_single_dataset(
streaming: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# For synthetic datasets, create a minimal placeholder instead of loading from path
if dataset_config.type == "_synthetic":
dataset = Dataset.from_dict({"text": [""]})
else:
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)

View File

@@ -10,11 +10,9 @@ from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
)
from torchao.quantization.qat.fake_quantize_config import Int4WeightFakeQuantizeConfig
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
)
@@ -175,70 +173,6 @@ def quantize_model(
)
def _make_qat_config(
base_config: AOBaseConfig,
weight_dtype: TorchAOQuantDType,
activation_dtype: TorchAOQuantDType | None,
group_size: int | None,
) -> QATConfig:
"""Build a QATConfig, explicitly constructing fake quantize configs to ensure
group_size and other params are properly propagated (torchao's QATConfig(base_config)
does not always map these correctly)."""
from torchao.quantization.qat.fake_quantize_config import (
Float8FakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
if isinstance(base_config, MXFakeQuantizeConfig):
return QATConfig(
activation_config=base_config,
weight_config=base_config,
)
# Build explicit weight config
weight_fq_config: (
Int4WeightFakeQuantizeConfig
| IntxFakeQuantizeConfig
| Float8FakeQuantizeConfig
| None
) = None
if weight_dtype == TorchAOQuantDType.int4:
gs = (
group_size
if group_size is not None
else getattr(base_config, "group_size", 128)
)
activation_dt = None
if activation_dtype == TorchAOQuantDType.int8:
activation_dt = torch.bfloat16
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
activation_dt = torch.float8_e4m3fn
kwargs = {"group_size": gs}
if activation_dt is not None:
kwargs["activation_dtype"] = activation_dt
weight_fq_config = Int4WeightFakeQuantizeConfig(**kwargs)
elif weight_dtype == TorchAOQuantDType.float8_e4m3fn:
weight_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
# Build explicit activation config
activation_fq_config = None
if activation_dtype == TorchAOQuantDType.int8:
activation_fq_config = IntxFakeQuantizeConfig(
dtype=torch.int8, granularity="per_token", is_symmetric=False
)
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
activation_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
if weight_fq_config is not None:
return QATConfig(
weight_config=weight_fq_config,
activation_config=activation_fq_config,
)
# Fallback to base_config for unhandled combos
return QATConfig(base_config)
def prepare_model_for_qat(
model,
weight_dtype: TorchAOQuantDType,
@@ -266,9 +200,13 @@ def prepare_model_for_qat(
activation_dtype=activation_dtype,
group_size=group_size,
)
qat_config = _make_qat_config(
base_config, weight_dtype, activation_dtype, group_size
)
if isinstance(base_config, MXFakeQuantizeConfig):
qat_config = QATConfig(
activation_config=base_config,
weight_config=base_config,
)
else:
qat_config = QATConfig(base_config)
quantize_(model, qat_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
@@ -277,9 +215,12 @@ def prepare_model_for_qat(
activation_dtype=None,
group_size=group_size,
)
embedding_qat_config = _make_qat_config(
embedding_base_config, weight_dtype, None, group_size
)
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
embedding_qat_config = QATConfig(
weight_config=embedding_base_config,
)
else:
embedding_qat_config = QATConfig(embedding_base_config)
quantize_(
model,
embedding_qat_config,

View File

@@ -2,7 +2,7 @@
import math
from functools import partial
from typing import Any, Sequence
from typing import Sequence
from torch import Tensor
from torch.optim import Optimizer
@@ -340,19 +340,3 @@ class JaggedLRRestartScheduler(LRScheduler):
return [lr * scale for lr in original]
return original * scale
def state_dict(self) -> dict[str, Any]:
"""Return serializable state, saving inner_schedule as its own state_dict."""
state = {
key: value
for key, value in self.__dict__.items()
if key not in ("optimizer", "inner_schedule")
}
state["inner_schedule_state"] = self.inner_schedule.state_dict()
return state
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Restore state, including inner_schedule."""
inner_state = state_dict.pop("inner_schedule_state")
self.__dict__.update(state_dict)
self.inner_schedule.load_state_dict(inner_state)

View File

@@ -22,7 +22,6 @@ from axolotl.utils.schemas.datasets import (
PretrainingDataset,
SFTDataset,
StepwiseSupervisedDataset,
SyntheticDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
@@ -186,13 +185,7 @@ class AxolotlInputConfig(
datasets: (
Annotated[
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
MinLen(1),
]
| None
@@ -205,13 +198,7 @@ class AxolotlInputConfig(
test_datasets: (
Annotated[
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
MinLen(1),
]
| None
@@ -446,12 +433,6 @@ class AxolotlInputConfig(
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
},
)
layer_offloading: bool | None = Field(
default=False,
json_schema_extra={
"description": "Offload model layer parameters to CPU during forward, prefetch back during backward."
},
)
unfrozen_parameters: list[str] | None = Field(
default=None,

View File

@@ -296,42 +296,4 @@ class KTODataset(BaseModel):
revision: str | None = None
class SyntheticDataset(BaseModel):
"""Synthetic dataset configuration for benchmarking and testing.
Generates datasets with configurable sequence length, dataset size, and token ID
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
validating weighted dataset mixes.
"""
path: Literal["synthetic"] = "synthetic"
type: Literal["_synthetic"] = "_synthetic"
length: int = Field(
default=1000,
json_schema_extra={"description": "Number of rows to generate"},
)
sequence_length: int | None = Field(
default=None,
json_schema_extra={
"description": "Sequence length per row (defaults to sequence_len from config)"
},
)
min_input_id: int = Field(
default=100,
json_schema_extra={"description": "Minimum token ID for generation"},
)
max_input_id: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
},
)
seed: int | None = Field(
default=None,
json_schema_extra={"description": "Random seed for reproducibility"},
)
DatasetConfig = (
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
)
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset

View File

@@ -87,11 +87,6 @@ class CustomSupportedOptimizers(str, Enum):
came_pytorch = "came_pytorch"
muon = "muon"
dion = "dion"
flash_adamw = "flash_adamw"
flash_adam = "flash_adam"
flash_sgd = "flash_sgd"
flash_sgdw = "flash_sgdw"
flash_lion = "flash_lion"
class RingAttnFunc(str, Enum):

View File

@@ -253,23 +253,6 @@ class TrainingValidationMixin:
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before")
@classmethod
def set_reward_model_defaults(cls, data):
if data.get("reward_model"):
if data.get("num_labels") is None:
data["num_labels"] = 1
if not (data.get("type_of_model") or data.get("model_type")):
data["model_type"] = "AutoModelForSequenceClassification"
if data.get("process_reward_model"):
if data.get("num_labels") is None:
data["num_labels"] = 2
if not (data.get("type_of_model") or data.get("model_type")):
data["model_type"] = "AutoModelForTokenClassification"
return data
@model_validator(mode="before")
@classmethod
def check_gas_bsz(cls, data):
@@ -790,14 +773,6 @@ class OptimizationValidationMixin:
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@staticmethod
def _resolve_fsdp_version(data):
"""Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version."""
fsdp_version = data.get("fsdp_version")
if fsdp_version is None:
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
return fsdp_version
@model_validator(mode="before")
@classmethod
def check_muon_deepspeed_fsdp(cls, data):
@@ -807,32 +782,15 @@ class OptimizationValidationMixin:
"Muon optimizer is currently incompatible with DeepSpeed"
)
if data.get("fsdp") or data.get("fsdp_config"):
fsdp_version = cls._resolve_fsdp_version(data)
fsdp_version = data.get("fsdp_version")
if fsdp_version is None:
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
if str(fsdp_version) != "2":
raise ValueError(
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
)
return data
@model_validator(mode="before")
@classmethod
def check_flashoptim_deepspeed_fsdp(cls, data):
optimizer = data.get("optimizer") or ""
if str(optimizer).startswith("flash_"):
if data.get("deepspeed"):
raise ValueError(
f"{optimizer} optimizer is incompatible with DeepSpeed. "
"Flash optimizers only support DDP and FSDP2."
)
if data.get("fsdp") or data.get("fsdp_config"):
fsdp_version = cls._resolve_fsdp_version(data)
if str(fsdp_version) != "2":
raise ValueError(
f"{optimizer} optimizer is only compatible with FSDP2. "
"Set fsdp_version: 2 to use flash optimizers with FSDP."
)
return data
@model_validator(mode="before")
@classmethod
def check_batch_flattening_fa(cls, data):

View File

@@ -15,8 +15,6 @@ import datasets
import pytest
import requests
import torch
import transformers.utils as _transformers_utils
import transformers.utils.import_utils as _import_utils
from huggingface_hub import snapshot_download
from huggingface_hub.errors import LocalEntryNotFoundError
from tokenizers import AddedToken
@@ -31,26 +29,6 @@ from tests.hf_offline_utils import (
logging.getLogger("filelock").setLevel(logging.CRITICAL)
# Shim for deepseek v3
if not hasattr(_import_utils, "is_torch_fx_available"):
def _is_torch_fx_available():
try:
import torch.fx # noqa: F401 # pylint: disable=unused-import
return True
except ImportError:
return False
_import_utils.is_torch_fx_available = _is_torch_fx_available
if not hasattr(_transformers_utils, "is_flash_attn_greater_or_equal_2_10"):
from transformers.utils import is_flash_attn_greater_or_equal as _is_flash_attn_gte
_transformers_utils.is_flash_attn_greater_or_equal_2_10 = lambda: (
_is_flash_attn_gte("2.10")
)
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):

View File

@@ -536,7 +536,7 @@ class TestHFCausalTrainerBuilder:
"cfg_string",
[
"sft_cfg",
"rm_cfg",
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
"prm_cfg",
],
)

View File

@@ -20,7 +20,6 @@ Test strategy:
- Tolerances account for tf32 accumulation in Triton kernels
"""
from functools import wraps
from types import SimpleNamespace
import pytest
@@ -35,21 +34,6 @@ pytestmark = pytest.mark.skipif(
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
def skip_on_out_of_resources(func):
"""Skip test if Triton kernel exceeds GPU shared memory limits."""
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as exc: # pylint: disable=broad-except
if "OutOfResources" in type(exc).__name__:
pytest.skip(f"GPU shared memory too small: {exc}")
raise
return wrapper
# =============================================================================
# Helpers
# =============================================================================
@@ -225,7 +209,6 @@ def make_test_data(
# =============================================================================
@pytest.mark.slow
class TestForwardPass:
"""Test forward pass of fused scatter2scatter_lora kernel."""
@@ -305,7 +288,6 @@ class TestForwardPass:
)
@pytest.mark.slow
class TestForwardGrouped:
"""Test forward pass with grouped_in/grouped_out configurations."""
@@ -395,7 +377,6 @@ class TestForwardGrouped:
# =============================================================================
@pytest.mark.slow
class TestLoRAGradients:
"""Test backward LoRA gradient computation (dA, dB)."""
@@ -471,7 +452,6 @@ class TestLoRAGradients:
# =============================================================================
@pytest.mark.slow
class TestAutograd:
"""Test full autograd integration through ScatterMoELoRA."""
@@ -640,7 +620,6 @@ class TestAutograd:
# =============================================================================
@pytest.mark.slow
class TestBaseEquivalence:
"""When scaling=0, fused kernel should match base scatter2scatter."""
@@ -713,7 +692,6 @@ class TestBaseEquivalence:
# =============================================================================
@pytest.mark.slow
class TestLoRAAdditivity:
"""Test that the LoRA component is correctly additive."""
@@ -771,7 +749,6 @@ class TestLoRAAdditivity:
# =============================================================================
@pytest.mark.slow
class TestParallelExpertsModule:
"""Test the ParallelExperts module with LoRA."""
@@ -839,7 +816,6 @@ class TestParallelExpertsModule:
# =============================================================================
@pytest.mark.slow
class TestEdgeCases:
"""Edge cases and boundary conditions."""
@@ -937,7 +913,6 @@ class TestEdgeCases:
# =============================================================================
@pytest.mark.slow
class TestFusedDX:
"""Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A."""
@@ -1005,7 +980,6 @@ class TestFusedDX:
def test_basic(self):
self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2)
@skip_on_out_of_resources
def test_large(self):
self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2)
@@ -1148,7 +1122,6 @@ class TestFusedDX:
# =============================================================================
@pytest.mark.slow
class TestFusedGatherBackward:
"""Test fused gather + backward dA/dB kernel."""
@@ -1201,7 +1174,6 @@ class TestFusedGatherBackward:
def test_basic(self):
self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2)
@skip_on_out_of_resources
def test_large(self):
self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2)
@@ -1211,7 +1183,6 @@ class TestFusedGatherBackward:
def test_k1(self):
self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1)
@skip_on_out_of_resources
def test_many_experts(self):
self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4)
@@ -1298,8 +1269,6 @@ class TestFusedGatherBackward:
# =============================================================================
@pytest.mark.slow
@pytest.mark.xfail(reason="flaky", strict=False)
class TestTokenRounding:
"""Test token rounding utility and its integration with backward kernels."""
@@ -1346,7 +1315,6 @@ class TestTokenRounding:
)
prev = padded_offsets[e].item()
@skip_on_out_of_resources
def test_round_with_fused_gather(self):
"""Token rounding + fused gather gives same result as plain fused gather."""
from importlib import import_module
@@ -1446,7 +1414,6 @@ class TestTokenRounding:
# =============================================================================
@pytest.mark.slow
class TestCombinedOptimizations:
"""Test all optimizations together."""
@@ -1616,7 +1583,6 @@ def _make_mock_sigmoid_moe_block(
return moe_block, T, H, FF, E, K
@pytest.mark.slow
class TestHFScatterMoESigmoidRouting:
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
@@ -1758,7 +1724,6 @@ class TestHFScatterMoESigmoidRouting:
)
@pytest.mark.slow
class TestHFScatterMoESigmoidWithSharedExperts:
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""

View File

@@ -933,7 +933,7 @@ class TestKernelizeIntegration:
def _get_repo_path():
"""Get the path to scattermoe_lora within axolotl's plugin."""
return (
Path(__file__).parent.parent.parent.parent
Path(__file__).parent.parent.parent
/ "src"
/ "axolotl"
/ "integrations"
@@ -1219,7 +1219,7 @@ class TestSharedExpertHandling:
# Kernelize
repo_path = (
Path(__file__).parent.parent.parent.parent
Path(__file__).parent.parent.parent
/ "src"
/ "axolotl"
/ "integrations"

View File

@@ -86,5 +86,5 @@ class TestPackedFlex:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -37,7 +37,7 @@ def verify_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/loss"]
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -37,7 +37,7 @@ def verify_fp8_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/loss"]
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -38,7 +38,7 @@ def verify_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/loss"]
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -38,7 +38,7 @@ def verify_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/loss"]
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -94,5 +94,5 @@ class TestMultiGPUGemma3:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.8, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high"
)

View File

@@ -90,7 +90,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.8, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.8, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -156,7 +156,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
def test_dpo_lora_ddp(self, temp_dir):
@@ -233,7 +233,7 @@ class TestMultiGPULlama:
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
)
@@ -312,7 +312,7 @@ class TestMultiGPULlama:
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
)
@@ -385,7 +385,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -461,7 +461,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_6_0
@@ -543,7 +543,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
@@ -623,7 +623,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -708,7 +708,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.45, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -784,7 +784,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -859,7 +859,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
)
@pytest.mark.skip(
@@ -925,5 +925,5 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 4.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high"
)

View File

@@ -79,7 +79,7 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_7_0
@@ -138,7 +138,7 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_7_0
@@ -205,5 +205,5 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)

View File

@@ -64,5 +64,5 @@ class TestTensorParallel:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
)

View File

@@ -78,5 +78,5 @@ class TestFAXentropyLlama:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -77,5 +77,5 @@ class TestFAFlattening:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -4,7 +4,8 @@ E2E tests for lora llama
import unittest
from transformers.utils import is_torch_bf16_gpu_available
import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -67,3 +68,51 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
def test_lora_gptq_packed(self, temp_dir):
cfg = DictDefault(
{
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
"gptq_disable_exllama": True,
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"max_steps": 20,
"save_steps": 0.5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import subprocess
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.datasets import load_datasets
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.train import train
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault

View File

@@ -73,7 +73,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
@@ -124,7 +124,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -180,5 +180,5 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -63,5 +63,5 @@ class TestPackedFlex(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -14,9 +14,6 @@ from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.mark.skip(
reason="DeepSeek-V3-11M remote model code needs _supports_flash_attn=True for newer transformers"
)
class TestDeepseekV3:
"""
Test case for DeepseekV3 models

View File

@@ -262,7 +262,6 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip(reason="TRL ORPO trainer has internal zip() length mismatch bug")
@with_temp_dir
def test_orpo_lora(self, temp_dir):
cfg = DictDefault(

View File

@@ -57,7 +57,9 @@ class TestEmbeddingsLrScale(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(temp_dir + "/runs", "train/loss", 2.0, "Loss is too high")
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)
@with_temp_dir
def test_train_w_embedding_lr(self, temp_dir):
@@ -98,4 +100,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(temp_dir + "/runs", "train/loss", 2.0, "Loss is too high")
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)

View File

@@ -66,7 +66,7 @@ class TestPretrainLlama:
loss_threshold = 6.5
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
)

View File

@@ -70,7 +70,7 @@ class TestMixtral(unittest.TestCase):
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
@@ -125,7 +125,7 @@ class TestMixtral(unittest.TestCase):
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
@@ -183,7 +183,7 @@ class TestMixtral(unittest.TestCase):
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
@@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase):
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,8 +4,6 @@ E2E tests for custom optimizers using Llama
import unittest
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -284,60 +282,3 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@require_torch_2_7_0
@pytest.mark.parametrize(
"optimizer_name,expected_class,learning_rate",
[
("flash_adamw", "FlashAdamW", 0.00001),
("flash_adam", "FlashAdam", 0.00001),
("flash_sgd", "FlashSGD", 0.01),
("flash_sgdw", "FlashSGDW", 0.01),
("flash_lion", "FlashLion", 0.0001),
],
)
def test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate):
pytest.importorskip("flashoptim")
temp_dir = str(tmp_path)
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": learning_rate,
"optimizer": optimizer_name,
"max_steps": 5,
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert trainer.optimizer.optimizer.__class__.__name__ == expected_class

View File

@@ -62,5 +62,5 @@ class TestPackedLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -57,7 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.7, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.7, "Train Loss (%s) is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -128,7 +128,7 @@ class TestQATLlama:
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
)

View File

@@ -35,14 +35,6 @@ from tests.e2e.utils import (
)
def _get_fake_quant_config_dtype(config):
"""Get the weight dtype from a fake quantize config, handling different config types."""
if hasattr(config, "dtype"):
return config.dtype
# Int4WeightFakeQuantizeConfig doesn't have .dtype — weight is always int4
return torch.int4
@pytest.fixture()
def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
@@ -165,18 +157,6 @@ class TestQuantization:
expected_exception,
expected_tensor_class,
):
# TODO: add mslk-cuda as a CI dependency once pytorch 2.10.x is available
# (see https://pypi.org/project/mslk-cuda/)
if expected_tensor_class is Int4Tensor and activation_dtype is None:
try:
from torchao.quantization.quantize_.workflows.int4.int4_tensor import (
int4_row_quantize_zp,
)
if int4_row_quantize_zp is None:
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
except ImportError:
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
if expected_exception:
with pytest.raises(expected_exception):
quantize_model(
@@ -272,24 +252,28 @@ class TestQuantization:
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
embed_config = model.model.embed_tokens.weight_fake_quantizer.config
assert _get_fake_quant_config_dtype(embed_config) == weight_dtype.value
assert (
model.model.embed_tokens.weight_fake_quantizer.config.dtype
== weight_dtype.value
)
if group_size:
assert embed_config.group_size == group_size
assert (
model.model.embed_tokens.weight_fake_quantizer.config.group_size
== group_size
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
w_config = child.weight_fake_quantizer.config
assert _get_fake_quant_config_dtype(w_config) == weight_dtype.value
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
if group_size:
assert w_config.group_size == group_size
assert child.weight_fake_quantizer.config.group_size == group_size
if activation_dtype:
assert hasattr(child, "activation_fake_quantizer")
a_config = child.activation_fake_quantizer.config
assert (
_get_fake_quant_config_dtype(a_config) == activation_dtype.value
child.activation_fake_quantizer.config.dtype
== activation_dtype.value
)
else:
assert child.activation_fake_quantizer is None
@@ -390,16 +374,9 @@ class TestQuantizationCallback:
# ensure model has been quantized
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert isinstance(model.lm_head, FakeQuantizedLinear)
# Only test enable/disable toggling if the fake quantizer supports it
# (Int4WeightFakeQuantizer does not have an 'enabled' attribute)
supports_toggle = hasattr(
model.model.embed_tokens.weight_fake_quantizer, "enabled"
)
if supports_toggle:
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
qat_callback = QATCallback(cfg)
@@ -411,10 +388,9 @@ class TestQuantizationCallback:
model=model,
)
if supports_toggle:
# quantization should have been disabled
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
assert not model.lm_head.weight_fake_quantizer.enabled
# quantization should have been disabled
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
assert not model.lm_head.weight_fake_quantizer.enabled
trainer_state.global_step = 100
qat_callback.on_step_begin(
@@ -424,10 +400,9 @@ class TestQuantizationCallback:
model=model,
)
if supports_toggle:
# quantization should have been enabled
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
# quantization should have been enabled
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
@require_torch_2_8_0
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
@@ -449,10 +424,9 @@ class TestQuantizationCallback:
# ensure model has been quantized
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert isinstance(model.lm_head, FakeQuantizedLinear)
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
qat_callback = QATCallback(cfg)
# simulate first training step
@@ -464,6 +438,5 @@ class TestQuantizationCallback:
)
# quantization should be enabled from the get-go
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled

View File

@@ -66,6 +66,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -66,7 +66,7 @@ class TestStreamingDatasets:
# Verify training actually happened by checking loss decrease
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
3.0,
"Train Loss (%s) is too high",
)

View File

@@ -13,7 +13,6 @@ from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
from axolotl.utils.schemas.datasets import SFTDataset
from axolotl.utils.wandb_ import setup_wandb_env_vars
warnings.filterwarnings("error")
@@ -278,34 +277,6 @@ class TestValidation(BaseValidation):
new_cfg = validate_config(cfg)
assert new_cfg.type_of_model == "AutoModelForCausalLM"
def test_reward_model_defaults(self, minimal_cfg):
cfg = (
DictDefault(
{
"reward_model": True,
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.num_labels == 1
assert new_cfg.type_of_model == "AutoModelForSequenceClassification"
def test_process_reward_model_defaults(self, minimal_cfg):
cfg = (
DictDefault(
{
"process_reward_model": True,
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.num_labels == 2
assert new_cfg.type_of_model == "AutoModelForTokenClassification"
def test_model_revision_remap(self, minimal_cfg):
cfg = (
DictDefault(
@@ -1732,52 +1703,3 @@ class TestDataloaderValidation(BaseValidation):
assert new_cfg.dataloader_num_workers == 8
assert new_cfg.dataloader_pin_memory is True
assert new_cfg.dataloader_prefetch_factor == 256
class TestSyntheticDatasetValidation(BaseValidation):
"""
Tests for synthetic dataset config validation
"""
@staticmethod
def _make_cfg(minimal_cfg, datasets):
raw = dict(minimal_cfg)
raw["datasets"] = datasets
return DictDefault(raw)
def test_synthetic_dict_config_validates(self, minimal_cfg):
"""Synthetic dataset passed as a raw dict should not raise."""
cfg = self._make_cfg(
minimal_cfg,
[
{
"path": "synthetic",
"type": "_synthetic",
"length": 100,
"sequence_length": 64,
}
],
)
new_cfg = validate_config(cfg)
assert new_cfg.datasets[0]["path"] == "synthetic"
def test_synthetic_already_sft_does_not_crash(self, minimal_cfg):
"""Synthetic dataset already parsed as SFTDataset should not raise AttributeError."""
sft = SFTDataset(path="synthetic", type="_synthetic")
cfg = self._make_cfg(minimal_cfg, [sft])
# Before the fix, this raised:
# AttributeError: 'SFTDataset' object has no attribute 'get'
new_cfg = validate_config(cfg)
assert new_cfg.datasets[0]["path"] == "synthetic"
def test_non_synthetic_sft_validates(self, minimal_cfg):
"""A regular SFT dataset should validate without being treated as synthetic."""
cfg = self._make_cfg(
minimal_cfg,
[{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
)
new_cfg = validate_config(cfg)
assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test"

View File

@@ -1,125 +0,0 @@
"""Tests for the synthetic dataset generator."""
import unittest
from unittest.mock import MagicMock
from datasets import Dataset
from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load
from axolotl.utils.dict import DictDefault
class TestSyntheticDatasetStrategy(unittest.TestCase):
def test_generates_correct_shape(self):
strategy = SyntheticDatasetStrategy(
sequence_length=128,
length=50,
min_input_id=1,
max_input_id=1000,
seed=42,
)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
assert len(result) == 50
assert len(result[0]["input_ids"]) == 128
assert len(result[0]["attention_mask"]) == 128
assert len(result[0]["labels"]) == 128
def test_attention_mask_all_ones(self):
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
assert all(v == 1 for v in row["attention_mask"])
def test_labels_equal_input_ids(self):
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
assert row["input_ids"] == row["labels"]
def test_input_id_range(self):
strategy = SyntheticDatasetStrategy(
sequence_length=64,
length=100,
min_input_id=500,
max_input_id=600,
seed=42,
)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
for token_id in row["input_ids"]:
assert 500 <= token_id < 600
def test_seed_reproducibility(self):
kwargs = dict(
sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123
)
dummy = Dataset.from_dict({"text": [""]})
result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
for r1, r2 in zip(result1, result2, strict=True):
assert r1["input_ids"] == r2["input_ids"]
def test_different_seeds_differ(self):
common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000)
dummy = Dataset.from_dict({"text": [""]})
result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy)
result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy)
any_different = any(
r1["input_ids"] != r2["input_ids"]
for r1, r2 in zip(result1, result2, strict=True)
)
assert any_different
def test_load_function_with_ds_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 32000
cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False})
ds_cfg = {
"sequence_length": 256,
"length": 5,
"min_input_id": 10,
"max_input_id": 100,
"seed": 0,
}
strategy = load(tokenizer, cfg, ds_cfg=ds_cfg)
assert isinstance(strategy, SyntheticDatasetStrategy)
assert strategy.sequence_length == 256
assert strategy.length == 5
assert strategy.min_input_id == 10
assert strategy.max_input_id == 100
def test_load_defaults_from_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 32000
cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False})
strategy = load(tokenizer, cfg, ds_cfg={})
assert strategy.sequence_length == 1024
assert strategy.max_input_id == 32000
assert strategy.length == 1000
def test_load_with_no_ds_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 50000
cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False})
strategy = load(tokenizer, cfg)
assert strategy.sequence_length == 2048
assert strategy.max_input_id == 50000
if __name__ == "__main__":
unittest.main()

Some files were not shown because too many files have changed in this diff Show More