Compare commits
24 Commits
fix/cp-was
...
lhl-moe-au
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6636e5de7e | ||
|
|
0a566d7a15 | ||
|
|
5acb1b0ade | ||
|
|
4009a2ba5f | ||
|
|
66b2ab8414 | ||
|
|
676d5e855d | ||
|
|
966a4555db | ||
|
|
ad0c825bcb | ||
|
|
46d677876e | ||
|
|
6eac9ac372 | ||
|
|
949cdf01eb | ||
|
|
a0019021dd | ||
|
|
2af7475fdf | ||
|
|
3e4688289c | ||
|
|
5b2e3f00ce | ||
|
|
fc3b3d1d4e | ||
|
|
c9df6efdc2 | ||
|
|
0ee98a0309 | ||
|
|
2c05847a5f | ||
|
|
b0294b3427 | ||
|
|
1bcfc08c90 | ||
|
|
5a5cf30b26 | ||
|
|
7ddfb2d8a0 | ||
|
|
c57acef2c7 |
@@ -128,11 +128,9 @@ quartodoc:
|
||||
- monkeypatch.mistral_attn_hijack_flash
|
||||
- monkeypatch.multipack
|
||||
- monkeypatch.relora
|
||||
- monkeypatch.llama_expand_mask
|
||||
- monkeypatch.lora_kernels
|
||||
- monkeypatch.utils
|
||||
- monkeypatch.btlm_attn_hijack_flash
|
||||
- monkeypatch.llama_patch_multipack
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
|
||||
@@ -3,7 +3,7 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
curl --silent -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
# hf download "microsoft/Phi-4-reasoning"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Gradient Checkpointing and Activation Offloading
|
||||
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
|
||||
---
|
||||
|
||||
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
||||
@@ -27,3 +27,33 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
|
||||
|
||||
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
||||
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
||||
|
||||
### Enabling Layer Offloading
|
||||
|
||||
```yaml
|
||||
layer_offloading: true
|
||||
```
|
||||
|
||||
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
|
||||
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
|
||||
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
|
||||
trainable adapter weights stay on GPU permanently.
|
||||
|
||||
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
|
||||
|
||||
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
|
||||
prefetched asynchronously on a separate CUDA stream for overlap.
|
||||
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
|
||||
previous layer is prefetched.
|
||||
|
||||
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
|
||||
|
||||
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
|
||||
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
|
||||
that is kept on GPU at any given time.
|
||||
|
||||
**Requirements:**
|
||||
|
||||
- CUDA GPU (CPU-only training is not supported for this feature)
|
||||
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
|
||||
- Best combined with LoRA/QLoRA where most parameters are frozen
|
||||
|
||||
@@ -54,6 +54,13 @@ These techniques save VRAM by changing how activations are handled.
|
||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||
|
||||
### Layer Offloading
|
||||
|
||||
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
|
||||
|
||||
- **Config:** `layer_offloading: true`
|
||||
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||
|
||||
@@ -6,9 +6,6 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
Note: Training this model requires weights in BF16 which we will link to later.
|
||||
Users interested in training can convert / descale the existing FP8 weights.
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
|
||||
84
examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml
Normal file
84
examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml
Normal file
@@ -0,0 +1,84 @@
|
||||
base_model: Qwen/Qwen3.5-122B-A10B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: true
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -32,7 +32,11 @@ lora_target_modules:
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
#lora_target_parameters:
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
@@ -52,7 +56,6 @@ learning_rate: 0.0002
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
81
examples/qwen3.5/27b-qlora-fsdp.yaml
Normal file
81
examples/qwen3.5/27b-qlora-fsdp.yaml
Normal file
@@ -0,0 +1,81 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
# Uncomment below to also target the linear attention projections.
|
||||
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
|
||||
# - linear_attn.in_proj_qkv
|
||||
# - linear_attn.in_proj_z
|
||||
# - linear_attn.out_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,9 +1,7 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
|
||||
# the text-only path. For multimodal (image+text) fine-tuning, add image
|
||||
# columns to your dataset following axolotl's multimodal dataset format.
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
85
examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml
Normal file
85
examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml
Normal file
@@ -0,0 +1,85 @@
|
||||
base_model: Qwen/Qwen3.5-35B-A3B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: true
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -32,7 +32,11 @@ lora_target_modules:
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
#lora_target_parameters:
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
|
||||
@@ -26,8 +26,6 @@ lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
# Targets the language model attention and MLP layers.
|
||||
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
|
||||
# the same transformer stack, so standard attention targets work for both modalities.
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
|
||||
@@ -2,20 +2,6 @@
|
||||
|
||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
|
||||
|
||||
Vision and text tokens are processed through the same transformer stack. The configs below train on text-only data unless noted otherwise. See `9b-lora-vision.yaml` for a multimodal example.
|
||||
|
||||
Available configs:
|
||||
|
||||
| Config | Model | Type | Peak VRAM |
|
||||
|---|---|---|---|
|
||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only QLoRA | ~47 GiB |
|
||||
| `27b-fft.yaml` | Qwen3.5-27B | Dense VLM, text-only FFT (vision frozen) | ~53 GiB |
|
||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
||||
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
||||
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
||||
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
@@ -23,43 +9,69 @@ Available configs:
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
```
|
||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||
|
||||
4. Pick any config from the table below and run:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3.5/<config>.yaml
|
||||
```
|
||||
|
||||
Available configs:
|
||||
|
||||
| Config | Model | Type | Peak VRAM |
|
||||
|---|---|---|---|
|
||||
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
||||
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
|
||||
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
|
||||
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
|
||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
||||
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
|
||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
||||
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
|
||||
|
||||
### Gated DeltaNet Linear Attention
|
||||
|
||||
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
|
||||
|
||||
```yaml
|
||||
lora_target_modules:
|
||||
# ... standard projections ...
|
||||
- linear_attn.in_proj_qkv
|
||||
- linear_attn.in_proj_z
|
||||
- linear_attn.out_proj
|
||||
```
|
||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||
|
||||
4. Run a finetuning example:
|
||||
### Routed Experts (MoE)
|
||||
|
||||
```bash
|
||||
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
|
||||
axolotl train examples/qwen3.5/27b-qlora.yaml
|
||||
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
|
||||
|
||||
# Dense 27B text-only FFT with vision encoder frozen (~53 GiB, single 80 GiB GPU)
|
||||
axolotl train examples/qwen3.5/27b-fft.yaml
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- mlp.experts.gate_up_proj
|
||||
- mlp.experts.down_proj
|
||||
# - mlp.gate.weight # router
|
||||
```
|
||||
|
||||
# MoE 35B-A3B text-only (QLoRA)
|
||||
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
|
||||
### Shared Experts (MoE)
|
||||
|
||||
# MoE 122B-A10B text-only (QLoRA)
|
||||
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
|
||||
|
||||
# 9B vision+text (LoRA, multimodal dataset)
|
||||
axolotl train examples/qwen3.5/9b-lora-vision.yaml
|
||||
|
||||
# 9B vision+text FFT, single 80 GiB GPU (~61 GiB peak)
|
||||
axolotl train examples/qwen3.5/9b-fft-vision.yaml
|
||||
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
|
||||
|
||||
```yaml
|
||||
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||
- For **text-only FFT** on 27B, use `27b-fft.yaml` which sets `unfrozen_parameters` to freeze the vision encoder (`model.visual.*`) — this avoids wasting optimizer state on parameters that receive no gradient from text-only data.
|
||||
- For inference hyp, please see the respective model card details.
|
||||
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
|
||||
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
|
||||
@@ -61,5 +61,11 @@ skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
docstring-code-format = false
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-m 'not slow'"
|
||||
markers = [
|
||||
"slow: marks tests as slow",
|
||||
]
|
||||
|
||||
[tool.uv.extra-build-dependencies]
|
||||
axolotl = ["huggingface_hub"]
|
||||
|
||||
13
setup.py
13
setup.py
@@ -81,16 +81,23 @@ def parse_requirements(extras_require_map):
|
||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
||||
)
|
||||
|
||||
if (major, minor) >= (2, 9):
|
||||
if (major, minor) >= (2, 10):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.5.0",
|
||||
"fbgemm-gpu-genai==1.5.0",
|
||||
]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm==0.17.1"]
|
||||
elif (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.4.0",
|
||||
"fbgemm-gpu-genai==1.4.2",
|
||||
]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
if patch == 0:
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpcore
|
||||
from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
@@ -47,7 +48,7 @@ def check_user_token() -> bool:
|
||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except HTTPError:
|
||||
except (HTTPError, httpcore.ConnectError):
|
||||
LOG.warning(
|
||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||
)
|
||||
|
||||
@@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC):
|
||||
adam_kwargs["eps"] = (eps1, eps2)
|
||||
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_adamw":
|
||||
from flashoptim import FlashAdamW
|
||||
|
||||
optimizer_cls = FlashAdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_adam":
|
||||
from flashoptim import FlashAdam
|
||||
|
||||
optimizer_cls = FlashAdam
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_sgd":
|
||||
from flashoptim import FlashSGD
|
||||
|
||||
optimizer_cls = FlashSGD
|
||||
elif self.cfg.optimizer == "flash_sgdw":
|
||||
from flashoptim import FlashSGDW
|
||||
|
||||
optimizer_cls = FlashSGDW
|
||||
elif self.cfg.optimizer == "flash_lion":
|
||||
from flashoptim import FlashLion
|
||||
|
||||
optimizer_cls = FlashLion
|
||||
if "betas" in adam_kwargs:
|
||||
optimizer_kwargs["betas"] = adam_kwargs["betas"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
|
||||
@@ -484,6 +508,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.layer_offloading:
|
||||
training_args_kwargs["layer_offloading"] = True
|
||||
if self.cfg.activation_offloading is True:
|
||||
# don't use the HF gradient checkpointing, manually wrap
|
||||
training_args_kwargs["gradient_checkpointing"] = False
|
||||
|
||||
@@ -208,7 +208,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.eval_dataset:
|
||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and self.peft_config
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
|
||||
):
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
|
||||
@@ -29,10 +29,12 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
||||
from trl.experimental.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||
from axolotl.core.trainers.mixins import (
|
||||
ActivationOffloadingMixin,
|
||||
CheckpointSaveMixin,
|
||||
DistributedParallelMixin,
|
||||
LayerOffloadingMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
@@ -51,8 +53,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state."
|
||||
|
||||
REDUCTION_FNS = {
|
||||
"mean": torch.mean,
|
||||
"min": torch.min,
|
||||
@@ -67,6 +67,7 @@ class AxolotlTrainer(
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
LayerOffloadingMixin,
|
||||
ActivationOffloadingMixin,
|
||||
DistributedParallelMixin,
|
||||
Trainer,
|
||||
|
||||
1
src/axolotl/core/trainers/constants.py
Normal file
1
src/axolotl/core/trainers/constants.py
Normal file
@@ -0,0 +1 @@
|
||||
TOKENS_STATE_FILE = "tokens_state.json"
|
||||
@@ -2,7 +2,8 @@
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
@@ -16,3 +17,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
|
||||
dpo_norm_loss: bool | None = False
|
||||
rpo_alpha: Optional[float] = field(default=None)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from .activation_checkpointing import ActivationOffloadingMixin
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .layer_offloading import LayerOffloadingMixin
|
||||
from .distributed_parallel import DistributedParallelMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
|
||||
304
src/axolotl/core/trainers/mixins/layer_offloading.py
Normal file
304
src/axolotl/core/trainers/mixins/layer_offloading.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Trainer mixin for layer-wise parameter offloading to CPU.
|
||||
|
||||
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
|
||||
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
|
||||
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
|
||||
|
||||
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
|
||||
transfer stream), post-hook offloads layer N-1's frozen params.
|
||||
Backward: same in reverse order.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
|
||||
"""Recursively search the model for the decoder layer ModuleList.
|
||||
|
||||
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
|
||||
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
|
||||
where layers are at model.language_model.layers).
|
||||
"""
|
||||
# BFS to find the first ModuleList containing decoder layers
|
||||
queue = [model]
|
||||
while queue:
|
||||
m = queue.pop(0)
|
||||
for _name, child in m.named_children():
|
||||
if isinstance(child, nn.ModuleList) and len(child) > 0:
|
||||
first_type = type(child[0]).__name__
|
||||
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
|
||||
layer_types = list({type(layer).__name__ for layer in child})
|
||||
return child, layer_types
|
||||
else:
|
||||
queue.append(child)
|
||||
|
||||
return None, []
|
||||
|
||||
|
||||
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
|
||||
"""Get all non-trainable parameters in a layer."""
|
||||
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
|
||||
|
||||
|
||||
class LayerOffloadManager:
|
||||
"""Manages offloading frozen decoder layer params to CPU and streaming
|
||||
them back during forward/backward with CUDA stream overlap.
|
||||
|
||||
Only frozen (requires_grad=False) parameters are offloaded.
|
||||
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
num_prefetch: int = 1,
|
||||
):
|
||||
self.model = model
|
||||
self.num_prefetch = num_prefetch
|
||||
self._hooks: list = []
|
||||
self._device = None
|
||||
|
||||
# Find decoder layers
|
||||
self.layers, layer_types = _find_decoder_layers(model)
|
||||
if self.layers is None:
|
||||
LOG.warning(
|
||||
"LayerOffloadManager: no decoder layers found, offloading disabled"
|
||||
)
|
||||
self.enabled = False
|
||||
return
|
||||
|
||||
self.enabled = True
|
||||
self.n_layers = len(self.layers)
|
||||
LOG.info(
|
||||
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
|
||||
)
|
||||
|
||||
# Determine GPU device
|
||||
for p in model.parameters():
|
||||
if p.device.type == "cuda":
|
||||
self._device = p.device
|
||||
break
|
||||
if self._device is None:
|
||||
LOG.warning("LayerOffloadManager: no CUDA parameters found")
|
||||
self.enabled = False
|
||||
return
|
||||
|
||||
# Transfer stream for async prefetch
|
||||
self._transfer_stream = torch.cuda.Stream(device=self._device)
|
||||
|
||||
# Track which layers have their frozen params on GPU
|
||||
self._on_gpu: set[int] = set(range(self.n_layers))
|
||||
|
||||
# Cache: frozen param references per layer (list of (name, param) tuples)
|
||||
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
|
||||
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
|
||||
]
|
||||
|
||||
# CPU storage: pinned tensors for each layer's frozen params
|
||||
# Populated on first offload
|
||||
self._cpu_data: list[dict[str, torch.Tensor]] = [
|
||||
{} for _ in range(self.n_layers)
|
||||
]
|
||||
|
||||
# Offload all layers upfront
|
||||
self._offload_all()
|
||||
|
||||
# Release cached memory blocks back to the driver
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _offload_all(self):
|
||||
"""Move all frozen params in all decoder layers to CPU."""
|
||||
mem_before = torch.cuda.memory_allocated(self._device)
|
||||
for i in range(self.n_layers):
|
||||
self._offload_layer(i)
|
||||
mem_after = torch.cuda.memory_allocated(self._device)
|
||||
freed = (mem_before - mem_after) / 1e6
|
||||
LOG.info(
|
||||
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
|
||||
f"freed {freed:.0f} MB GPU memory"
|
||||
)
|
||||
|
||||
def _offload_layer(self, idx: int):
|
||||
"""Move frozen params of layer idx to CPU pinned memory."""
|
||||
if idx not in self._on_gpu:
|
||||
return
|
||||
for name, param in self._frozen_params[idx]:
|
||||
if param.device.type != "cuda":
|
||||
continue
|
||||
# Allocate pinned CPU tensor on first offload
|
||||
if name not in self._cpu_data[idx]:
|
||||
self._cpu_data[idx][name] = torch.empty_like(
|
||||
param.data, device="cpu", pin_memory=True
|
||||
)
|
||||
cpu_buf = self._cpu_data[idx][name]
|
||||
# Async copy GPU -> CPU (on transfer stream for overlap)
|
||||
cpu_buf.copy_(param.data, non_blocking=True)
|
||||
# Point parameter at a dummy CPU tensor to free GPU memory
|
||||
param.data = cpu_buf
|
||||
self._on_gpu.discard(idx)
|
||||
|
||||
def _load_layer(self, idx: int, stream=None):
|
||||
"""Move frozen params of layer idx back to GPU."""
|
||||
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
|
||||
return
|
||||
ctx = (
|
||||
torch.cuda.stream(stream)
|
||||
if stream is not None
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with ctx:
|
||||
for _name, param in self._frozen_params[idx]:
|
||||
if param.device.type == "cuda":
|
||||
continue
|
||||
gpu_data = param.data.to(self._device, non_blocking=True)
|
||||
param.data = gpu_data
|
||||
self._on_gpu.add(idx)
|
||||
|
||||
def _prefetch_layer(self, idx: int):
|
||||
"""Async prefetch layer idx on the transfer stream."""
|
||||
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
|
||||
return
|
||||
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
|
||||
self._load_layer(idx, stream=self._transfer_stream)
|
||||
|
||||
def _wait_transfer(self):
|
||||
"""Make default stream wait for any in-flight transfers."""
|
||||
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
|
||||
|
||||
def setup_hooks(self):
|
||||
"""Register forward and backward hooks on each decoder layer."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
for idx in range(self.n_layers):
|
||||
layer = self.layers[idx]
|
||||
|
||||
def make_pre_fwd(i):
|
||||
def hook(module, args):
|
||||
# Ensure this layer is on GPU
|
||||
if i not in self._on_gpu:
|
||||
self._load_layer(i)
|
||||
self._wait_transfer()
|
||||
# Prefetch next layer(s)
|
||||
for offset in range(1, self.num_prefetch + 1):
|
||||
self._prefetch_layer(i + offset)
|
||||
|
||||
return hook
|
||||
|
||||
def make_post_fwd(i):
|
||||
def hook(module, args, output):
|
||||
# Offload previous layer (no longer needed in forward)
|
||||
if i > 0:
|
||||
self._offload_layer(i - 1)
|
||||
# Offload last layer after forward
|
||||
if i == self.n_layers - 1:
|
||||
self._offload_layer(i)
|
||||
|
||||
return hook
|
||||
|
||||
def make_pre_bwd(i):
|
||||
def hook(module, grad_output):
|
||||
# Load this layer for backward
|
||||
if i not in self._on_gpu:
|
||||
self._load_layer(i)
|
||||
self._wait_transfer()
|
||||
# Prefetch previous layer(s)
|
||||
for offset in range(1, self.num_prefetch + 1):
|
||||
self._prefetch_layer(i - offset)
|
||||
|
||||
return hook
|
||||
|
||||
def make_post_bwd(i):
|
||||
def hook(module, grad_input, grad_output):
|
||||
# Offload the layer above
|
||||
if i < self.n_layers - 1:
|
||||
self._offload_layer(i + 1)
|
||||
# Offload first layer after backward
|
||||
if i == 0:
|
||||
self._offload_layer(i)
|
||||
|
||||
return hook
|
||||
|
||||
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
|
||||
h2 = layer.register_forward_hook(make_post_fwd(idx))
|
||||
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
|
||||
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
|
||||
self._hooks.extend([h1, h2, h3, h4])
|
||||
|
||||
def remove_hooks(self):
|
||||
"""Remove all hooks and restore layers to GPU."""
|
||||
for h in self._hooks:
|
||||
h.remove()
|
||||
self._hooks.clear()
|
||||
if self.enabled:
|
||||
for i in range(self.n_layers):
|
||||
if i not in self._on_gpu:
|
||||
self._load_layer(i)
|
||||
|
||||
def pre_step(self):
|
||||
"""Called before each training step — ensure layers start offloaded."""
|
||||
if not self.enabled:
|
||||
return
|
||||
for i in list(self._on_gpu):
|
||||
self._offload_layer(i)
|
||||
# Prefetch layer 0 for forward
|
||||
self._prefetch_layer(0)
|
||||
|
||||
def post_step(self):
|
||||
"""Called after each training step — ensure layers are offloaded."""
|
||||
if not self.enabled:
|
||||
return
|
||||
for i in list(self._on_gpu):
|
||||
self._offload_layer(i)
|
||||
# Prefetch layer 0 for next step
|
||||
self._prefetch_layer(0)
|
||||
|
||||
|
||||
class _LayerOffloadContext:
|
||||
"""Context manager wrapping pre_step / post_step around a training step."""
|
||||
|
||||
def __init__(self, manager: LayerOffloadManager):
|
||||
self.manager = manager
|
||||
|
||||
def __enter__(self):
|
||||
self.manager.pre_step()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.manager.post_step()
|
||||
|
||||
|
||||
class LayerOffloadingMixin(Trainer):
|
||||
"""
|
||||
Trainer mixin class for layer-wise parameter offloading to CPU.
|
||||
|
||||
Offloads frozen decoder layer params to CPU at init, then streams them
|
||||
on/off GPU one layer at a time during each training step.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if getattr(self.args, "layer_offloading", False):
|
||||
LOG.info("Layer parameter offloading enabled")
|
||||
self._layer_offload_manager = LayerOffloadManager(
|
||||
model=self.model,
|
||||
num_prefetch=1,
|
||||
)
|
||||
self._layer_offload_manager.setup_hooks()
|
||||
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
|
||||
else:
|
||||
self._layer_offload_manager = None
|
||||
self._layer_offload_ctx = contextlib.nullcontext()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
with self._layer_offload_ctx:
|
||||
return super().training_step(*args, **kwargs)
|
||||
@@ -235,6 +235,13 @@ class AxolotlTrainingMixins:
|
||||
metadata={"help": "Use activation offloading with CUDA streams for training."},
|
||||
)
|
||||
|
||||
layer_offloading: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
|
||||
50
src/axolotl/integrations/aux_free_router/README.md
Normal file
50
src/axolotl/integrations/aux_free_router/README.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Aux-Loss-Free MoE Router Plugin
|
||||
|
||||
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
|
||||
|
||||
Summary
|
||||
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
|
||||
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
|
||||
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
|
||||
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
|
||||
|
||||
Enable
|
||||
- Add the plugin to your YAML, then set the aux-free toggle:
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
|
||||
|
||||
moe_balance_type: noaux_tc
|
||||
moe_update_rate: 0.01 # default if unset
|
||||
moe_update_momentum: 0.9 # default if unset
|
||||
moe_bias_cap: 2.0 # default if unset
|
||||
moe_afb_warmup_steps: 100 # optional
|
||||
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
|
||||
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
|
||||
|
||||
Config keys
|
||||
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
|
||||
- moe_update_rate: bias update rate (gamma). Default: 0.01.
|
||||
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
|
||||
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
|
||||
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
||||
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
||||
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
||||
|
||||
Compatibility
|
||||
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
||||
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
|
||||
|
||||
Notes
|
||||
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
||||
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction using the Trainer’s `logging_steps` cadence.
|
||||
- Sample packing: packed batches are compatible with aux-free routing. Because load counts are accumulated on-device per expert before reduction, packing tends to smooth token histograms and reduce bias oscillation. Keep `pad_to_sequence_len: true` when packing to preserve the target token budget per expert.
|
||||
|
||||
Telemetry metrics
|
||||
- `moe_afb/l{idx}_load_min|mean|max`: token frequency per expert after reduction (0–1 range, sums to 1).
|
||||
- `moe_afb/l{idx}_bias_abs_max`: absolute maximum of the learned bias for the layer.
|
||||
- `moe_afb/l{idx}_bias_sign_flip_frac`: fraction of experts whose bias sign changed since the previous step (simple oscillation indicator).
|
||||
|
||||
Usage tips
|
||||
- Increase `logging_steps` if router telemetry becomes noisy for large jobs—the plugin follows the Trainer’s logging cadence.
|
||||
- Compare aux-free vs. aux-loss load metrics by plotting the `load_*` series; aux-free typically tightens min/max spread without the auxiliary loss term.
|
||||
9
src/axolotl/integrations/aux_free_router/__init__.py
Normal file
9
src/axolotl/integrations/aux_free_router/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Aux-loss-free (AFB) MoE router integration package."""
|
||||
|
||||
from .args import AuxFreeRouterArgs
|
||||
from .plugin import AuxFreeMoEPlugin
|
||||
|
||||
__all__ = [
|
||||
"AuxFreeMoEPlugin",
|
||||
"AuxFreeRouterArgs",
|
||||
]
|
||||
393
src/axolotl/integrations/aux_free_router/adapters.py
Normal file
393
src/axolotl/integrations/aux_free_router/adapters.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Architecture-specific adapters for aux-loss-free MoE routing.
|
||||
|
||||
Each adapter discovers MoE layers for a model family and patches only the
|
||||
router/gate to inject per-expert bias into expert selection while keeping
|
||||
mixture weights from unbiased logits. Expert dispatch is left untouched so
|
||||
the patching composes with any expert backend (eager, ScatterMoE, SonicMoE).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .core import AuxFreeShim
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerHandle:
|
||||
layer: nn.Module
|
||||
layer_idx: int
|
||||
num_experts: int
|
||||
top_k: int
|
||||
|
||||
|
||||
class BaseMoEAdapter:
|
||||
"""Base adapter that discovers MoE layers and patches their routing.
|
||||
|
||||
Concrete adapters implement discovery, attribute extraction, and
|
||||
architecture-specific router patching.
|
||||
"""
|
||||
|
||||
family: str = "generic"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
|
||||
return False
|
||||
|
||||
def find_moe_layers(
|
||||
self, model: nn.Module
|
||||
) -> Iterable[nn.Module]: # pragma: no cover
|
||||
return []
|
||||
|
||||
def get_top_k(self, moe_layer: nn.Module) -> int:
|
||||
"""Resolve top_k from the MoE layer, checking common attribute paths."""
|
||||
for attr_path in [
|
||||
("top_k",),
|
||||
("num_experts_per_tok",),
|
||||
("gate", "top_k"),
|
||||
("router", "top_k"),
|
||||
]:
|
||||
obj: object = moe_layer
|
||||
for attr in attr_path:
|
||||
obj = getattr(obj, attr, None)
|
||||
if obj is None:
|
||||
break
|
||||
if isinstance(obj, int):
|
||||
return obj
|
||||
return 2
|
||||
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||
"""Resolve num_experts from the MoE layer, checking common attribute paths."""
|
||||
for attr_path in [
|
||||
("num_experts",),
|
||||
("num_local_experts",),
|
||||
("gate", "num_experts"),
|
||||
("router", "num_experts"),
|
||||
("experts", "num_experts"),
|
||||
]:
|
||||
obj: object = moe_layer
|
||||
for attr in attr_path:
|
||||
obj = getattr(obj, attr, None)
|
||||
if obj is None:
|
||||
break
|
||||
if isinstance(obj, int):
|
||||
return obj
|
||||
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||
|
||||
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
|
||||
# Best-effort: zero router aux loss coef if present
|
||||
if hasattr(model_or_layer, "router_aux_loss_coef"):
|
||||
try:
|
||||
model_or_layer.router_aux_loss_coef = 0.0
|
||||
except Exception: # pragma: no cover - non-critical
|
||||
LOG.debug(
|
||||
"disable_aux_loss: failed to set router_aux_loss_coef on %s",
|
||||
type(model_or_layer).__name__,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _register_aux_buffers(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
p = next(moe_layer.parameters(), None)
|
||||
b = next(moe_layer.buffers(), None)
|
||||
device = (
|
||||
p.device
|
||||
if p is not None
|
||||
else (b.device if b is not None else torch.device("cpu"))
|
||||
)
|
||||
if not hasattr(moe_layer, "_afb_bias"):
|
||||
moe_layer.register_buffer(
|
||||
"_afb_bias", torch.zeros(handle.num_experts, device=device)
|
||||
)
|
||||
if not hasattr(moe_layer, "_afb_counts"):
|
||||
moe_layer.register_buffer(
|
||||
"_afb_counts", torch.zeros(handle.num_experts, device=device)
|
||||
)
|
||||
if not hasattr(moe_layer, "_afb_ema"):
|
||||
moe_layer.register_buffer(
|
||||
"_afb_ema", torch.zeros(handle.num_experts, device=device)
|
||||
)
|
||||
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
||||
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
||||
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
||||
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
"""Attach per-layer buffers. Subclasses override to also patch routing."""
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
|
||||
def uses_kernel_routing(self, moe_layer: nn.Module) -> bool:
|
||||
"""Return True when a kernel backend (SonicMoE / ScatterMoE) has
|
||||
already replaced the block forward, meaning the routing is handled
|
||||
inside the kernel forward and we should NOT patch the router."""
|
||||
cls = type(moe_layer)
|
||||
# SonicMoE stores the original forward when it patches a class.
|
||||
if hasattr(cls, "_original_forward"):
|
||||
return True
|
||||
# ScatterMoE replaces via kernels library; check for the marker.
|
||||
if hasattr(cls, "_kernel_forward"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MixtralAdapter(BaseMoEAdapter):
|
||||
"""Patches the TopKRouter for Mixtral / Qwen-MoE style softmax→topk
|
||||
routing so that biased logits drive expert *selection* while unbiased
|
||||
softmax scores drive mixture *weights*.
|
||||
|
||||
Works with transformers v5 where experts are fused 3D tensors and
|
||||
the router is ``MixtralTopKRouter`` (returns a 3-tuple).
|
||||
"""
|
||||
|
||||
family = "mixtral"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return (
|
||||
getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
||||
)
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
||||
yield m
|
||||
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
if not self.uses_kernel_routing(moe_layer):
|
||||
self._patch_router(moe_layer)
|
||||
else:
|
||||
LOG.info(
|
||||
"AuxFreeMoE: kernel backend detected on %s; "
|
||||
"skipping router patch (kernel routing handles bias)",
|
||||
type(moe_layer).__name__,
|
||||
)
|
||||
|
||||
def _patch_router(self, moe_layer: nn.Module) -> None:
|
||||
"""Patch the TopKRouter to inject aux-free bias into expert selection."""
|
||||
gate = getattr(moe_layer, "gate", None)
|
||||
if gate is None:
|
||||
LOG.info("MixtralAdapter: layer missing gate; skipping aux-free patch")
|
||||
return
|
||||
if getattr(gate, "_afb_patched", False):
|
||||
return
|
||||
|
||||
# Capture reference to the MoE block for bias / counts access.
|
||||
block_ref = moe_layer
|
||||
|
||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
router_logits = F.linear(hidden_states, self.weight)
|
||||
router_probs = F.softmax(router_logits.float(), dim=-1)
|
||||
|
||||
# Biased selection, unbiased weights
|
||||
bias = block_ref._afb_bias
|
||||
biased = router_probs + bias
|
||||
_, router_indices = torch.topk(biased, self.top_k, dim=-1)
|
||||
router_scores = torch.gather(router_probs, 1, router_indices)
|
||||
|
||||
# Renormalize (Mixtral always normalizes; Qwen checks config)
|
||||
if getattr(self, "norm_topk_prob", True):
|
||||
router_scores = router_scores / router_scores.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Accumulate counts for the bias-update callback
|
||||
flat_idx = router_indices.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=self.num_experts)
|
||||
block_ref._afb_counts.add_(counts.to(block_ref._afb_counts.dtype))
|
||||
|
||||
return router_probs, router_scores, router_indices
|
||||
|
||||
gate.forward = afb_router_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||
gate._afb_patched = True
|
||||
moe_layer._afb_patched = True
|
||||
|
||||
|
||||
class Qwen3Adapter(MixtralAdapter):
|
||||
family = "qwen3_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||
"qwen3_moe",
|
||||
"qwen2_moe",
|
||||
)
|
||||
|
||||
|
||||
class Qwen35MoeAdapter(MixtralAdapter):
|
||||
"""Adapter for Qwen 3.5 MoE models.
|
||||
|
||||
Same softmax→topk router pattern as Mixtral/Qwen3. The shared expert
|
||||
is handled by the block forward (untouched by router-level patching).
|
||||
"""
|
||||
|
||||
family = "qwen3_5_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||
"qwen3_5_moe",
|
||||
"qwen3_5_moe_text",
|
||||
)
|
||||
|
||||
|
||||
class BailingAdapter(BaseMoEAdapter):
|
||||
family = "bailing_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
cfg = getattr(model, "config", None)
|
||||
if cfg is None:
|
||||
return False
|
||||
model_type = getattr(cfg, "model_type", "") or ""
|
||||
if model_type in ("bailing_moe", "bailing_moe_v2", "ring_moe", "ring"):
|
||||
return True
|
||||
cfg_name = cfg.__class__.__name__.lower()
|
||||
return "bailingmoev2" in cfg_name or "ring" in cfg_name
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
|
||||
yield m
|
||||
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||
if hasattr(moe_layer, "num_experts"):
|
||||
return int(moe_layer.num_experts)
|
||||
cfg = getattr(moe_layer, "config", None)
|
||||
if cfg is None:
|
||||
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||
return int(cfg.num_experts)
|
||||
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_bailing_gate(moe_layer)
|
||||
|
||||
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
|
||||
gate = getattr(moe_layer, "gate", None)
|
||||
if gate is None:
|
||||
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
|
||||
return
|
||||
if getattr(gate, "_afb_patched", False):
|
||||
return
|
||||
|
||||
def afb_gate_forward(self, hidden_states: torch.Tensor):
|
||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
logits = F.linear(flat.float(), self.weight.float())
|
||||
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||
bias = moe_layer._afb_bias
|
||||
biased_scores = scores_unbiased + bias
|
||||
_, topk_idx = self.group_limited_topk(biased_scores)
|
||||
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
||||
if self.top_k > 1:
|
||||
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
|
||||
weights = weights / denom
|
||||
weights = weights * self.routed_scaling_factor
|
||||
|
||||
flat_topk = topk_idx.reshape(-1)
|
||||
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
||||
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return topk_idx, weights.to(hidden_states.dtype), logits
|
||||
|
||||
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||
gate._afb_patched = True
|
||||
|
||||
|
||||
class Llama4Adapter(BaseMoEAdapter):
|
||||
family = "llama4"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "Llama4TextMoe":
|
||||
yield m
|
||||
|
||||
def prepare(
|
||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||
) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_llama4_router(moe_layer)
|
||||
|
||||
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
|
||||
router = getattr(moe_layer, "router", None)
|
||||
if router is None:
|
||||
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
|
||||
return
|
||||
if getattr(router, "_afb_patched", False):
|
||||
return
|
||||
|
||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||
flat = (
|
||||
hidden_states
|
||||
if hidden_states.dim() == 2
|
||||
else hidden_states.view(-1, hidden_states.shape[-1])
|
||||
)
|
||||
router_logits = F.linear(flat, self.weight, self.bias)
|
||||
bias = moe_layer._afb_bias
|
||||
biased_logits = router_logits + bias
|
||||
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
||||
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
||||
router_scores = torch.full_like(router_logits, float("-inf"))
|
||||
router_scores.scatter_(1, router_indices, unbiased_top)
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
||||
|
||||
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
||||
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return router_scores, router_logits
|
||||
|
||||
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
||||
router._afb_patched = True
|
||||
|
||||
|
||||
def discover_and_prepare_layers(
|
||||
model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim
|
||||
) -> list[LayerHandle]:
|
||||
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
||||
|
||||
Returns a list of layer handles for later routing patching and updates.
|
||||
"""
|
||||
handles: list[LayerHandle] = []
|
||||
adapter: Optional[BaseMoEAdapter] = None
|
||||
for a in adapters:
|
||||
if a.matches(model):
|
||||
adapter = a
|
||||
break
|
||||
|
||||
if adapter is None:
|
||||
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
|
||||
return handles
|
||||
|
||||
# disable aux loss at model level if possible
|
||||
adapter.disable_aux_loss(getattr(model, "config", model))
|
||||
|
||||
idx = 0
|
||||
for layer in adapter.find_moe_layers(model):
|
||||
try:
|
||||
top_k = adapter.get_top_k(layer)
|
||||
nE = adapter.get_num_experts(layer)
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
continue
|
||||
|
||||
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
|
||||
adapter.prepare(layer, handle, shim)
|
||||
handles.append(handle)
|
||||
idx += 1
|
||||
|
||||
LOG.info(
|
||||
"AuxFreeMoE: prepared %d %s layers for aux-free routing",
|
||||
len(handles),
|
||||
adapter.family,
|
||||
)
|
||||
return handles
|
||||
71
src/axolotl/integrations/aux_free_router/args.py
Normal file
71
src/axolotl/integrations/aux_free_router/args.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Plugin args for the Aux-Loss-Free MoE router integration.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AuxFreeRouterArgs(BaseModel):
|
||||
"""
|
||||
Input args for Aux-Loss-Free MoE routing.
|
||||
"""
|
||||
|
||||
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, "
|
||||
"'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. "
|
||||
"Defaults to model's native behavior when unset."
|
||||
},
|
||||
)
|
||||
moe_update_rate: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Per-step bias update rate (gamma). Recommended: 0.005-0.05. "
|
||||
"If unset, plugin default is 0.01."
|
||||
},
|
||||
)
|
||||
moe_update_momentum: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "EMA momentum for expert load smoothing (0-1). "
|
||||
"If unset, plugin default is 0.9."
|
||||
},
|
||||
)
|
||||
moe_bias_cap: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Absolute clamp for expert bias magnitude. "
|
||||
"If unset, plugin default is 2.0."
|
||||
},
|
||||
)
|
||||
moe_afb_warmup_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of initial steps to delay aux-free bias updates, "
|
||||
"allowing routing to stabilize. If unset, plugin default is 0."
|
||||
},
|
||||
)
|
||||
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Reduction group for expert load counts: 'world' (DP) or "
|
||||
"'ep' (expert-parallel group if available). Defaults to 'world' when unset."
|
||||
},
|
||||
)
|
||||
166
src/axolotl/integrations/aux_free_router/core.py
Normal file
166
src/axolotl/integrations/aux_free_router/core.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuxFreeConfig:
|
||||
rate: float = 0.01
|
||||
momentum: float = 0.9
|
||||
bias_cap: float = 2.0
|
||||
warmup_steps: int = 0
|
||||
sync_group: str = "world" # or "ep"
|
||||
|
||||
|
||||
class AuxFreeState:
|
||||
"""Holds per-layer bias and EMA load buffers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_experts: int,
|
||||
device: torch.device,
|
||||
cfg: AuxFreeConfig,
|
||||
):
|
||||
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
||||
self.ema_load = [
|
||||
torch.zeros(num_experts, device=device) for _ in range(num_layers)
|
||||
]
|
||||
self.cfg = cfg
|
||||
self.steps = 0
|
||||
|
||||
|
||||
class AuxFreeShim:
|
||||
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: AuxFreeState,
|
||||
ep_group: Optional[dist.ProcessGroup] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
):
|
||||
self.state = state
|
||||
self.ep_group = ep_group
|
||||
self._ep_size = ep_size
|
||||
self._ep_group_pending = (
|
||||
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
||||
)
|
||||
self._layer_modules: dict[int, torch.nn.Module] = {}
|
||||
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_experts(
|
||||
self, layer_idx: int, logits: torch.Tensor, top_k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
|
||||
module = self._layer_modules.get(layer_idx)
|
||||
if module is not None and hasattr(module, "_afb_bias"):
|
||||
b = module._afb_bias
|
||||
else:
|
||||
b = self.state.bias[layer_idx]
|
||||
biased = logits + b # bias is a buffer
|
||||
_topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
|
||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
|
||||
return topk_idx, weights
|
||||
|
||||
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
|
||||
"""Bind model buffers so shim updates stay in sync with patched layers."""
|
||||
self._layer_modules[layer_idx] = module
|
||||
bias = module._afb_bias
|
||||
ema = module._afb_ema
|
||||
# Keep state views pointing to the same tensors to avoid drift.
|
||||
if layer_idx < len(self.state.bias):
|
||||
self.state.bias[layer_idx] = bias
|
||||
if layer_idx < len(self.state.ema_load):
|
||||
self.state.ema_load[layer_idx] = ema
|
||||
|
||||
def begin_step(self) -> None:
|
||||
"""Call once per optimizer step before per-layer updates."""
|
||||
self.state.steps += 1
|
||||
|
||||
def get_prev_bias_sign(self, layer_idx: int) -> Optional[torch.Tensor]:
|
||||
return self._prev_bias_sign.get(layer_idx)
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
||||
self._maybe_init_ep_group()
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
return counts
|
||||
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
|
||||
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
|
||||
return counts
|
||||
|
||||
@torch.no_grad()
|
||||
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
|
||||
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
|
||||
cfg = self.state.cfg
|
||||
if self.state.steps <= cfg.warmup_steps:
|
||||
return
|
||||
|
||||
nE = step_counts.numel()
|
||||
if tokens_seen <= 0:
|
||||
return
|
||||
module = self._layer_modules.get(layer_idx)
|
||||
if module is not None and hasattr(module, "_afb_ema"):
|
||||
ema = module._afb_ema
|
||||
bias = module._afb_bias
|
||||
else:
|
||||
ema = self.state.ema_load[layer_idx]
|
||||
bias = self.state.bias[layer_idx]
|
||||
counts = step_counts.to(ema.device)
|
||||
freq = counts.float() / float(tokens_seen)
|
||||
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
|
||||
target = 1.0 / float(nE)
|
||||
delta = cfg.rate * (target - ema)
|
||||
# optional mean-centering to keep sum(bias) ~ 0
|
||||
delta = delta - delta.mean()
|
||||
bias.add_(delta)
|
||||
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
||||
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
||||
self._prev_bias_sign[layer_idx] = torch.sign(bias.detach())
|
||||
|
||||
def _maybe_init_ep_group(self) -> None:
|
||||
if not self._ep_group_pending:
|
||||
return
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
return
|
||||
ep_size = self._ep_size
|
||||
if not ep_size or ep_size <= 1:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
|
||||
)
|
||||
self.ep_group = dist.group.WORLD
|
||||
self._ep_group_pending = False
|
||||
return
|
||||
world = dist.get_world_size()
|
||||
if world % ep_size != 0:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
|
||||
ep_size,
|
||||
world,
|
||||
)
|
||||
self.ep_group = dist.group.WORLD
|
||||
self._ep_group_pending = False
|
||||
return
|
||||
if ep_size == world:
|
||||
self.ep_group = dist.group.WORLD
|
||||
else:
|
||||
rank = dist.get_rank()
|
||||
group_start = (rank // ep_size) * ep_size
|
||||
ranks = tuple(range(group_start, group_start + ep_size))
|
||||
self.ep_group = dist.new_group(ranks)
|
||||
LOG.info(
|
||||
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
|
||||
ep_size,
|
||||
dist.get_world_size(),
|
||||
)
|
||||
self._ep_group_pending = False
|
||||
267
src/axolotl/integrations/aux_free_router/plugin.py
Normal file
267
src/axolotl/integrations/aux_free_router/plugin.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Aux-loss-free MoE Router Plugin for Axolotl.
|
||||
|
||||
This plugin wires an aux-free gating option into compatible MoE models using
|
||||
unbiased logits for mixture weights and per-expert biases for top-k selection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .adapters import (
|
||||
BailingAdapter,
|
||||
BaseMoEAdapter,
|
||||
Llama4Adapter,
|
||||
MixtralAdapter,
|
||||
Qwen3Adapter,
|
||||
Qwen35MoeAdapter,
|
||||
discover_and_prepare_layers,
|
||||
)
|
||||
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
"""Post-step callback to update aux-free biases from accumulated expert counts."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shim: AuxFreeShim,
|
||||
layer_modules: list[torch.nn.Module],
|
||||
trainer: Any,
|
||||
):
|
||||
self.shim = shim
|
||||
self.layer_modules = layer_modules
|
||||
self.trainer = trainer
|
||||
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||
self._telemetry_buffer: dict[int, dict[str, float]] = {}
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
||||
# Iterate prepared MoE layers and apply the bias update rule.
|
||||
self.shim.begin_step()
|
||||
for layer in self.layer_modules:
|
||||
if not hasattr(layer, "_afb_counts") or not hasattr(
|
||||
layer, "_afb_layer_idx"
|
||||
):
|
||||
continue
|
||||
counts = layer._afb_counts
|
||||
if counts is None:
|
||||
continue
|
||||
counts = self.shim.all_reduce_counts(counts)
|
||||
layer_idx = getattr(layer, "_afb_layer_idx", None)
|
||||
if layer_idx is None:
|
||||
counts.zero_()
|
||||
continue
|
||||
bias = layer._afb_bias
|
||||
counts_for_update = counts.to(bias.device)
|
||||
tokens_seen = int(counts_for_update.sum().item())
|
||||
# local layer-state EMA and bias update
|
||||
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
||||
self._collect_telemetry(layer_idx, counts_for_update, tokens_seen, bias)
|
||||
# reset step counts
|
||||
counts.zero_()
|
||||
|
||||
if self._should_log(args, state) and self._telemetry_buffer:
|
||||
logs: dict[str, float] = {}
|
||||
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
|
||||
prefix = f"moe_afb/l{layer_idx}_"
|
||||
for key, value in metrics.items():
|
||||
logs[f"{prefix}{key}"] = value
|
||||
if logs and hasattr(self.trainer, "log"):
|
||||
self.trainer.log(logs)
|
||||
self._telemetry_buffer.clear()
|
||||
return control
|
||||
|
||||
def _collect_telemetry(
|
||||
self,
|
||||
layer_idx: int,
|
||||
counts: torch.Tensor,
|
||||
tokens_seen: int,
|
||||
bias: torch.Tensor,
|
||||
) -> None:
|
||||
if tokens_seen <= 0:
|
||||
return
|
||||
freq = counts.float() / float(tokens_seen)
|
||||
load_min = freq.min().item()
|
||||
load_mean = freq.mean().item()
|
||||
load_max = freq.max().item()
|
||||
bias_abs_max = bias.abs().max().item()
|
||||
|
||||
prev_sign = self._prev_bias_sign.get(layer_idx)
|
||||
current_sign = torch.sign(bias.detach())
|
||||
if prev_sign is None or prev_sign.shape != current_sign.shape:
|
||||
oscillation = 0.0
|
||||
else:
|
||||
changed = (current_sign != prev_sign) & (
|
||||
(current_sign != 0) | (prev_sign != 0)
|
||||
)
|
||||
if changed.numel() == 0:
|
||||
oscillation = 0.0
|
||||
else:
|
||||
oscillation = changed.float().mean().item()
|
||||
self._prev_bias_sign[layer_idx] = current_sign.clone()
|
||||
|
||||
self._telemetry_buffer[layer_idx] = {
|
||||
"load_min": load_min,
|
||||
"load_mean": load_mean,
|
||||
"load_max": load_max,
|
||||
"bias_abs_max": bias_abs_max,
|
||||
"bias_sign_flip_frac": oscillation,
|
||||
}
|
||||
|
||||
def _should_log(self, args, state) -> bool:
|
||||
interval = getattr(args, "logging_steps", 0)
|
||||
if not interval:
|
||||
return False
|
||||
try:
|
||||
interval = max(1, int(interval))
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return interval > 0 and state.global_step % interval == 0
|
||||
|
||||
|
||||
class AuxFreeMoEPlugin(BasePlugin):
|
||||
"""Plugin that enables aux-loss-free routing when configured."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._handles: list = []
|
||||
self._shim: Optional[AuxFreeShim] = None
|
||||
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.aux_free_router.AuxFreeRouterArgs"
|
||||
|
||||
def post_model_build(self, cfg, model):
|
||||
# Enable only when explicitly requested
|
||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||
return
|
||||
|
||||
# Be conservative — skip known native aux-free families
|
||||
native_auxfree = getattr(
|
||||
getattr(model, "config", object()), "model_type", ""
|
||||
) in (
|
||||
"deepseek_v3",
|
||||
"glm4_moe",
|
||||
)
|
||||
if native_auxfree:
|
||||
LOG.info(
|
||||
"AuxFreeMoE: model reports native aux-free routing; skipping patching"
|
||||
)
|
||||
return
|
||||
|
||||
# Build aux-free state and shim
|
||||
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
|
||||
momentum = (
|
||||
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
|
||||
)
|
||||
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
||||
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
||||
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
||||
af_cfg = AuxFreeConfig(
|
||||
rate=rate,
|
||||
momentum=momentum,
|
||||
bias_cap=bias_cap,
|
||||
warmup_steps=warmup,
|
||||
sync_group=sync_group,
|
||||
)
|
||||
|
||||
# Discover layers to count the number and experts for state sizing
|
||||
adapters: list[BaseMoEAdapter] = [
|
||||
MixtralAdapter(),
|
||||
Qwen3Adapter(),
|
||||
Qwen35MoeAdapter(),
|
||||
BailingAdapter(),
|
||||
Llama4Adapter(),
|
||||
]
|
||||
|
||||
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
||||
n_layers = 0
|
||||
n_experts = None
|
||||
for _m in model.modules():
|
||||
n_layers += 1 # upper bound — we will re-use bias slots sparsely
|
||||
device = next(model.parameters(), torch.tensor(0)).device
|
||||
if n_layers <= 0:
|
||||
n_layers = 1
|
||||
if n_experts is None:
|
||||
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
||||
n_experts = 2
|
||||
state = AuxFreeState(
|
||||
num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg
|
||||
)
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
ep_group = None
|
||||
if sync_group == "ep":
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
ep_group = self._resolve_ep_group(cfg)
|
||||
else:
|
||||
LOG.info(
|
||||
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
|
||||
)
|
||||
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
|
||||
|
||||
# Discover and prepare layers (attach per-layer buffers)
|
||||
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
|
||||
|
||||
LOG.info(
|
||||
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
|
||||
)
|
||||
|
||||
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world"
|
||||
)
|
||||
return None
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
if not ep_size or ep_size <= 1:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world"
|
||||
)
|
||||
return None
|
||||
world = dist.get_world_size()
|
||||
if world % ep_size != 0:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
|
||||
ep_size,
|
||||
world,
|
||||
)
|
||||
return None
|
||||
if ep_size == world:
|
||||
return dist.group.WORLD
|
||||
|
||||
rank = dist.get_rank()
|
||||
# All ranks must collectively create all EP subgroups in the same order
|
||||
# to avoid deadlocks (dist.new_group is a collective operation).
|
||||
world_size = world
|
||||
my_group = None
|
||||
for group_start in range(0, world_size, ep_size):
|
||||
ranks = tuple(range(group_start, group_start + ep_size))
|
||||
if ranks not in self._ep_group_cache:
|
||||
self._ep_group_cache[ranks] = dist.new_group(ranks)
|
||||
if rank in ranks:
|
||||
my_group = self._ep_group_cache[ranks]
|
||||
return my_group
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||
return []
|
||||
if self._shim is None:
|
||||
return []
|
||||
# gather concrete layer modules from handles
|
||||
layers = [h.layer for h in self._handles]
|
||||
cb = MoeAuxFreeBiasUpdateCallback(
|
||||
self._shim,
|
||||
layers,
|
||||
trainer,
|
||||
)
|
||||
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
||||
return [cb]
|
||||
@@ -15,6 +15,7 @@ SPARSE_MOE_BLOCK = {
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
||||
"qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock",
|
||||
"qwen3_next": "Qwen3NextSparseMoeBlock",
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
|
||||
@@ -58,7 +59,16 @@ def resolve_moe_block_classes(model_type: str):
|
||||
|
||||
cls_names = entry if isinstance(entry, list) else [entry]
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = importlib.import_module(module_path)
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ModuleNotFoundError:
|
||||
# Text sub-model types (e.g. qwen3_5_moe_text) share the parent module
|
||||
if model_type.endswith("_text"):
|
||||
parent_type = model_type.removesuffix("_text")
|
||||
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
|
||||
module = importlib.import_module(module_path)
|
||||
else:
|
||||
raise
|
||||
|
||||
classes = []
|
||||
for cls_name in cls_names:
|
||||
|
||||
@@ -363,7 +363,7 @@ def _scatter2scatter_lora_configs():
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128}
|
||||
BLOCK_N: {32, 64, 128, 256}
|
||||
BLOCK_N: {32, 64}
|
||||
BLOCK_K: {32, 64, 128}
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
@@ -371,7 +371,7 @@ def _scatter2scatter_lora_configs():
|
||||
configs = []
|
||||
for block_m, block_n, block_k, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_N
|
||||
[32, 64], # BLOCK_N
|
||||
[32, 64, 128], # BLOCK_K
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
@@ -943,16 +943,16 @@ def _scatter2scatter_lora_dX_configs():
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128} (token tile)
|
||||
BLOCK_K: {32, 64, 128, 256} (output tile)
|
||||
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
||||
BLOCK_K: {32, 64, 128} (output tile)
|
||||
BLOCK_N: {32, 64} (reduction tile)
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
"""
|
||||
configs = []
|
||||
for block_m, block_k, block_n, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
||||
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
||||
[32, 64, 128], # BLOCK_K (output dimension)
|
||||
[32, 64], # BLOCK_N (reduction dimension)
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
@@ -1278,9 +1278,9 @@ def _group_bwd_lora_configs():
|
||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
|
||||
BLOCK_K: {32, 64, 128, 256}
|
||||
BLOCK_N: {32, 64, 128, 256}
|
||||
BLOCK_M: {32, 64, 128} (token-loop tile)
|
||||
BLOCK_K: {32, 64, 128}
|
||||
BLOCK_N: {32, 64}
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
|
||||
@@ -1289,9 +1289,9 @@ def _group_bwd_lora_configs():
|
||||
"""
|
||||
configs = []
|
||||
for block_m, block_k, block_n, warps, stages in product(
|
||||
[32, 64, 128, 256], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_K
|
||||
[32, 64, 128, 256], # BLOCK_N
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128], # BLOCK_K
|
||||
[32, 64], # BLOCK_N
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
|
||||
@@ -240,7 +240,16 @@ def _softmax_topk_route(
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
# Aux-free bias: biased selection, unbiased weights
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = routing_weights + afb_bias
|
||||
_, selected_experts = torch.topk(scores_for_choice, top_k, dim=-1)
|
||||
routing_weights = routing_weights.gather(1, selected_experts)
|
||||
_accumulate_afb_counts(moe_block, selected_experts)
|
||||
else:
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if getattr(base_gate, "norm_topk_prob", True):
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
@@ -282,6 +291,11 @@ def _sigmoid_topk_route(
|
||||
else:
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = scores_for_choice + afb_bias
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
if n_group > 1:
|
||||
@@ -307,6 +321,10 @@ def _sigmoid_topk_route(
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Accumulate counts for aux-free bias update
|
||||
if afb_bias is not None:
|
||||
_accumulate_afb_counts(moe_block, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
if getattr(moe_block, "norm_topk_prob", True):
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
@@ -335,6 +353,16 @@ def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
|
||||
)
|
||||
|
||||
|
||||
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||
"""Accumulate per-expert token counts for aux-free bias updates."""
|
||||
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||
if afb_counts is None:
|
||||
return
|
||||
flat_idx = topk_indices.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=afb_counts.numel())
|
||||
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shared expert helpers
|
||||
# =============================================================================
|
||||
|
||||
@@ -9,6 +9,12 @@ Different MoE architectures use different routing strategies:
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
|
||||
Aux-loss-free (AFB) bias integration: when the aux_free_router plugin is
|
||||
active, ``moe_block._afb_bias`` and ``moe_block._afb_counts`` are registered
|
||||
as buffers. The routing functions transparently inject the bias into expert
|
||||
*selection* (biased topk) while keeping mixture *weights* from unbiased
|
||||
scores, then accumulate per-expert token counts for the post-step bias update.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -101,17 +107,25 @@ def softmax_topk_routing(
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Aux-free bias: biased selection, unbiased weights
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
scores_for_choice = router_probs
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = router_probs + afb_bias
|
||||
|
||||
# Select top-k experts per token
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
||||
top_values, top_indices = torch.topk(scores_for_choice, K, dim=-1) # [T, K] each
|
||||
|
||||
# When aux-free bias is active, gather unbiased weights and accumulate counts
|
||||
if afb_bias is not None:
|
||||
top_values = router_probs.gather(1, top_indices)
|
||||
_accumulate_afb_counts(moe_block, top_indices)
|
||||
|
||||
# Renormalize if configured (default True for models without the attribute,
|
||||
# e.g. Mixtral/MiniMax which always normalize)
|
||||
if getattr(gate, "norm_topk_prob", True):
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
||||
# top_values = top_values.to(router_probs.dtype)
|
||||
|
||||
# Flatten for moe_general_routing_inputs.
|
||||
# Token indices are naturally sorted ascending from the [T, K] layout:
|
||||
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
||||
@@ -142,7 +156,11 @@ def softmax_group_topk_routing(
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Aux-free bias: inject before group selection / topk
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
scores_for_choice = router_probs
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = router_probs + afb_bias
|
||||
|
||||
# Group selection: pick top groups, mask the rest
|
||||
if n_group > 1:
|
||||
@@ -159,11 +177,17 @@ def softmax_group_topk_routing(
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
scores_for_choice = scores_for_choice.masked_fill(
|
||||
~score_mask.bool(), -float("inf")
|
||||
)
|
||||
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Accumulate counts for aux-free bias update
|
||||
if afb_bias is not None:
|
||||
_accumulate_afb_counts(moe_block, topk_indices)
|
||||
|
||||
# Renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
@@ -233,6 +257,11 @@ def sigmoid_topk_routing(
|
||||
)
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
|
||||
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||
if afb_bias is not None:
|
||||
scores_for_choice = scores_for_choice + afb_bias
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
@@ -248,7 +277,9 @@ def sigmoid_topk_routing(
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
scores_for_choice = scores_for_choice.masked_fill(
|
||||
~score_mask.bool(), -float("inf")
|
||||
)
|
||||
|
||||
# Final topk from (possibly masked) scores
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
@@ -256,6 +287,10 @@ def sigmoid_topk_routing(
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Accumulate counts for aux-free bias update
|
||||
if afb_bias is not None:
|
||||
_accumulate_afb_counts(moe_block, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
@@ -276,3 +311,21 @@ def sigmoid_topk_routing(
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||
"""Accumulate per-expert token counts for the aux-free bias update.
|
||||
|
||||
Called when ``moe_block._afb_bias`` is present (registered by the
|
||||
``aux_free_router`` plugin). The counts are later consumed by the
|
||||
``MoeAuxFreeBiasUpdateCallback`` at each training step.
|
||||
"""
|
||||
if hasattr(moe_block, "training") and not moe_block.training:
|
||||
return
|
||||
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||
if afb_counts is None:
|
||||
return
|
||||
num_experts = afb_counts.numel()
|
||||
flat_idx = topk_indices.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=num_experts)
|
||||
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||
|
||||
@@ -571,15 +571,6 @@ class PatchManager:
|
||||
LOG.info("Patching with xformers attention...")
|
||||
hijack_llama_attention()
|
||||
|
||||
def _patch_llama_sample_packing(self):
|
||||
"""Apply sample packing patches for LLaMA models."""
|
||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
||||
hijack_llama_prepare_4d_mask,
|
||||
)
|
||||
|
||||
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
|
||||
hijack_llama_prepare_4d_mask()
|
||||
|
||||
def _patch_llama_derived_model(self):
|
||||
"""Modify all llama derived models in one block."""
|
||||
if self.cfg.is_llama_derived_model and not (
|
||||
@@ -591,8 +582,6 @@ class PatchManager:
|
||||
self._patch_llama_flash_attention()
|
||||
elif self.cfg.xformers_attention:
|
||||
self._patch_llama_xformers_attention()
|
||||
elif self.cfg.sample_packing:
|
||||
self._patch_llama_sample_packing()
|
||||
elif self.cfg.s2_attention:
|
||||
raise NotImplementedError(
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
|
||||
@@ -221,6 +221,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
# Generic fallback: if tokenizer still has no pad_token, use eos_token
|
||||
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
LOG.warning(
|
||||
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
|
||||
tokenizer.eos_token,
|
||||
)
|
||||
|
||||
additional_special_tokens = None
|
||||
if cfg.special_tokens:
|
||||
special_tokens = cfg.special_tokens.to_dict()
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
"""
|
||||
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.utils import mask_2d_to_4d
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
|
||||
inverted_mask = 1.0 - masked_zero_one_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def hijack_expand_mask():
|
||||
import transformers
|
||||
|
||||
transformers.models.llama.modeling_llama._expand_mask = _expand_mask
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
from axolotl.monkeypatch.utils import (
|
||||
patched_prepare_4d_causal_attention_mask,
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
|
||||
|
||||
def hijack_llama_prepare_4d_mask():
|
||||
from transformers import modeling_attn_mask_utils
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = (
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = (
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
modeling_llama._prepare_4d_causal_attention_mask = (
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = (
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
@@ -3,15 +3,10 @@ Shared utils for the monkeypatches
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -170,65 +165,6 @@ def set_module_name(model, name, value):
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
|
||||
def mask_2d_to_4d(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1, device=mask.device).to(dtype),
|
||||
torch.tensor(0, device=mask.device).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
|
||||
return masked_zero_one_mask
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def detab_code(code: str) -> Tuple[str, str]:
|
||||
try:
|
||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||
|
||||
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Synthetic dataset generator for benchmarking and testing.
|
||||
|
||||
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
|
||||
Useful for benchmarking memory usage and speed by sequence length, and for validating
|
||||
weighted dataset mixes.
|
||||
|
||||
YAML configuration example:
|
||||
|
||||
datasets:
|
||||
- path: synthetic
|
||||
type: _synthetic
|
||||
length: 1000
|
||||
sequence_length: 2048
|
||||
min_input_id: 100
|
||||
max_input_id: 32000
|
||||
seed: 42
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
|
||||
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sequence_length: int = 2048,
|
||||
length: int = 1000,
|
||||
min_input_id: int = 100,
|
||||
max_input_id: int = 32000,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
self.sequence_length = sequence_length
|
||||
self.length = length
|
||||
self.min_input_id = min_input_id
|
||||
self.max_input_id = max_input_id
|
||||
self.seed = seed
|
||||
|
||||
def wrap_dataset(
|
||||
self,
|
||||
dataset,
|
||||
process_count: int | None = None,
|
||||
keep_in_memory: bool | None = False,
|
||||
**kwargs,
|
||||
) -> Dataset:
|
||||
LOG.info(
|
||||
f"Generating synthetic dataset: {self.length} samples, "
|
||||
f"sequence_length={self.sequence_length}, "
|
||||
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
|
||||
)
|
||||
|
||||
rng = np.random.default_rng(self.seed)
|
||||
input_ids = rng.integers(
|
||||
low=self.min_input_id,
|
||||
high=self.max_input_id,
|
||||
size=(self.length, self.sequence_length),
|
||||
).tolist()
|
||||
|
||||
attention_mask = [[1] * self.sequence_length] * self.length
|
||||
# labels == input_ids means we train on all tokens
|
||||
labels = [row[:] for row in input_ids]
|
||||
|
||||
return Dataset.from_dict(
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
|
||||
length = ds_cfg.get("length", 1000)
|
||||
min_input_id = ds_cfg.get("min_input_id", 100)
|
||||
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
|
||||
seed = ds_cfg.get("seed", None)
|
||||
|
||||
return SyntheticDatasetStrategy(
|
||||
sequence_length=sequence_length,
|
||||
length=length,
|
||||
min_input_id=min_input_id,
|
||||
max_input_id=max_input_id,
|
||||
seed=seed,
|
||||
)
|
||||
@@ -82,7 +82,7 @@ def setup_model_and_tokenizer(
|
||||
|
||||
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
||||
model, peft_config = model_loader.load()
|
||||
if model.generation_config is not None:
|
||||
if getattr(model, "generation_config", None) is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_properties = model.config.to_dict()
|
||||
|
||||
@@ -25,9 +25,11 @@ def toggle_fake_quant(mod: nn.Module, enable: bool):
|
||||
if (
|
||||
isinstance(mod, FakeQuantizedLinear)
|
||||
and mod.activation_fake_quantizer is not None
|
||||
and hasattr(mod.activation_fake_quantizer, "enabled")
|
||||
):
|
||||
mod.activation_fake_quantizer.enabled = enable
|
||||
mod.weight_fake_quantizer.enabled = enable
|
||||
if hasattr(mod.weight_fake_quantizer, "enabled"):
|
||||
mod.weight_fake_quantizer.enabled = enable
|
||||
|
||||
|
||||
class QATCallback(TrainerCallback):
|
||||
|
||||
@@ -12,12 +12,11 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state.json"
|
||||
|
||||
|
||||
class TokensPerSecondCallback(TrainerCallback):
|
||||
"""
|
||||
|
||||
@@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||
from axolotl.utils.schemas.datasets import (
|
||||
DPODataset,
|
||||
KTODataset,
|
||||
SFTDataset,
|
||||
SyntheticDataset,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -294,6 +299,7 @@ def validate_config(
|
||||
AxolotlInputConfig = AxolotlInputConfigBase
|
||||
|
||||
if cfg.plugins:
|
||||
prepare_plugins(cfg)
|
||||
(
|
||||
AxolotlConfigWCapabilities,
|
||||
AxolotlInputConfig,
|
||||
@@ -308,6 +314,14 @@ def validate_config(
|
||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||
elif (
|
||||
ds_cfg.get("type")
|
||||
if isinstance(ds_cfg, dict)
|
||||
else getattr(ds_cfg, "type", None)
|
||||
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
|
||||
cfg["datasets"][idx] = SyntheticDataset(
|
||||
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
|
||||
)
|
||||
elif not isinstance(ds_cfg, SFTDataset):
|
||||
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||
|
||||
|
||||
@@ -376,10 +376,14 @@ def _load_and_process_single_dataset(
|
||||
streaming: bool = False,
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Load and process a single dataset based on the passed config."""
|
||||
# Load the dataset
|
||||
dataset = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
||||
)
|
||||
# For synthetic datasets, create a minimal placeholder instead of loading from path
|
||||
if dataset_config.type == "_synthetic":
|
||||
dataset = Dataset.from_dict({"text": [""]})
|
||||
else:
|
||||
# Load the dataset
|
||||
dataset = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
||||
)
|
||||
|
||||
# Parse dataset type
|
||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||
|
||||
@@ -10,9 +10,11 @@ from torchao.quantization import quantize_
|
||||
from torchao.quantization.qat import (
|
||||
QATConfig,
|
||||
)
|
||||
from torchao.quantization.qat.fake_quantize_config import Int4WeightFakeQuantizeConfig
|
||||
from torchao.quantization.quant_api import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
)
|
||||
|
||||
@@ -173,6 +175,70 @@ def quantize_model(
|
||||
)
|
||||
|
||||
|
||||
def _make_qat_config(
|
||||
base_config: AOBaseConfig,
|
||||
weight_dtype: TorchAOQuantDType,
|
||||
activation_dtype: TorchAOQuantDType | None,
|
||||
group_size: int | None,
|
||||
) -> QATConfig:
|
||||
"""Build a QATConfig, explicitly constructing fake quantize configs to ensure
|
||||
group_size and other params are properly propagated (torchao's QATConfig(base_config)
|
||||
does not always map these correctly)."""
|
||||
from torchao.quantization.qat.fake_quantize_config import (
|
||||
Float8FakeQuantizeConfig,
|
||||
IntxFakeQuantizeConfig,
|
||||
)
|
||||
|
||||
if isinstance(base_config, MXFakeQuantizeConfig):
|
||||
return QATConfig(
|
||||
activation_config=base_config,
|
||||
weight_config=base_config,
|
||||
)
|
||||
|
||||
# Build explicit weight config
|
||||
weight_fq_config: (
|
||||
Int4WeightFakeQuantizeConfig
|
||||
| IntxFakeQuantizeConfig
|
||||
| Float8FakeQuantizeConfig
|
||||
| None
|
||||
) = None
|
||||
if weight_dtype == TorchAOQuantDType.int4:
|
||||
gs = (
|
||||
group_size
|
||||
if group_size is not None
|
||||
else getattr(base_config, "group_size", 128)
|
||||
)
|
||||
activation_dt = None
|
||||
if activation_dtype == TorchAOQuantDType.int8:
|
||||
activation_dt = torch.bfloat16
|
||||
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
|
||||
activation_dt = torch.float8_e4m3fn
|
||||
kwargs = {"group_size": gs}
|
||||
if activation_dt is not None:
|
||||
kwargs["activation_dtype"] = activation_dt
|
||||
weight_fq_config = Int4WeightFakeQuantizeConfig(**kwargs)
|
||||
elif weight_dtype == TorchAOQuantDType.float8_e4m3fn:
|
||||
weight_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
|
||||
|
||||
# Build explicit activation config
|
||||
activation_fq_config = None
|
||||
if activation_dtype == TorchAOQuantDType.int8:
|
||||
activation_fq_config = IntxFakeQuantizeConfig(
|
||||
dtype=torch.int8, granularity="per_token", is_symmetric=False
|
||||
)
|
||||
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
|
||||
activation_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
|
||||
|
||||
if weight_fq_config is not None:
|
||||
return QATConfig(
|
||||
weight_config=weight_fq_config,
|
||||
activation_config=activation_fq_config,
|
||||
)
|
||||
|
||||
# Fallback to base_config for unhandled combos
|
||||
return QATConfig(base_config)
|
||||
|
||||
|
||||
def prepare_model_for_qat(
|
||||
model,
|
||||
weight_dtype: TorchAOQuantDType,
|
||||
@@ -200,13 +266,9 @@ def prepare_model_for_qat(
|
||||
activation_dtype=activation_dtype,
|
||||
group_size=group_size,
|
||||
)
|
||||
if isinstance(base_config, MXFakeQuantizeConfig):
|
||||
qat_config = QATConfig(
|
||||
activation_config=base_config,
|
||||
weight_config=base_config,
|
||||
)
|
||||
else:
|
||||
qat_config = QATConfig(base_config)
|
||||
qat_config = _make_qat_config(
|
||||
base_config, weight_dtype, activation_dtype, group_size
|
||||
)
|
||||
quantize_(model, qat_config)
|
||||
if quantize_embedding:
|
||||
# activation fake quantization is not supported for embedding layers
|
||||
@@ -215,12 +277,9 @@ def prepare_model_for_qat(
|
||||
activation_dtype=None,
|
||||
group_size=group_size,
|
||||
)
|
||||
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
|
||||
embedding_qat_config = QATConfig(
|
||||
weight_config=embedding_base_config,
|
||||
)
|
||||
else:
|
||||
embedding_qat_config = QATConfig(embedding_base_config)
|
||||
embedding_qat_config = _make_qat_config(
|
||||
embedding_base_config, weight_dtype, None, group_size
|
||||
)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_qat_config,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
from typing import Any, Sequence
|
||||
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
@@ -340,3 +340,19 @@ class JaggedLRRestartScheduler(LRScheduler):
|
||||
return [lr * scale for lr in original]
|
||||
|
||||
return original * scale
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
"""Return serializable state, saving inner_schedule as its own state_dict."""
|
||||
state = {
|
||||
key: value
|
||||
for key, value in self.__dict__.items()
|
||||
if key not in ("optimizer", "inner_schedule")
|
||||
}
|
||||
state["inner_schedule_state"] = self.inner_schedule.state_dict()
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
||||
"""Restore state, including inner_schedule."""
|
||||
inner_state = state_dict.pop("inner_schedule_state")
|
||||
self.__dict__.update(state_dict)
|
||||
self.inner_schedule.load_state_dict(inner_state)
|
||||
|
||||
@@ -22,6 +22,7 @@ from axolotl.utils.schemas.datasets import (
|
||||
PretrainingDataset,
|
||||
SFTDataset,
|
||||
StepwiseSupervisedDataset,
|
||||
SyntheticDataset,
|
||||
)
|
||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
||||
@@ -185,7 +186,13 @@ class AxolotlInputConfig(
|
||||
|
||||
datasets: (
|
||||
Annotated[
|
||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||
list[
|
||||
SFTDataset
|
||||
| DPODataset
|
||||
| KTODataset
|
||||
| StepwiseSupervisedDataset
|
||||
| SyntheticDataset
|
||||
],
|
||||
MinLen(1),
|
||||
]
|
||||
| None
|
||||
@@ -198,7 +205,13 @@ class AxolotlInputConfig(
|
||||
|
||||
test_datasets: (
|
||||
Annotated[
|
||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||
list[
|
||||
SFTDataset
|
||||
| DPODataset
|
||||
| KTODataset
|
||||
| StepwiseSupervisedDataset
|
||||
| SyntheticDataset
|
||||
],
|
||||
MinLen(1),
|
||||
]
|
||||
| None
|
||||
@@ -433,6 +446,12 @@ class AxolotlInputConfig(
|
||||
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
|
||||
},
|
||||
)
|
||||
layer_offloading: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Offload model layer parameters to CPU during forward, prefetch back during backward."
|
||||
},
|
||||
)
|
||||
|
||||
unfrozen_parameters: list[str] | None = Field(
|
||||
default=None,
|
||||
@@ -817,6 +836,12 @@ class AxolotlInputConfig(
|
||||
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||
},
|
||||
)
|
||||
expert_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
|
||||
},
|
||||
)
|
||||
special_tokens: SpecialTokensConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -296,4 +296,42 @@ class KTODataset(BaseModel):
|
||||
revision: str | None = None
|
||||
|
||||
|
||||
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
||||
class SyntheticDataset(BaseModel):
|
||||
"""Synthetic dataset configuration for benchmarking and testing.
|
||||
|
||||
Generates datasets with configurable sequence length, dataset size, and token ID
|
||||
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
|
||||
validating weighted dataset mixes.
|
||||
"""
|
||||
|
||||
path: Literal["synthetic"] = "synthetic"
|
||||
type: Literal["_synthetic"] = "_synthetic"
|
||||
length: int = Field(
|
||||
default=1000,
|
||||
json_schema_extra={"description": "Number of rows to generate"},
|
||||
)
|
||||
sequence_length: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Sequence length per row (defaults to sequence_len from config)"
|
||||
},
|
||||
)
|
||||
min_input_id: int = Field(
|
||||
default=100,
|
||||
json_schema_extra={"description": "Minimum token ID for generation"},
|
||||
)
|
||||
max_input_id: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
|
||||
},
|
||||
)
|
||||
seed: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Random seed for reproducibility"},
|
||||
)
|
||||
|
||||
|
||||
DatasetConfig = (
|
||||
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
|
||||
)
|
||||
|
||||
@@ -87,6 +87,11 @@ class CustomSupportedOptimizers(str, Enum):
|
||||
came_pytorch = "came_pytorch"
|
||||
muon = "muon"
|
||||
dion = "dion"
|
||||
flash_adamw = "flash_adamw"
|
||||
flash_adam = "flash_adam"
|
||||
flash_sgd = "flash_sgd"
|
||||
flash_sgdw = "flash_sgdw"
|
||||
flash_lion = "flash_lion"
|
||||
|
||||
|
||||
class RingAttnFunc(str, Enum):
|
||||
|
||||
@@ -790,6 +790,14 @@ class OptimizationValidationMixin:
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _resolve_fsdp_version(data):
|
||||
"""Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version."""
|
||||
fsdp_version = data.get("fsdp_version")
|
||||
if fsdp_version is None:
|
||||
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||
return fsdp_version
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
@@ -799,15 +807,32 @@ class OptimizationValidationMixin:
|
||||
"Muon optimizer is currently incompatible with DeepSpeed"
|
||||
)
|
||||
if data.get("fsdp") or data.get("fsdp_config"):
|
||||
fsdp_version = data.get("fsdp_version")
|
||||
if fsdp_version is None:
|
||||
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||
fsdp_version = cls._resolve_fsdp_version(data)
|
||||
if str(fsdp_version) != "2":
|
||||
raise ValueError(
|
||||
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_flashoptim_deepspeed_fsdp(cls, data):
|
||||
optimizer = data.get("optimizer") or ""
|
||||
if str(optimizer).startswith("flash_"):
|
||||
if data.get("deepspeed"):
|
||||
raise ValueError(
|
||||
f"{optimizer} optimizer is incompatible with DeepSpeed. "
|
||||
"Flash optimizers only support DDP and FSDP2."
|
||||
)
|
||||
if data.get("fsdp") or data.get("fsdp_config"):
|
||||
fsdp_version = cls._resolve_fsdp_version(data)
|
||||
if str(fsdp_version) != "2":
|
||||
raise ValueError(
|
||||
f"{optimizer} optimizer is only compatible with FSDP2. "
|
||||
"Set fsdp_version: 2 to use flash optimizers with FSDP."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_flattening_fa(cls, data):
|
||||
@@ -1361,6 +1386,14 @@ class ComplexValidationMixin:
|
||||
self.tensor_parallel_size = 1
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_expert_parallel_size(self):
|
||||
if not getattr(self, "expert_parallel_size", None):
|
||||
self.expert_parallel_size = 1
|
||||
elif self.expert_parallel_size < 1:
|
||||
raise ValueError("expert_parallel_size must be >= 1")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_context_parallel_size(self):
|
||||
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||
|
||||
@@ -15,6 +15,8 @@ import datasets
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
import transformers.utils as _transformers_utils
|
||||
import transformers.utils.import_utils as _import_utils
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
from tokenizers import AddedToken
|
||||
@@ -29,6 +31,26 @@ from tests.hf_offline_utils import (
|
||||
|
||||
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||
|
||||
# Shim for deepseek v3
|
||||
if not hasattr(_import_utils, "is_torch_fx_available"):
|
||||
|
||||
def _is_torch_fx_available():
|
||||
try:
|
||||
import torch.fx # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
_import_utils.is_torch_fx_available = _is_torch_fx_available
|
||||
|
||||
if not hasattr(_transformers_utils, "is_flash_attn_greater_or_equal_2_10"):
|
||||
from transformers.utils import is_flash_attn_greater_or_equal as _is_flash_attn_gte
|
||||
|
||||
_transformers_utils.is_flash_attn_greater_or_equal_2_10 = lambda: (
|
||||
_is_flash_attn_gte("2.10")
|
||||
)
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
def decorator(func):
|
||||
|
||||
@@ -20,6 +20,7 @@ Test strategy:
|
||||
- Tolerances account for tf32 accumulation in Triton kernels
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
@@ -34,6 +35,21 @@ pytestmark = pytest.mark.skipif(
|
||||
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
|
||||
|
||||
|
||||
def skip_on_out_of_resources(func):
|
||||
"""Skip test if Triton kernel exceeds GPU shared memory limits."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
if "OutOfResources" in type(exc).__name__:
|
||||
pytest.skip(f"GPU shared memory too small: {exc}")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helpers
|
||||
# =============================================================================
|
||||
@@ -209,6 +225,7 @@ def make_test_data(
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestForwardPass:
|
||||
"""Test forward pass of fused scatter2scatter_lora kernel."""
|
||||
|
||||
@@ -288,6 +305,7 @@ class TestForwardPass:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestForwardGrouped:
|
||||
"""Test forward pass with grouped_in/grouped_out configurations."""
|
||||
|
||||
@@ -377,6 +395,7 @@ class TestForwardGrouped:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestLoRAGradients:
|
||||
"""Test backward LoRA gradient computation (dA, dB)."""
|
||||
|
||||
@@ -452,6 +471,7 @@ class TestLoRAGradients:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestAutograd:
|
||||
"""Test full autograd integration through ScatterMoELoRA."""
|
||||
|
||||
@@ -620,6 +640,7 @@ class TestAutograd:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestBaseEquivalence:
|
||||
"""When scaling=0, fused kernel should match base scatter2scatter."""
|
||||
|
||||
@@ -692,6 +713,7 @@ class TestBaseEquivalence:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestLoRAAdditivity:
|
||||
"""Test that the LoRA component is correctly additive."""
|
||||
|
||||
@@ -749,6 +771,7 @@ class TestLoRAAdditivity:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestParallelExpertsModule:
|
||||
"""Test the ParallelExperts module with LoRA."""
|
||||
|
||||
@@ -816,6 +839,7 @@ class TestParallelExpertsModule:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestEdgeCases:
|
||||
"""Edge cases and boundary conditions."""
|
||||
|
||||
@@ -913,6 +937,7 @@ class TestEdgeCases:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestFusedDX:
|
||||
"""Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A."""
|
||||
|
||||
@@ -980,6 +1005,7 @@ class TestFusedDX:
|
||||
def test_basic(self):
|
||||
self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
||||
|
||||
@skip_on_out_of_resources
|
||||
def test_large(self):
|
||||
self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
||||
|
||||
@@ -1122,6 +1148,7 @@ class TestFusedDX:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestFusedGatherBackward:
|
||||
"""Test fused gather + backward dA/dB kernel."""
|
||||
|
||||
@@ -1174,6 +1201,7 @@ class TestFusedGatherBackward:
|
||||
def test_basic(self):
|
||||
self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
||||
|
||||
@skip_on_out_of_resources
|
||||
def test_large(self):
|
||||
self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
||||
|
||||
@@ -1183,6 +1211,7 @@ class TestFusedGatherBackward:
|
||||
def test_k1(self):
|
||||
self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1)
|
||||
|
||||
@skip_on_out_of_resources
|
||||
def test_many_experts(self):
|
||||
self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4)
|
||||
|
||||
@@ -1269,6 +1298,8 @@ class TestFusedGatherBackward:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="flaky", strict=False)
|
||||
class TestTokenRounding:
|
||||
"""Test token rounding utility and its integration with backward kernels."""
|
||||
|
||||
@@ -1315,6 +1346,7 @@ class TestTokenRounding:
|
||||
)
|
||||
prev = padded_offsets[e].item()
|
||||
|
||||
@skip_on_out_of_resources
|
||||
def test_round_with_fused_gather(self):
|
||||
"""Token rounding + fused gather gives same result as plain fused gather."""
|
||||
from importlib import import_module
|
||||
@@ -1414,6 +1446,7 @@ class TestTokenRounding:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestCombinedOptimizations:
|
||||
"""Test all optimizations together."""
|
||||
|
||||
@@ -1583,6 +1616,7 @@ def _make_mock_sigmoid_moe_block(
|
||||
return moe_block, T, H, FF, E, K
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestHFScatterMoESigmoidRouting:
|
||||
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
|
||||
|
||||
@@ -1724,6 +1758,7 @@ class TestHFScatterMoESigmoidRouting:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestHFScatterMoESigmoidWithSharedExperts:
|
||||
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
|
||||
|
||||
|
||||
@@ -933,7 +933,7 @@ class TestKernelizeIntegration:
|
||||
def _get_repo_path():
|
||||
"""Get the path to scattermoe_lora within axolotl's plugin."""
|
||||
return (
|
||||
Path(__file__).parent.parent.parent
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "src"
|
||||
/ "axolotl"
|
||||
/ "integrations"
|
||||
@@ -1219,7 +1219,7 @@ class TestSharedExpertHandling:
|
||||
|
||||
# Kernelize
|
||||
repo_path = (
|
||||
Path(__file__).parent.parent.parent
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "src"
|
||||
/ "axolotl"
|
||||
/ "integrations"
|
||||
|
||||
@@ -9,8 +9,8 @@ import subprocess
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ from axolotl.utils.dict import DictDefault
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="DeepSeek-V3-11M remote model code needs _supports_flash_attn=True for newer transformers"
|
||||
)
|
||||
class TestDeepseekV3:
|
||||
"""
|
||||
Test case for DeepseekV3 models
|
||||
|
||||
@@ -262,6 +262,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@pytest.mark.skip(reason="TRL ORPO trainer has internal zip() length mismatch bug")
|
||||
@with_temp_dir
|
||||
def test_orpo_lora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
|
||||
75
tests/e2e/test_llama4_moe_aux_free.py
Normal file
75
tests/e2e/test_llama4_moe_aux_free.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
E2E smoke test for Llama 4 aux-loss-free routing via plugin
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestLlama4MoeAuxFree(unittest.TestCase):
|
||||
"""Smoke test to ensure aux-free plugin patches Llama 4 MoE correctly."""
|
||||
|
||||
@with_temp_dir
|
||||
def test_llama4_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "yujiepan/llama-4-tiny-random",
|
||||
"tokenizer_config": "yujiepan/llama-4-tiny-random",
|
||||
"trust_remote_code": False,
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
|
||||
assert patched is not None, (
|
||||
"Llama 4 MoE layer was not patched by aux-free plugin"
|
||||
)
|
||||
assert patched._afb_bias.ndim == 1
|
||||
assert patched._afb_counts.ndim == 1
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -70,7 +70,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -125,7 +125,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -183,7 +183,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
79
tests/e2e/test_moe_aux_free.py
Normal file
79
tests/e2e/test_moe_aux_free.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestMoeAuxFree(unittest.TestCase):
|
||||
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
|
||||
|
||||
@with_temp_dir
|
||||
def test_mixtral_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
# Aux-free plugin and toggles
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
# Inspect model modules for a patched MoE layer
|
||||
patched = None
|
||||
for m in model.modules():
|
||||
if hasattr(m, "_afb_patched") and m._afb_patched is True:
|
||||
patched = m
|
||||
break
|
||||
assert patched is not None, "No MoE layer patched by aux-free plugin"
|
||||
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
|
||||
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
|
||||
# ensure counts buffer got reset by callback (best effort)
|
||||
assert torch.all(patched._afb_counts == 0)
|
||||
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
91
tests/e2e/test_moe_aux_parity.py
Normal file
91
tests/e2e/test_moe_aux_parity.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
|
||||
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
|
||||
def _last_logged_loss(trainer) -> float | None:
|
||||
# Scan log_history for the most recent entry with a 'loss' key
|
||||
for entry in reversed(trainer.state.log_history):
|
||||
if isinstance(entry, dict) and "loss" in entry:
|
||||
return float(entry["loss"])
|
||||
return None
|
||||
|
||||
|
||||
class TestMoeAuxParity(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
|
||||
base_cfg = {
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 8,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
"seed": 42,
|
||||
"logging_steps": 1,
|
||||
}
|
||||
|
||||
# Baseline: aux-loss (gshard)
|
||||
cfg0 = DictDefault(dict(base_cfg))
|
||||
cfg0.output_dir = f"{temp_dir}/baseline"
|
||||
cfg0 = validate_config(cfg0)
|
||||
normalize_config(cfg0)
|
||||
# baseline uses default aux-loss routing; no plugin registration
|
||||
dataset_meta0 = load_datasets(cfg=cfg0)
|
||||
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
|
||||
loss0 = _last_logged_loss(trainer0)
|
||||
assert loss0 is not None
|
||||
|
||||
# Release baseline resources before starting aux-free run
|
||||
del model0, trainer0, dataset_meta0
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Aux-free: plugin + noaux_tc
|
||||
cfg1 = DictDefault(dict(base_cfg))
|
||||
cfg1.output_dir = f"{temp_dir}/auxfree"
|
||||
cfg1.plugins = [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
]
|
||||
cfg1.moe_balance_type = "noaux_tc"
|
||||
cfg1.moe_update_rate = 0.01
|
||||
cfg1.moe_update_momentum = 0.9
|
||||
cfg1.moe_bias_cap = 2.0
|
||||
prepare_plugins(cfg1)
|
||||
cfg1 = validate_config(cfg1)
|
||||
normalize_config(cfg1)
|
||||
dataset_meta1 = load_datasets(cfg=cfg1)
|
||||
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
|
||||
loss1 = _last_logged_loss(trainer1)
|
||||
assert loss1 is not None
|
||||
|
||||
# Assert aux-free loss is within 10% of aux-loss baseline
|
||||
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"
|
||||
@@ -4,6 +4,8 @@ E2E tests for custom optimizers using Llama
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -282,3 +284,60 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"optimizer_name,expected_class,learning_rate",
|
||||
[
|
||||
("flash_adamw", "FlashAdamW", 0.00001),
|
||||
("flash_adam", "FlashAdam", 0.00001),
|
||||
("flash_sgd", "FlashSGD", 0.01),
|
||||
("flash_sgdw", "FlashSGDW", 0.01),
|
||||
("flash_lion", "FlashLion", 0.0001),
|
||||
],
|
||||
)
|
||||
def test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate):
|
||||
pytest.importorskip("flashoptim")
|
||||
temp_dir = str(tmp_path)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": learning_rate,
|
||||
"optimizer": optimizer_name,
|
||||
"max_steps": 5,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert trainer.optimizer.optimizer.__class__.__name__ == expected_class
|
||||
|
||||
@@ -35,6 +35,14 @@ from tests.e2e.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_fake_quant_config_dtype(config):
|
||||
"""Get the weight dtype from a fake quantize config, handling different config types."""
|
||||
if hasattr(config, "dtype"):
|
||||
return config.dtype
|
||||
# Int4WeightFakeQuantizeConfig doesn't have .dtype — weight is always int4
|
||||
return torch.int4
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def model():
|
||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -157,6 +165,18 @@ class TestQuantization:
|
||||
expected_exception,
|
||||
expected_tensor_class,
|
||||
):
|
||||
# TODO: add mslk-cuda as a CI dependency once pytorch 2.10.x is available
|
||||
# (see https://pypi.org/project/mslk-cuda/)
|
||||
if expected_tensor_class is Int4Tensor and activation_dtype is None:
|
||||
try:
|
||||
from torchao.quantization.quantize_.workflows.int4.int4_tensor import (
|
||||
int4_row_quantize_zp,
|
||||
)
|
||||
|
||||
if int4_row_quantize_zp is None:
|
||||
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
|
||||
except ImportError:
|
||||
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
quantize_model(
|
||||
@@ -252,28 +272,24 @@ class TestQuantization:
|
||||
if quantize_embedding:
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
||||
== weight_dtype.value
|
||||
)
|
||||
embed_config = model.model.embed_tokens.weight_fake_quantizer.config
|
||||
assert _get_fake_quant_config_dtype(embed_config) == weight_dtype.value
|
||||
if group_size:
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||
== group_size
|
||||
)
|
||||
assert embed_config.group_size == group_size
|
||||
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child, FakeQuantizedLinear)
|
||||
assert hasattr(child, "weight_fake_quantizer")
|
||||
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
||||
w_config = child.weight_fake_quantizer.config
|
||||
assert _get_fake_quant_config_dtype(w_config) == weight_dtype.value
|
||||
if group_size:
|
||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||
assert w_config.group_size == group_size
|
||||
if activation_dtype:
|
||||
assert hasattr(child, "activation_fake_quantizer")
|
||||
a_config = child.activation_fake_quantizer.config
|
||||
assert (
|
||||
child.activation_fake_quantizer.config.dtype
|
||||
== activation_dtype.value
|
||||
_get_fake_quant_config_dtype(a_config) == activation_dtype.value
|
||||
)
|
||||
else:
|
||||
assert child.activation_fake_quantizer is None
|
||||
@@ -374,9 +390,16 @@ class TestQuantizationCallback:
|
||||
|
||||
# ensure model has been quantized
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
# Only test enable/disable toggling if the fake quantizer supports it
|
||||
# (Int4WeightFakeQuantizer does not have an 'enabled' attribute)
|
||||
supports_toggle = hasattr(
|
||||
model.model.embed_tokens.weight_fake_quantizer, "enabled"
|
||||
)
|
||||
if supports_toggle:
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
qat_callback = QATCallback(cfg)
|
||||
|
||||
@@ -388,9 +411,10 @@ class TestQuantizationCallback:
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should have been disabled
|
||||
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert not model.lm_head.weight_fake_quantizer.enabled
|
||||
if supports_toggle:
|
||||
# quantization should have been disabled
|
||||
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert not model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
trainer_state.global_step = 100
|
||||
qat_callback.on_step_begin(
|
||||
@@ -400,9 +424,10 @@ class TestQuantizationCallback:
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should have been enabled
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
if supports_toggle:
|
||||
# quantization should have been enabled
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
@require_torch_2_8_0
|
||||
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
||||
@@ -424,9 +449,10 @@ class TestQuantizationCallback:
|
||||
|
||||
# ensure model has been quantized
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
qat_callback = QATCallback(cfg)
|
||||
# simulate first training step
|
||||
@@ -438,5 +464,6 @@ class TestQuantizationCallback:
|
||||
)
|
||||
|
||||
# quantization should be enabled from the get-go
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
76
tests/e2e/test_qwen3_moe_aux_free.py
Normal file
76
tests/e2e/test_qwen3_moe_aux_free.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestQwen3MoeAuxFree(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
# Aux-free plugin and toggles
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
# check that at least one sparse MoE block has been patched
|
||||
found = False
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(
|
||||
m, "_afb_patched"
|
||||
):
|
||||
assert m._afb_patched is True
|
||||
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
|
||||
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
|
||||
found = True
|
||||
break
|
||||
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
|
||||
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
74
tests/e2e/test_ring_moe_aux_free.py
Normal file
74
tests/e2e/test_ring_moe_aux_free.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
E2E smoke test for Ring 2.0 aux-loss-free routing via plugin
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestRingMoeAuxFree(unittest.TestCase):
|
||||
"""Smoke test to ensure aux-free plugin patches Ring Mini 2.0 correctly."""
|
||||
|
||||
@with_temp_dir
|
||||
def test_ring_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "yujiepan/ring-tiny-random",
|
||||
"tokenizer_config": "yujiepan/ring-tiny-random",
|
||||
"trust_remote_code": True,
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
# Aux-free plugin config
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
|
||||
assert patched is not None, "Ring MoE layer was not patched by aux-free plugin"
|
||||
assert patched._afb_bias.ndim == 1
|
||||
assert patched._afb_counts.ndim == 1
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -12,7 +12,11 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from tbparse import SummaryReader
|
||||
|
||||
try:
|
||||
from tbparse import SummaryReader
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
SummaryReader = None
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -185,6 +189,10 @@ def check_tensorboard(
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
"""
|
||||
if SummaryReader is None:
|
||||
raise unittest.SkipTest(
|
||||
"tbparse is not installed; skipping tensorboard assertions"
|
||||
)
|
||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
|
||||
@@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||
from axolotl.utils.schemas.datasets import SFTDataset
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
@@ -1731,3 +1732,52 @@ class TestDataloaderValidation(BaseValidation):
|
||||
assert new_cfg.dataloader_num_workers == 8
|
||||
assert new_cfg.dataloader_pin_memory is True
|
||||
assert new_cfg.dataloader_prefetch_factor == 256
|
||||
|
||||
|
||||
class TestSyntheticDatasetValidation(BaseValidation):
|
||||
"""
|
||||
Tests for synthetic dataset config validation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_cfg(minimal_cfg, datasets):
|
||||
raw = dict(minimal_cfg)
|
||||
raw["datasets"] = datasets
|
||||
return DictDefault(raw)
|
||||
|
||||
def test_synthetic_dict_config_validates(self, minimal_cfg):
|
||||
"""Synthetic dataset passed as a raw dict should not raise."""
|
||||
cfg = self._make_cfg(
|
||||
minimal_cfg,
|
||||
[
|
||||
{
|
||||
"path": "synthetic",
|
||||
"type": "_synthetic",
|
||||
"length": 100,
|
||||
"sequence_length": 64,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||
|
||||
def test_synthetic_already_sft_does_not_crash(self, minimal_cfg):
|
||||
"""Synthetic dataset already parsed as SFTDataset should not raise AttributeError."""
|
||||
sft = SFTDataset(path="synthetic", type="_synthetic")
|
||||
cfg = self._make_cfg(minimal_cfg, [sft])
|
||||
|
||||
# Before the fix, this raised:
|
||||
# AttributeError: 'SFTDataset' object has no attribute 'get'
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||
|
||||
def test_non_synthetic_sft_validates(self, minimal_cfg):
|
||||
"""A regular SFT dataset should validate without being treated as synthetic."""
|
||||
cfg = self._make_cfg(
|
||||
minimal_cfg,
|
||||
[{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test"
|
||||
|
||||
125
tests/prompt_strategies/test_synthetic.py
Normal file
125
tests/prompt_strategies/test_synthetic.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Tests for the synthetic dataset generator."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestSyntheticDatasetStrategy(unittest.TestCase):
|
||||
def test_generates_correct_shape(self):
|
||||
strategy = SyntheticDatasetStrategy(
|
||||
sequence_length=128,
|
||||
length=50,
|
||||
min_input_id=1,
|
||||
max_input_id=1000,
|
||||
seed=42,
|
||||
)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
result = strategy.wrap_dataset(dummy)
|
||||
|
||||
assert len(result) == 50
|
||||
assert len(result[0]["input_ids"]) == 128
|
||||
assert len(result[0]["attention_mask"]) == 128
|
||||
assert len(result[0]["labels"]) == 128
|
||||
|
||||
def test_attention_mask_all_ones(self):
|
||||
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
result = strategy.wrap_dataset(dummy)
|
||||
|
||||
for row in result:
|
||||
assert all(v == 1 for v in row["attention_mask"])
|
||||
|
||||
def test_labels_equal_input_ids(self):
|
||||
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
result = strategy.wrap_dataset(dummy)
|
||||
|
||||
for row in result:
|
||||
assert row["input_ids"] == row["labels"]
|
||||
|
||||
def test_input_id_range(self):
|
||||
strategy = SyntheticDatasetStrategy(
|
||||
sequence_length=64,
|
||||
length=100,
|
||||
min_input_id=500,
|
||||
max_input_id=600,
|
||||
seed=42,
|
||||
)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
result = strategy.wrap_dataset(dummy)
|
||||
|
||||
for row in result:
|
||||
for token_id in row["input_ids"]:
|
||||
assert 500 <= token_id < 600
|
||||
|
||||
def test_seed_reproducibility(self):
|
||||
kwargs = dict(
|
||||
sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123
|
||||
)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
|
||||
result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
||||
result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
||||
|
||||
for r1, r2 in zip(result1, result2, strict=True):
|
||||
assert r1["input_ids"] == r2["input_ids"]
|
||||
|
||||
def test_different_seeds_differ(self):
|
||||
common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
|
||||
result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy)
|
||||
result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy)
|
||||
|
||||
any_different = any(
|
||||
r1["input_ids"] != r2["input_ids"]
|
||||
for r1, r2 in zip(result1, result2, strict=True)
|
||||
)
|
||||
assert any_different
|
||||
|
||||
def test_load_function_with_ds_cfg(self):
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.vocab_size = 32000
|
||||
cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False})
|
||||
ds_cfg = {
|
||||
"sequence_length": 256,
|
||||
"length": 5,
|
||||
"min_input_id": 10,
|
||||
"max_input_id": 100,
|
||||
"seed": 0,
|
||||
}
|
||||
|
||||
strategy = load(tokenizer, cfg, ds_cfg=ds_cfg)
|
||||
assert isinstance(strategy, SyntheticDatasetStrategy)
|
||||
assert strategy.sequence_length == 256
|
||||
assert strategy.length == 5
|
||||
assert strategy.min_input_id == 10
|
||||
assert strategy.max_input_id == 100
|
||||
|
||||
def test_load_defaults_from_cfg(self):
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.vocab_size = 32000
|
||||
cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False})
|
||||
|
||||
strategy = load(tokenizer, cfg, ds_cfg={})
|
||||
assert strategy.sequence_length == 1024
|
||||
assert strategy.max_input_id == 32000
|
||||
assert strategy.length == 1000
|
||||
|
||||
def test_load_with_no_ds_cfg(self):
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.vocab_size = 50000
|
||||
cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False})
|
||||
|
||||
strategy = load(tokenizer, cfg)
|
||||
assert strategy.sequence_length == 2048
|
||||
assert strategy.max_input_id == 50000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,45 +0,0 @@
|
||||
"""
|
||||
Unit tests for the monkey patch for expand mask to handle packed sequences
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
|
||||
|
||||
|
||||
class TestExpandMask(unittest.TestCase):
|
||||
"""
|
||||
Test class for attention mask expansion for packed sequences
|
||||
"""
|
||||
|
||||
def test_output(self):
|
||||
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
|
||||
dtype = torch.float32
|
||||
expected_output = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
]
|
||||
],
|
||||
]
|
||||
)
|
||||
# Check that the output matches the expected output
|
||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
666
tests/unit/test_aux_free_adapters.py
Normal file
666
tests/unit/test_aux_free_adapters.py
Normal file
@@ -0,0 +1,666 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from importlib import util as importlib_util
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
||||
|
||||
|
||||
def _cfg(**overrides):
|
||||
defaults = dict(
|
||||
moe_balance_type="noaux_tc",
|
||||
moe_update_rate=0.1,
|
||||
moe_update_momentum=0.9,
|
||||
moe_bias_cap=2.0,
|
||||
moe_afb_warmup_steps=0,
|
||||
moe_bias_sync_group="world",
|
||||
expert_parallel_size=1,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return SimpleNamespace(**defaults)
|
||||
|
||||
|
||||
def _load_bailing_modules():
|
||||
repo_dir = snapshot_download(
|
||||
repo_id="inclusionAI/Ring-mini-2.0",
|
||||
allow_patterns=[
|
||||
"configuration_bailing_moe_v2.py",
|
||||
"modeling_bailing_moe_v2.py",
|
||||
"__init__.py",
|
||||
],
|
||||
)
|
||||
repo = Path(repo_dir)
|
||||
config_path = repo / "configuration_bailing_moe_v2.py"
|
||||
modeling_path = repo / "modeling_bailing_moe_v2.py"
|
||||
|
||||
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
|
||||
if config_name not in sys.modules:
|
||||
spec = importlib_util.spec_from_file_location(config_name, config_path)
|
||||
module = importlib_util.module_from_spec(spec)
|
||||
sys.modules[config_name] = module
|
||||
sys.modules["configuration_bailing_moe_v2"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
config_module = sys.modules[config_name]
|
||||
|
||||
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
|
||||
if modeling_name not in sys.modules:
|
||||
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
|
||||
module = importlib_util.module_from_spec(spec)
|
||||
sys.modules[modeling_name] = module
|
||||
sys.modules["modeling_bailing_moe_v2"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
modeling_module = sys.modules[modeling_name]
|
||||
|
||||
BailingMoeV2Config = config_module.BailingMoeV2Config
|
||||
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
|
||||
|
||||
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
|
||||
|
||||
|
||||
def _build_bailing_model():
|
||||
BailingConfig, BailingBlock = _load_bailing_modules()
|
||||
config = BailingConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
moe_intermediate_size=32,
|
||||
num_experts=4,
|
||||
num_shared_experts=None,
|
||||
num_experts_per_tok=2,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
block = BailingBlock(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, layer):
|
||||
super().__init__()
|
||||
self.block = layer
|
||||
self.config = SimpleNamespace(model_type="bailing_moe")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.block(hidden_states)
|
||||
|
||||
return DummyModel(block), block
|
||||
|
||||
|
||||
def _build_llama4_model():
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||
|
||||
# Build config without __post_init__ validation (works around a
|
||||
# huggingface_hub strict-dataclass type mismatch for layer_types).
|
||||
config = object.__new__(__import__("transformers").Llama4TextConfig)
|
||||
config.__dict__.update(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_experts_per_tok=2,
|
||||
num_hidden_layers=2,
|
||||
hidden_act="silu",
|
||||
layer_types=None,
|
||||
)
|
||||
layer = Llama4TextMoe(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="llama4")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _build_mixtral_model():
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
config = MixtralConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
layer = MixtralSparseMoeBlock(config)
|
||||
layer.config = config
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="mixtral")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _build_qwen35_moe_model():
|
||||
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
|
||||
Qwen3_5MoeTextConfig,
|
||||
)
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
Qwen3_5MoeSparseMoeBlock,
|
||||
)
|
||||
|
||||
config = Qwen3_5MoeTextConfig(
|
||||
hidden_size=16,
|
||||
moe_intermediate_size=32,
|
||||
shared_expert_intermediate_size=32,
|
||||
num_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=2,
|
||||
)
|
||||
layer = Qwen3_5MoeSparseMoeBlock(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="qwen3_5_moe")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
if args is None:
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
if state is None:
|
||||
state = SimpleNamespace(global_step=1, log_history=[])
|
||||
if control is None:
|
||||
control = SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
|
||||
class DummyTrainer:
|
||||
def __init__(self, state_obj, control_obj):
|
||||
self.state = state_obj
|
||||
self.control = control_obj
|
||||
|
||||
def log(self, logs):
|
||||
output = dict(logs)
|
||||
output["step"] = self.state.global_step
|
||||
self.state.log_history.append(output)
|
||||
self.control.should_log = True
|
||||
|
||||
dummy_trainer = DummyTrainer(state, control)
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer)
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
callback.on_step_end(args=args, state=state, control=control)
|
||||
return state, control
|
||||
|
||||
|
||||
class TestAuxFreeAdapters(unittest.TestCase):
|
||||
def test_bailing_adapter_updates_counts_and_bias(self):
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertTrue(hasattr(block, "_afb_bias"))
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
||||
self.assertFalse(
|
||||
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
|
||||
)
|
||||
|
||||
def test_llama4_adapter_biases_router_selection(self):
|
||||
model, layer = _build_llama4_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||
hidden = torch.randn(2, 4, layer.hidden_dim)
|
||||
layer(hidden)
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(
|
||||
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||
)
|
||||
|
||||
def test_bias_warmup_respected(self):
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg(moe_afb_warmup_steps=2)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
def _step():
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
# Warmup steps should leave bias untouched.
|
||||
_step()
|
||||
self.assertTrue(
|
||||
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||
)
|
||||
|
||||
_step()
|
||||
self.assertTrue(
|
||||
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||
)
|
||||
|
||||
# Third step exceeds warmup -> bias should update.
|
||||
_step()
|
||||
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
||||
|
||||
def test_mixtral_adapter_patches_router_not_forward(self):
|
||||
"""Verify that aux-free patches the router (gate) only, and the
|
||||
v5 block forward signature (single tensor return) is preserved."""
|
||||
model, layer = _build_mixtral_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Gate should be patched, not the block forward
|
||||
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||
|
||||
# v5 block forward returns a single tensor (not a tuple with logits)
|
||||
hidden = torch.randn(2, 3, layer.config.hidden_size)
|
||||
out = layer(hidden)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
self.assertEqual(out.shape, hidden.shape)
|
||||
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
def test_mixtral_adapter_bias_affects_selection(self):
|
||||
"""When bias is large for one expert, it should be selected more often."""
|
||||
model, layer = _build_mixtral_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Set a large bias for expert 0 to force its selection
|
||||
layer._afb_bias.zero_()
|
||||
layer._afb_bias[0] = 10.0
|
||||
|
||||
hidden = torch.randn(2, 8, layer.config.hidden_size)
|
||||
num_tokens = 2 * 8 # batch * seq
|
||||
layer(hidden)
|
||||
|
||||
# With top_k=2, expert 0 should appear in every token's selection
|
||||
# (once per token = num_tokens counts, not num_tokens * top_k)
|
||||
counts = layer._afb_counts.clone()
|
||||
self.assertEqual(
|
||||
int(counts[0].item()),
|
||||
num_tokens,
|
||||
msg="Expert 0 should be selected for every token when heavily biased",
|
||||
)
|
||||
|
||||
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
|
||||
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
|
||||
output includes shared expert contribution."""
|
||||
model, layer = _build_qwen35_moe_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Gate should be patched
|
||||
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||
# Shared expert should be unmodified
|
||||
self.assertTrue(hasattr(layer, "shared_expert"))
|
||||
self.assertTrue(hasattr(layer, "shared_expert_gate"))
|
||||
|
||||
# Forward should return a single tensor (shared + routed)
|
||||
hidden_size = layer.gate.hidden_dim
|
||||
hidden = torch.randn(2, 3, hidden_size)
|
||||
out = layer(hidden)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
self.assertEqual(out.shape, hidden.shape)
|
||||
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
|
||||
def test_qwen35_moe_adapter_bias_updates(self):
|
||||
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
|
||||
model, layer = _build_qwen35_moe_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
hidden_size = layer.gate.hidden_dim
|
||||
hidden = torch.randn(2, 4, hidden_size)
|
||||
layer(hidden)
|
||||
|
||||
# Bias should start at zero
|
||||
self.assertTrue(
|
||||
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
|
||||
)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
# After callback: counts reset, EMA updated, bias updated
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(
|
||||
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||
)
|
||||
|
||||
def test_qwen35_moe_adapter_model_type_matching(self):
|
||||
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
|
||||
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
|
||||
|
||||
adapter = Qwen35MoeAdapter()
|
||||
|
||||
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
|
||||
model_text = SimpleNamespace(
|
||||
config=SimpleNamespace(model_type="qwen3_5_moe_text")
|
||||
)
|
||||
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
|
||||
|
||||
self.assertTrue(adapter.matches(model_moe))
|
||||
self.assertTrue(adapter.matches(model_text))
|
||||
self.assertFalse(adapter.matches(model_other))
|
||||
|
||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
self.skipTest(
|
||||
"Cannot safely test deferred EP group resolution when a process group is already initialized"
|
||||
)
|
||||
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertIsNotNone(plugin._shim)
|
||||
self.assertIsNone(plugin._shim.ep_group)
|
||||
|
||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp_init.close()
|
||||
init_method = f"file://{tmp_init.name}"
|
||||
dist.init_process_group(
|
||||
backend="gloo", init_method=init_method, world_size=1, rank=0
|
||||
)
|
||||
try:
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
_run_callback(
|
||||
plugin,
|
||||
cfg,
|
||||
args=SimpleNamespace(logging_steps=1),
|
||||
state=SimpleNamespace(global_step=1, log_history=[]),
|
||||
control=SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
),
|
||||
)
|
||||
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
os.unlink(tmp_init.name)
|
||||
|
||||
def test_telemetry_logging(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
hidden_dim = layer.config.hidden_size
|
||||
hidden = torch.randn(2, 3, hidden_dim)
|
||||
layer(hidden)
|
||||
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
state = SimpleNamespace(global_step=1, log_history=[])
|
||||
control = SimpleNamespace(
|
||||
should_log=False,
|
||||
should_evaluate=False,
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
_run_callback(plugin, cfg, args=args, state=state, control=control)
|
||||
|
||||
self.assertTrue(control.should_log)
|
||||
self.assertTrue(state.log_history)
|
||||
telemetry = state.log_history[-1]
|
||||
self.assertEqual(telemetry["step"], state.global_step)
|
||||
self.assertIn("moe_afb/l0_load_min", telemetry)
|
||||
self.assertIn("moe_afb/l0_load_max", telemetry)
|
||||
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
||||
|
||||
def test_get_num_experts_v5_attribute_paths(self):
|
||||
"""Verify get_num_experts works with v5 attribute layout where
|
||||
num_experts is on gate/experts sub-modules, not the block."""
|
||||
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||
|
||||
adapter = MixtralAdapter()
|
||||
|
||||
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
|
||||
block = SimpleNamespace(
|
||||
gate=SimpleNamespace(num_experts=8),
|
||||
experts=SimpleNamespace(num_experts=8),
|
||||
)
|
||||
self.assertEqual(adapter.get_num_experts(block), 8)
|
||||
|
||||
# Also works when num_experts is directly on block
|
||||
block2 = SimpleNamespace(num_experts=4)
|
||||
self.assertEqual(adapter.get_num_experts(block2), 4)
|
||||
|
||||
|
||||
class TestAuxFreeKernelComposition(unittest.TestCase):
|
||||
"""Tests that aux-free bias composes correctly with kernel routing."""
|
||||
|
||||
def test_sonicmoe_softmax_routing_with_afb_bias(self):
|
||||
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
T = 6
|
||||
|
||||
# Build a mock MoE block with gate attributes
|
||||
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
gate.top_k = top_k
|
||||
gate.num_experts = num_experts
|
||||
gate.norm_topk_prob = True
|
||||
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(T, hidden_dim)
|
||||
|
||||
# Baseline: no bias
|
||||
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
self.assertEqual(scores_base.shape[0], T * top_k)
|
||||
|
||||
# Now register aux-free buffers and set heavy bias on expert 0
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_bias[0] = 100.0
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
|
||||
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
# Expert 0 should be selected for every token
|
||||
self.assertTrue(
|
||||
(exp_biased == 0).any(),
|
||||
"Expert 0 should appear in selections when heavily biased",
|
||||
)
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||
# Total counts should equal T * top_k
|
||||
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||
|
||||
def test_sonicmoe_routing_without_bias_unchanged(self):
|
||||
"""Without _afb_bias, routing should produce identical results."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
|
||||
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
gate.top_k = top_k
|
||||
gate.num_experts = num_experts
|
||||
gate.norm_topk_prob = True
|
||||
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(4, hidden_dim)
|
||||
|
||||
# Without _afb_bias attribute
|
||||
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
# With _afb_bias = zeros (should be equivalent)
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
torch.testing.assert_close(scores1, scores2)
|
||||
torch.testing.assert_close(exp1, exp2)
|
||||
|
||||
@unittest.skipUnless(
|
||||
importlib_util.find_spec("triton") is not None,
|
||||
"triton not installed (required by scattermoe)",
|
||||
)
|
||||
def test_scattermoe_softmax_routing_with_afb_bias(self):
|
||||
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
T = 6
|
||||
|
||||
gate_weight = torch.randn(num_experts, hidden_dim)
|
||||
base_gate = SimpleNamespace(
|
||||
top_k=top_k,
|
||||
num_experts=num_experts,
|
||||
norm_topk_prob=True,
|
||||
weight=gate_weight,
|
||||
)
|
||||
|
||||
moe_block = SimpleNamespace()
|
||||
hidden = torch.randn(T, hidden_dim)
|
||||
|
||||
# Baseline without bias
|
||||
w_base, e_base, _, _ = _softmax_topk_route(
|
||||
moe_block, base_gate, hidden, gate_weight, None
|
||||
)
|
||||
|
||||
# With heavy bias on expert 0
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_bias[0] = 100.0
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
|
||||
w_biased, e_biased, _, _ = _softmax_topk_route(
|
||||
moe_block, base_gate, hidden, gate_weight, None
|
||||
)
|
||||
|
||||
# Expert 0 should appear in all selections
|
||||
self.assertTrue((e_biased == 0).any())
|
||||
# Counts accumulated
|
||||
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||
|
||||
def test_kernel_routing_skips_router_patch(self):
|
||||
"""When a kernel backend has patched the block class, the adapter
|
||||
should skip patching the router (buffers are still registered)."""
|
||||
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||
|
||||
adapter = MixtralAdapter()
|
||||
|
||||
# Create a mock layer whose class has _original_forward (SonicMoE marker)
|
||||
class PatchedBlock(nn.Module):
|
||||
_original_forward = True # SonicMoE marker
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(16, 4, bias=False)
|
||||
self.gate.top_k = 2
|
||||
self.gate.num_experts = 4
|
||||
self.gate.hidden_dim = 16
|
||||
self.experts = nn.Linear(16, 16) # placeholder
|
||||
|
||||
layer = PatchedBlock()
|
||||
self.assertTrue(adapter.uses_kernel_routing(layer))
|
||||
|
||||
# Gate should NOT be patched (kernel handles routing)
|
||||
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||
|
||||
def test_adapter_buffers_registered_even_with_kernel(self):
|
||||
"""Even when kernel routing is active, aux-free buffers must be
|
||||
registered on the MoE block so the kernel routing can find them."""
|
||||
from axolotl.integrations.aux_free_router.adapters import (
|
||||
LayerHandle,
|
||||
MixtralAdapter,
|
||||
)
|
||||
from axolotl.integrations.aux_free_router.core import (
|
||||
AuxFreeConfig,
|
||||
AuxFreeShim,
|
||||
AuxFreeState,
|
||||
)
|
||||
|
||||
class PatchedBlock(nn.Module):
|
||||
_original_forward = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(16, 4, bias=False)
|
||||
self.gate.top_k = 2
|
||||
self.gate.num_experts = 4
|
||||
self.gate.hidden_dim = 16
|
||||
self.experts = nn.Linear(16, 16)
|
||||
|
||||
layer = PatchedBlock()
|
||||
adapter = MixtralAdapter()
|
||||
cfg = AuxFreeConfig()
|
||||
state = AuxFreeState(
|
||||
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
|
||||
)
|
||||
shim = AuxFreeShim(state=state)
|
||||
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
|
||||
|
||||
adapter.prepare(layer, handle, shim)
|
||||
|
||||
# Buffers should be registered for kernel routing to use
|
||||
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||
self.assertTrue(hasattr(layer, "_afb_counts"))
|
||||
self.assertTrue(hasattr(layer, "_afb_ema"))
|
||||
# But gate should NOT be patched
|
||||
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user