Compare commits

...

22 Commits

Author SHA1 Message Date
Wing Lian
db6af43f3b chore: lint 2026-03-23 04:54:00 +00:00
Wing Lian
35d06c8087 add textui 2026-03-23 04:54:00 +00:00
Wing Lian
0e583efeaa increase rtol, codecov informational only, don't silently fail errors w curl (#3534) [skip ci] 2026-03-22 13:54:03 -04:00
Wing Lian
b3289fd190 feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]
* feat: LoRA kernel support for bias, dropout, dora, embeddings

* chore: lint

* chore: lint

* address PR feedback, add regression tests, add fsdp2 tests for lora kernels

* update tests for new sigs

* update tests now that bias and dropout are supported
2026-03-22 13:53:19 -04:00
Wing Lian
a67392c427 liger support for qwen 3.5 and fused rmsnorm+gated (#3531) [skip ci]
* liger support for qwen 3.5 and fused rmsnorm+gated

* support for qwen 3.5 moe

* fix version ref

* fixups for PR code review
2026-03-22 13:19:21 -04:00
Wing Lian
5b2e3f00ce fix: handle connection errors when checking user whoami (#3529) 2026-03-22 09:11:17 -04:00
Wing Lian
fc3b3d1d4e synthetic datasets for benchmarking and testing (#3518) [skip ci]
* synthetic datasets for benchmarking and testing

* fix synthetic dataset parse from config and add tests

* use type=_synthetic
2026-03-21 22:47:26 -04:00
Wing Lian
c9df6efdc2 support offloading layers to CPU (#3512) [skip ci]
* support offloading layers to CPU

* chore: lint

* revert change

* update docs
2026-03-21 22:47:02 -04:00
Wing Lian
0ee98a0309 fix token state json and mistral tokenizer issue (#3522) [skip ci]
* fix token state json and mistral tokenizer issue

* centralize constants

* forgot to commit constants file

* Fix weakref in pickling relora state dict

* make curl a bit quieter so it doesn't log 2K lines

* fix path traversal for olmoe test

* more test fixes that weren't flagged previously

* chore: lint

* skip tests that fail b/c of OutOfResources

* scattermoe as slow tests

* update fbgemm-genai for torch 2.10
2026-03-21 22:46:10 -04:00
Wing Lian
2c05847a5f reduce autotune search space (#3525) [skip ci]
* reduce autotune search space

* consistent docstrings
2026-03-21 18:30:15 -04:00
Wing Lian
b0294b3427 handle qwen3.5 moe loading (#3523) [skip ci] 2026-03-20 09:25:16 -04:00
Avaya Aggarwal
1bcfc08c90 feat: add support and end-to-end tests for multiple custom optimizers… (#3457) [skip ci]
* feat: add support and end-to-end tests for multiple custom optimizers including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW.

* feat: Add standalone flashoptim integration test and E2E tests for various custom optimizers including FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion, optimi_adamw, adopt_adamw, muon, dion, and schedule_free_adamw.

* feat: introduce Pydantic schema validation for dataset, attention, and training configurations.

* feat: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: fix assertion in flash optimizers test to compare class names directly

* fix: address PR review - reuse require_torch_2_7_0 decorator, remove fsdp_config.version check, extract shared FSDP version helper, remove unused imports and optim_args

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 08:24:44 -04:00
NanoCode012
5a5cf30b26 fix: add dequant bf16 repo (#3507) [skip ci] 2026-03-20 17:11:46 +07:00
Avaya Aggarwal
7ddfb2d8a0 cleanup: remove dead SDPA patches (#3488) [skip ci]
Transformers 5.x routes attention through sdpa_attention.py and no longer
calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that
these patches targeted. This makes the following patches dead code:

- llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*)
- llama_expand_mask.py (patched _expand_mask, never called)
- Related utility functions in monkeypatch/utils.py

Closes axolotl-ai-cloud/axolotl#3331
2026-03-20 17:10:41 +07:00
Owen Arliawan
c57acef2c7 Qwen3.5-MoE example config with lora_target_modules regex (#3515) [skip ci]
* lora target modules with regex

* updates

* fsdp for non moe

* update wording

* chore: cleanup and lint

* chore: cleanup docs from merge

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 16:52:46 +07:00
Lorenzo Baraldi
038ffe3f26 fix: solved double sequence partition from SequenceParallelContextManager and Accelerate's native CP (#3498) 2026-03-20 16:27:24 +07:00
VED
c13cb7c853 feat: add nemotron config (#3506)
* nemotron config exp

* Update examples/nemotron/nemotron-mini-4b-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-03-20 16:23:42 +07:00
VED
b3823cc6b0 fix: gemma3 configs (#3500) [skip ci]
* gemma fft , text fix

* good lint
2026-03-20 16:14:06 +07:00
VED
113d275bd9 qwen docs + new config (#3499) [skip ci]
* qwen docs + new config

* docss lint

* simplify comments

* read me

* lint comments

* Update docs/multimodal.qmd

* Update docs/multimodal.qmd

* Update examples/qwen3.5/9b-fft-vision.yaml

* chore: fix link and incorrect points

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 16:13:34 +07:00
VED
7920fe74ec fix num_labels= 1 test fail (#3493) [skip ci]
* trl_num_lables=1

* casual num_lables=1,rwd model

* lint
2026-03-20 16:12:23 +07:00
Wing Lian
1fc86d5295 Scattermoe LoRA optimizations (#3513)
* optimize moe + lora

* more scattermoe optims

* selective dequant

* add correctness unit tests and benchmarks for scattermoe + lora

* handle base+lora split kernel for older moe models

* chore: lint

* fix casting for H200 and B200

* register pressure estimation and pruning for h200/b200

* use soft limit for pruning

* qkv patch for qwen3.5moe

* support text_model for qwen3.5 moe

* nesting of qwen3

* use udpated cce with zero3 support

* Fix decomposed backward for QKV and O projections

eliminates B @ A materialization in LoRA attention backward, replacing full [out, in] matmuls with two small [T, R] matmuls.
2026-03-19 23:07:42 -04:00
Wing Lian
bb483ad4c4 make the CI fail GitHub Actions on test failures (#3517)
* make the CI fail GitHub Actions on test failures

* use model bundle

* install zstd for compressed model artifact
2026-03-19 08:29:24 -04:00
127 changed files with 9223 additions and 934 deletions

View File

@@ -128,11 +128,9 @@ quartodoc:
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils

View File

@@ -0,0 +1,284 @@
"""Benchmark for ScatterMoE LoRA Triton kernels.
Measures forward, backward dX, and backward dA/dB kernels at common MoE
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
and full fwd+bwd autograd throughput.
Usage:
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
"""
import argparse
import gc
import time
from functools import partial
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
WARMUP = 5
ITERS = 20
# ─── Model configs ──────────────────────────────────────────────────────────
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
}
def _resolve_config(spec):
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
key = spec.lower().replace("/", "-")
for name, cfg in BUILTIN_CONFIGS.items():
if key in name.lower() or name.lower() in key:
return name, cfg
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
if callable(getattr(hf_cfg, "get_text_config", None)):
tc = hf_cfg.get_text_config()
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
hf_cfg = tc
hidden = hf_cfg.hidden_size
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
experts = (
getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None)
)
top_k = (
getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None)
or 2
)
name = spec.split("/")[-1]
return name, (experts, hidden, inter, top_k)
# ─── Benchmark helpers ──────────────────────────────────────────────────────
def _clean():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _bench(fn, warmup=WARMUP, iters=ITERS):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
torch.cuda.synchronize()
t0 = time.perf_counter()
fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
times.sort()
return times[len(times) // 2]
def _setup(num_experts, K, N, T, top_k, R):
torch.manual_seed(42)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, num_experts, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
gx = base_ops.group(x, ssi, fan_out=top_k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
return lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
lora_A=lA,
lora_B=lB,
scaling=2.0,
)
def _call_base(x, W, sei, ssi, top_k):
return base_ops.scatter2scatter(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
)
def _call_dx(dy, W, sei, ssi, lA, lB):
return lora_ops.scatter2scatter_lora_dX(
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=2.0,
dy_grouped=True,
dx_grouped=False,
)
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
return lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=num_experts,
scaling=2.0,
)
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
parser.add_argument(
"--models",
"-m",
nargs="+",
help="Model names or HF IDs (default: all builtins)",
)
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
parser.add_argument("--seq-len", "-T", type=int, default=2048)
args = parser.parse_args()
T = args.seq_len
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"T={T}, ranks={args.ranks}\n")
if args.models:
configs = [_resolve_config(m) for m in args.models]
else:
configs = list(BUILTIN_CONFIGS.items())
for model_name, (num_experts, hidden, inter, top_k) in configs:
print(f"{'=' * 70}")
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
print(f"{'=' * 70}")
for R in args.ranks:
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
num_experts, K, N, T, top_k, R
)
# Forward with LoRA (auto-dispatched: fused or split)
dispatch = (
"split"
if (
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
)
else "fused"
)
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
total = t_fwd + t_dx + t_bwd
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(
f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead * 100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms"
)
# Full autograd fwd+bwd with memory measurement
x_ag = x.clone().requires_grad_(True)
lA_ag = lA.clone().requires_grad_(True)
lB_ag = lB.clone().requires_grad_(True)
def _run_autograd(
_x=x_ag,
_W=W,
_k=top_k,
_sei=sei,
_ssi=ssi,
_eo=eo,
_lA=lA_ag,
_lB=lB_ag,
):
out = ScatterMoELoRA.apply(
_x,
_W,
_k,
_sei,
_ssi,
_eo,
_lA,
_lB,
2.0,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
_x.grad = None
_lA.grad = None
_lB.grad = None
t_full = _bench(_run_autograd)
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
_run_autograd()
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(
f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
)
print()
if __name__ == "__main__":
main()

View File

@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -3,11 +3,13 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
hf download "NousResearch/Meta-Llama-3-8B"
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
hf download "microsoft/Phi-4-reasoning"
hf download "microsoft/Phi-3.5-mini-instruct"
set -o pipefail
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
# hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning"
# hf download "microsoft/Phi-3.5-mini-instruct"
# hf download "microsoft/Phi-3-medium-128k-instruct"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \

View File

@@ -68,10 +68,6 @@ def run_cmd(cmd: str, run_folder: str):
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
# Propagate errors from subprocess.
try:
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
print(f"Command '{cmd}' failed with exit code {exit_code}")
return exit_code
except Exception as e: # pylint: disable=broad-except
print(f"Command '{cmd}' failed with exception {e}")
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")

View File

@@ -37,6 +37,7 @@ coverage:
only_pulls: false
flags: null
paths: null
informational: true
parsers:
gcov:

View File

@@ -1,5 +1,5 @@
---
title: Gradient Checkpointing and Activation Offloading
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
@@ -27,3 +27,33 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
### Enabling Layer Offloading
```yaml
layer_offloading: true
```
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
trainable adapter weights stay on GPU permanently.
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
prefetched asynchronously on a separate CUDA stream for overlap.
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
previous layer is prefetched.
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
that is kept on GPU at any given time.
**Requirements:**
- CUDA GPU (CPU-only training is not supported for this feature)
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
- Best combined with LoRA/QLoRA where most parameters are frozen

View File

@@ -20,6 +20,7 @@ format:
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Qwen3.5](#sec-qwen3-5)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
@@ -191,6 +192,14 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3.5 {#sec-qwen3-5}
```yaml
base_model: Qwen/Qwen3.5-9B
chat_template: qwen3_5
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.

View File

@@ -54,6 +54,13 @@ These techniques save VRAM by changing how activations are handled.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
### Layer Offloading
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
- **Config:** `layer_offloading: true`
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
### Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
]
},
{

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -27,6 +24,11 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -27,6 +24,11 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,9 +1,5 @@
base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp
@@ -24,6 +20,11 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_model_dir:

View File

@@ -6,9 +6,6 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
Note: Training this model requires weights in BF16 which we will link to later.
Users interested in training can convert / descale the existing FP8 weights.
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -1,4 +1,4 @@
base_model: mistralai/Mistral-Small-4-119B-2603
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -1,4 +1,4 @@
base_model: mistralai/Mistral-Small-4-119B-2603
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
processor_type: AutoProcessor
plugins:

View File

@@ -1,4 +1,4 @@
base_model: mistralai/Mistral-Small-4-119B-2603
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -1,4 +1,4 @@
base_model: mistralai/Mistral-Small-4-119B-2603
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
processor_type: AutoProcessor
plugins:

View File

@@ -0,0 +1,57 @@
base_model: nvidia/Nemotron-Mini-4B-Instruct
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/nemotron-mini-4b-qlora
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- up_proj
- down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
special_tokens:

View File

@@ -0,0 +1,84 @@
base_model: Qwen/Qwen3.5-122B-A10B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -32,7 +32,11 @@ lora_target_modules:
- v_proj
- o_proj
#lora_target_parameters:
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
@@ -52,7 +56,6 @@ learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false

View File

@@ -0,0 +1,59 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
# Freeze vision encoder
unfrozen_parameters:
- model\.language_model\..*
- lm_head\..*
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,81 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment below to also target the linear attention projections.
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,9 +1,7 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
# the text-only path. For multimodal (image+text) fine-tuning, add image
# columns to your dataset following axolotl's multimodal dataset format.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -0,0 +1,85 @@
base_model: Qwen/Qwen3.5-35B-A3B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -32,7 +32,11 @@ lora_target_modules:
- v_proj
- o_proj
#lora_target_parameters:
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj

View File

@@ -0,0 +1,49 @@
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Required for multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,10 +1,6 @@
base_model: Qwen/Qwen3.5-7B
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
# Vision and text tokens are processed together by the same transformer layers.
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
# These 3 lines are required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
@@ -30,8 +26,6 @@ lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# Targets the language model attention and MLP layers.
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
# the same transformer stack, so standard attention targets work for both modalities.
lora_target_modules:
- q_proj
- k_proj

View File

@@ -1,15 +1,6 @@
# Finetune Qwen3.5 with Axolotl
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
Available configs:
| Config | Model | Type |
|---|---|---|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
## Getting started
@@ -18,35 +9,69 @@ Available configs:
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Pick any config from the table below and run:
```bash
axolotl train examples/qwen3.5/<config>.yaml
```
Available configs:
| Config | Model | Type | Peak VRAM |
|---|---|---|---|
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
### Gated DeltaNet Linear Attention
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
```yaml
lora_target_modules:
# ... standard projections ...
- linear_attn.in_proj_qkv
- linear_attn.in_proj_z
- linear_attn.out_proj
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Run a finetuning example:
### Routed Experts (MoE)
```bash
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
axolotl train examples/qwen3.5/27b-qlora.yaml
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
# MoE 35B-A3B text-only (QLoRA)
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router
```
# MoE 122B-A10B text-only (QLoRA)
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
### Shared Experts (MoE)
# 7B vision+text (LoRA, multimodal dataset)
axolotl train examples/qwen3.5/7b-lora-vision.yaml
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
```yaml
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
```
### TIPS
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- For inference hyp, please see the respective model card details.
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
## Optimization Guides

View File

@@ -61,5 +61,11 @@ skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.pytest.ini_options]
addopts = "-m 'not slow'"
markers = [
"slow: marks tests as slow",
]
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
)

View File

@@ -81,16 +81,23 @@ def parse_requirements(extras_require_map):
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 9):
if (major, minor) >= (2, 10):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.5.0",
"fbgemm-gpu-genai==1.5.0",
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.17.1"]
elif (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:

View File

@@ -3,6 +3,7 @@
import os
from pathlib import Path
import httpcore
from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
@@ -47,7 +48,7 @@ def check_user_token() -> bool:
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except HTTPError:
except (HTTPError, httpcore.ConnectError):
LOG.warning(
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
)

View File

@@ -91,6 +91,7 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@click.option("--tui", is_flag=True, default=False, help="Enable TUI dashboard")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@@ -101,6 +102,7 @@ def train(
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
cloud: str | None = None,
sweep: str | None = None,
tui: bool = False,
**kwargs,
):
"""
@@ -118,6 +120,10 @@ def train(
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
# Handle --tui flag: set env var so subprocess workers pick it up
if tui:
os.environ["AXOLOTL_TUI"] = "1"
# Handle Ray launcher override
_launcher = None if kwargs.get("use_ray") else launcher

View File

@@ -2,6 +2,7 @@
import gc
import os
import queue
from pathlib import Path
from typing import Union
@@ -34,22 +35,101 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
# Start TUI early (before data loading) so it captures preprocessing events
tui_renderer = None
tui_queue: queue.Queue | None = None
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
if is_rank_0:
from axolotl.train import _is_tui_enabled
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
if _is_tui_enabled(cfg):
import queue as _queue
del model, tokenizer, trainer
from axolotl.train import _get_tui_config
from axolotl.tui.config import TUIConfig
from axolotl.tui.renderer import TUIRenderer
gc.collect()
tui_config_dict = _get_tui_config(cfg)
tui_config = (
TUIConfig(**tui_config_dict)
if isinstance(tui_config_dict, dict)
else tui_config_dict
)
tui_queue = _queue.Queue(maxsize=4096)
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)
# Send initial run info
model_name = cfg.base_model or ""
training_mode = str(cfg.rl) if cfg.rl else "sft"
world_size = int(os.environ.get("WORLD_SIZE", 1))
try:
tui_queue.put_nowait(
{
"type": "run_info",
"model_name": model_name,
"training_mode": training_mode,
"world_size": world_size,
}
)
except _queue.Full:
pass
tui_renderer.start()
# Attach logging handler early
import logging
from axolotl.tui.callback import _TUILogHandler
_early_log_handler = _TUILogHandler(
tui_queue, min_level=tui_config.log_level
)
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
# Attach to BOTH root and axolotl loggers because axolotl logger
# has propagate=False so root handler never sees axolotl.* messages
root_logger = logging.getLogger()
root_logger.addHandler(_early_log_handler)
axolotl_logger = logging.getLogger("axolotl")
axolotl_logger.addHandler(_early_log_handler)
# Stash refs on cfg so train() can reuse the renderer
cfg._tui_renderer = tui_renderer
cfg._tui_queue = tui_queue
cfg._tui_early_log_handler = _early_log_handler
try:
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
gc.collect()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)
finally:
# If the TUI renderer started early but train() didn't get to stop it
# (e.g., error during data loading), clean up here
if tui_renderer is not None and not tui_renderer._stop_event.is_set():
try:
if tui_queue is not None:
tui_queue.put_nowait({"type": "done"})
except queue.Full:
pass
tui_renderer.stop()
# Remove early log handler from both root and axolotl loggers
if hasattr(cfg, "_tui_early_log_handler"):
import logging
logging.getLogger().removeHandler(cfg._tui_early_log_handler)
logging.getLogger("axolotl").removeHandler(cfg._tui_early_log_handler)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):

View File

@@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adamw":
from flashoptim import FlashAdamW
optimizer_cls = FlashAdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adam":
from flashoptim import FlashAdam
optimizer_cls = FlashAdam
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_sgd":
from flashoptim import FlashSGD
optimizer_cls = FlashSGD
elif self.cfg.optimizer == "flash_sgdw":
from flashoptim import FlashSGDW
optimizer_cls = FlashSGDW
elif self.cfg.optimizer == "flash_lion":
from flashoptim import FlashLion
optimizer_cls = FlashLion
if "betas" in adam_kwargs:
optimizer_kwargs["betas"] = adam_kwargs["betas"]
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
@@ -484,6 +508,8 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.layer_offloading:
training_args_kwargs["layer_offloading"] = True
if self.cfg.activation_offloading is True:
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False

View File

@@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
# config reflects this regardless of how the model was instantiated.
if (
self.cfg.reward_model
and getattr(self.model.config, "num_labels", None) != 1
):
self.model.config.num_labels = 1
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,

View File

@@ -208,7 +208,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
if (
self.cfg.adapter
and self.peft_config
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
):
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (

View File

@@ -29,10 +29,12 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.experimental.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
DistributedParallelMixin,
LayerOffloadingMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
@@ -51,8 +53,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state."
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
@@ -67,6 +67,7 @@ class AxolotlTrainer(
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
LayerOffloadingMixin,
ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer,

View File

@@ -0,0 +1 @@
TOKENS_STATE_FILE = "tokens_state.json"

View File

@@ -2,7 +2,8 @@
Axolotl specific DPO args
"""
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from trl import DPOConfig
@@ -16,3 +17,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
dpo_norm_loss: bool | None = False
rpo_alpha: Optional[float] = field(default=None)

View File

@@ -4,6 +4,7 @@
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .layer_offloading import LayerOffloadingMixin
from .distributed_parallel import DistributedParallelMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin

View File

@@ -0,0 +1,304 @@
"""
Trainer mixin for layer-wise parameter offloading to CPU.
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
transfer stream), post-hook offloads layer N-1's frozen params.
Backward: same in reverse order.
"""
import contextlib
import torch
import torch.nn as nn
from transformers import Trainer
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
"""Recursively search the model for the decoder layer ModuleList.
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
where layers are at model.language_model.layers).
"""
# BFS to find the first ModuleList containing decoder layers
queue = [model]
while queue:
m = queue.pop(0)
for _name, child in m.named_children():
if isinstance(child, nn.ModuleList) and len(child) > 0:
first_type = type(child[0]).__name__
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
layer_types = list({type(layer).__name__ for layer in child})
return child, layer_types
else:
queue.append(child)
return None, []
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
"""Get all non-trainable parameters in a layer."""
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
class LayerOffloadManager:
"""Manages offloading frozen decoder layer params to CPU and streaming
them back during forward/backward with CUDA stream overlap.
Only frozen (requires_grad=False) parameters are offloaded.
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
"""
def __init__(
self,
model: nn.Module,
num_prefetch: int = 1,
):
self.model = model
self.num_prefetch = num_prefetch
self._hooks: list = []
self._device = None
# Find decoder layers
self.layers, layer_types = _find_decoder_layers(model)
if self.layers is None:
LOG.warning(
"LayerOffloadManager: no decoder layers found, offloading disabled"
)
self.enabled = False
return
self.enabled = True
self.n_layers = len(self.layers)
LOG.info(
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
)
# Determine GPU device
for p in model.parameters():
if p.device.type == "cuda":
self._device = p.device
break
if self._device is None:
LOG.warning("LayerOffloadManager: no CUDA parameters found")
self.enabled = False
return
# Transfer stream for async prefetch
self._transfer_stream = torch.cuda.Stream(device=self._device)
# Track which layers have their frozen params on GPU
self._on_gpu: set[int] = set(range(self.n_layers))
# Cache: frozen param references per layer (list of (name, param) tuples)
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
]
# CPU storage: pinned tensors for each layer's frozen params
# Populated on first offload
self._cpu_data: list[dict[str, torch.Tensor]] = [
{} for _ in range(self.n_layers)
]
# Offload all layers upfront
self._offload_all()
# Release cached memory blocks back to the driver
torch.cuda.empty_cache()
def _offload_all(self):
"""Move all frozen params in all decoder layers to CPU."""
mem_before = torch.cuda.memory_allocated(self._device)
for i in range(self.n_layers):
self._offload_layer(i)
mem_after = torch.cuda.memory_allocated(self._device)
freed = (mem_before - mem_after) / 1e6
LOG.info(
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
f"freed {freed:.0f} MB GPU memory"
)
def _offload_layer(self, idx: int):
"""Move frozen params of layer idx to CPU pinned memory."""
if idx not in self._on_gpu:
return
for name, param in self._frozen_params[idx]:
if param.device.type != "cuda":
continue
# Allocate pinned CPU tensor on first offload
if name not in self._cpu_data[idx]:
self._cpu_data[idx][name] = torch.empty_like(
param.data, device="cpu", pin_memory=True
)
cpu_buf = self._cpu_data[idx][name]
# Async copy GPU -> CPU (on transfer stream for overlap)
cpu_buf.copy_(param.data, non_blocking=True)
# Point parameter at a dummy CPU tensor to free GPU memory
param.data = cpu_buf
self._on_gpu.discard(idx)
def _load_layer(self, idx: int, stream=None):
"""Move frozen params of layer idx back to GPU."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
ctx = (
torch.cuda.stream(stream)
if stream is not None
else contextlib.nullcontext()
)
with ctx:
for _name, param in self._frozen_params[idx]:
if param.device.type == "cuda":
continue
gpu_data = param.data.to(self._device, non_blocking=True)
param.data = gpu_data
self._on_gpu.add(idx)
def _prefetch_layer(self, idx: int):
"""Async prefetch layer idx on the transfer stream."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
self._load_layer(idx, stream=self._transfer_stream)
def _wait_transfer(self):
"""Make default stream wait for any in-flight transfers."""
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
def setup_hooks(self):
"""Register forward and backward hooks on each decoder layer."""
if not self.enabled:
return
for idx in range(self.n_layers):
layer = self.layers[idx]
def make_pre_fwd(i):
def hook(module, args):
# Ensure this layer is on GPU
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch next layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i + offset)
return hook
def make_post_fwd(i):
def hook(module, args, output):
# Offload previous layer (no longer needed in forward)
if i > 0:
self._offload_layer(i - 1)
# Offload last layer after forward
if i == self.n_layers - 1:
self._offload_layer(i)
return hook
def make_pre_bwd(i):
def hook(module, grad_output):
# Load this layer for backward
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch previous layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i - offset)
return hook
def make_post_bwd(i):
def hook(module, grad_input, grad_output):
# Offload the layer above
if i < self.n_layers - 1:
self._offload_layer(i + 1)
# Offload first layer after backward
if i == 0:
self._offload_layer(i)
return hook
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
h2 = layer.register_forward_hook(make_post_fwd(idx))
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
self._hooks.extend([h1, h2, h3, h4])
def remove_hooks(self):
"""Remove all hooks and restore layers to GPU."""
for h in self._hooks:
h.remove()
self._hooks.clear()
if self.enabled:
for i in range(self.n_layers):
if i not in self._on_gpu:
self._load_layer(i)
def pre_step(self):
"""Called before each training step — ensure layers start offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for forward
self._prefetch_layer(0)
def post_step(self):
"""Called after each training step — ensure layers are offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for next step
self._prefetch_layer(0)
class _LayerOffloadContext:
"""Context manager wrapping pre_step / post_step around a training step."""
def __init__(self, manager: LayerOffloadManager):
self.manager = manager
def __enter__(self):
self.manager.pre_step()
return self
def __exit__(self, *args):
self.manager.post_step()
class LayerOffloadingMixin(Trainer):
"""
Trainer mixin class for layer-wise parameter offloading to CPU.
Offloads frozen decoder layer params to CPU at init, then streams them
on/off GPU one layer at a time during each training step.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if getattr(self.args, "layer_offloading", False):
LOG.info("Layer parameter offloading enabled")
self._layer_offload_manager = LayerOffloadManager(
model=self.model,
num_prefetch=1,
)
self._layer_offload_manager.setup_hooks()
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
else:
self._layer_offload_manager = None
self._layer_offload_ctx = contextlib.nullcontext()
def training_step(self, *args, **kwargs):
with self._layer_offload_ctx:
return super().training_step(*args, **kwargs)

View File

@@ -235,6 +235,13 @@ class AxolotlTrainingMixins:
metadata={"help": "Use activation offloading with CUDA streams for training."},
)
layer_offloading: bool | None = field(
default=None,
metadata={
"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
```
## Usage

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
)

View File

@@ -15,6 +15,7 @@ SPARSE_MOE_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock",
"qwen3_next": "Qwen3NextSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
@@ -58,7 +59,16 @@ def resolve_moe_block_classes(model_type: str):
cls_names = entry if isinstance(entry, list) else [entry]
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = importlib.import_module(module_path)
try:
module = importlib.import_module(module_path)
except ModuleNotFoundError:
# Text sub-model types (e.g. qwen3_5_moe_text) share the parent module
if model_type.endswith("_text"):
parent_type = model_type.removesuffix("_text")
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
module = importlib.import_module(module_path)
else:
raise
classes = []
for cls_name in cls_names:

View File

@@ -195,6 +195,36 @@ def _estimate_smem_usage(
_SMEM_SLACK = 10_000
def _estimate_register_pressure(
num_warps: int,
*tile_sizes: tuple[int, int],
) -> float:
"""Rough estimate of per-thread register footprint from live tile sizes.
This is a heuristic, NOT an accurate register count. Triton uses tensor
core MMA fragments that pack multiple elements per register, and can spill
to local memory when the hardware limit (255 regs/thread) is exceeded.
The estimate is used to prune only truly extreme configs that would cause
excessive spilling or compilation failures. The threshold is set high
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
(est ~520) work fine in practice via spilling.
Returns estimated registers per thread.
"""
# Each thread in a warp holds ~1/32 of the tile elements
tile_regs = sum(r * c for r, c in tile_sizes) / 32
scalar_overhead = 40
return tile_regs + scalar_overhead
# Soft limit for register pressure pruning. Only prune configs with extreme
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
_MAX_REGS_SOFT_LIMIT = 1024
# =============================================================================
# Forward Kernel: scatter2scatter with fused LoRA
# =============================================================================
@@ -313,12 +343,11 @@ def _compute_expert_block_lora(
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
) # [BLOCK_N, BLOCK_R]
# Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16)
# Both operands must match; cast to float32 (accumulator type) for precision.
b_f32 = b.to(tl.float32)
# tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype
b_inp = b.to(INPUT_DTYPE)
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32)
lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)
acc += scaling * lora_out
return acc
@@ -327,28 +356,29 @@ def _compute_expert_block_lora(
def _scatter2scatter_lora_configs():
"""Generate forward kernel autotune configs.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space includes BLOCK_M to allow trading token-tile size for
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
(4× fewer inner-loop iterations).
Search space:
BLOCK_N: {32, 64, 128, 256}
BLOCK_M: {32, 64, 128}
BLOCK_N: {32, 64}
BLOCK_K: {32, 64, 128}
num_warps: {4, 8}
num_stages: {3, 4, 5}
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
scatter2scatter pattern).
"""
configs = []
for block_n, block_k, warps, stages in product(
[32, 64, 128, 256], # BLOCK_N
for block_m, block_n, block_k, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64], # BLOCK_N
[32, 64, 128], # BLOCK_K
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_N": block_n, "BLOCK_K": block_k},
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
num_stages=stages,
num_warps=warps,
)
@@ -357,7 +387,7 @@ def _scatter2scatter_lora_configs():
def _prune_fwd_configs(configs, named_args, **kwargs):
"""Prune forward configs based on SMEM capacity.
"""Prune forward configs based on SMEM capacity and register pressure.
The forward kernel inner loop loads three tiles per pipeline stage:
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
@@ -373,23 +403,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_n = config.kwargs["BLOCK_N"]
block_k = config.kwargs["BLOCK_K"]
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_r * block_k * 2
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
smem_lora_epilogue = block_n * block_r * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,N], xa_acc[M,R],
# x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_n), # acc
(block_m, block_r), # xa_acc
(block_m, block_k), # x tile
(block_k, block_n), # w tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"]
),
)
]
@triton.autotune(
@@ -531,6 +587,89 @@ def _scatter2scatter_lora(
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def _scatter2scatter_lora_split(
X: torch.Tensor,
W: torch.Tensor,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
k: int,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
b: Optional[torch.Tensor] = None,
x_grouped: bool = False,
y_grouped: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
because the base kernel runs at full speed without LoRA SMEM overhead,
and the LoRA matmuls (R=16) are tiny separate passes.
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
"""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
scatter2scatter,
)
E = W.size(0)
R = lora_A.size(0) // E
K = W.size(1)
N = W.size(2)
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
output = scatter2scatter(
X=X,
W=W,
b=b,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=y_grouped,
out=out,
)
# 2. XA = X @ A^T (tiny: output is [M*k, R])
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
XA = scatter2scatter(
X=X,
W=W_A,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=True,
)
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
# Reshape B: [N, R*E] → [E, R, N]
W_B = lora_B.T.reshape(E, R, N).contiguous()
Y_lora = scatter2scatter(
X=XA,
W=W_B,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
x_grouped=True,
y_grouped=y_grouped,
)
# 4. Y = Y_base + scaling * Y_lora
output.add_(Y_lora, alpha=scaling)
return output
# Threshold for switching from fused to split LoRA forward.
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
# loads dominate in the fused kernel's inner loop).
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
def scatter2scatter_lora(
X: torch.Tensor,
W: torch.Tensor,
@@ -546,7 +685,13 @@ def scatter2scatter_lora(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Automatically selects between:
- Fused kernel: single Triton kernel with LoRA in the inner loop.
Best for many small experts (E>=64, small K*N).
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
Best for few large experts (E<=32, large K*N like Mixtral).
Args:
X: Input [M, K] or [M*k, K] if x_grouped
@@ -565,12 +710,30 @@ def scatter2scatter_lora(
Returns:
Y: Output [M*k, N]
"""
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
E = W.size(0)
K = W.size(1)
N = W.size(2)
# Dispatch: split for few large experts, fused for many small experts
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
return _scatter2scatter_lora_split(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
lora_A,
lora_B,
scaling,
b,
x_grouped,
y_grouped,
out,
)
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
R = lora_A.size(0) // E
# Pad R to power of 2 for Triton tile size
@@ -610,11 +773,9 @@ def scatter2scatter_lora(
b_ptr,
stride_be,
stride_bn,
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
lora_A,
lora_A.stride(0),
lora_A.stride(1),
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
lora_B,
lora_B.stride(0),
lora_B.stride(1),
@@ -625,9 +786,8 @@ def scatter2scatter_lora(
K=K,
N=N,
E=E,
ACTUAL_R=R, # True LoRA rank for weight indexing
BLOCK_M=BLOCK_M,
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
allow_tf32=ALLOW_TF32,
@@ -761,13 +921,13 @@ def _compute_expert_block_lora_dX(
+ (A_expert_offset + R_block)[:, None] * stride_ar
+ K_block[None, :] * stride_ak
)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0)
# Cast to float32 for precision
a_f32 = a_e.to(tl.float32)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(
INPUT_DTYPE
)
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32)
# tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype
lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)
acc += scaling * lora_dx
return acc
@@ -779,25 +939,26 @@ def _scatter2scatter_lora_dX_configs():
The inner loop is over N (not K as in forward). The output dimension is K.
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
BLOCK_M is now autotunable (was fixed at 128).
Search space:
BLOCK_K: {32, 64, 128, 256} (output tile)
BLOCK_N: {32, 64, 128, 256} (reduction tile)
BLOCK_M: {32, 64, 128} (token tile)
BLOCK_K: {32, 64, 128} (output tile)
BLOCK_N: {32, 64} (reduction tile)
num_warps: {4, 8}
num_stages: {3, 4, 5}
"""
configs = []
for block_k, block_n, warps, stages in product(
[32, 64, 128, 256], # BLOCK_K (output dimension)
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128], # BLOCK_K (output dimension)
[32, 64], # BLOCK_N (reduction dimension)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_K": block_k, "BLOCK_N": block_n},
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
num_stages=stages,
num_warps=warps,
)
@@ -806,7 +967,7 @@ def _scatter2scatter_lora_dX_configs():
def _prune_dX_configs(configs, named_args, **kwargs):
"""Prune backward dX configs based on SMEM capacity.
"""Prune backward dX configs based on SMEM capacity and register pressure.
The dX kernel inner loop loads three tiles per pipeline stage:
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
@@ -822,23 +983,49 @@ def _prune_dX_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_k = config.kwargs["BLOCK_K"]
block_n = config.kwargs["BLOCK_N"]
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_n * block_r * 2
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
smem_lora_epilogue = block_r * block_k * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],
# dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_k), # acc
(block_m, block_r), # dy_b_acc
(block_m, block_n), # dy tile
(block_n, block_k), # wt tile
(block_n, block_r), # b tile
(block_r, block_k), # a tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1067,7 +1254,7 @@ def scatter2scatter_lora_dX(
N=N,
E=E,
ACTUAL_R=R,
BLOCK_M=BLOCK_M,
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
@@ -1091,9 +1278,9 @@ def _group_bwd_lora_configs():
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space:
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
BLOCK_K: {32, 64, 128, 256}
BLOCK_N: {32, 64, 128, 256}
BLOCK_M: {32, 64, 128} (token-loop tile)
BLOCK_K: {32, 64, 128}
BLOCK_N: {32, 64}
num_warps: {4, 8}
num_stages: {3, 4, 5}
@@ -1102,9 +1289,9 @@ def _group_bwd_lora_configs():
"""
configs = []
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128, 256], # BLOCK_M
[32, 64, 128, 256], # BLOCK_K
[32, 64, 128, 256], # BLOCK_N
[32, 64, 128], # BLOCK_M
[32, 64, 128], # BLOCK_K
[32, 64], # BLOCK_N
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
@@ -1119,7 +1306,7 @@ def _group_bwd_lora_configs():
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
"""Prune backward configs based on SMEM capacity.
"""Prune backward configs based on SMEM capacity and register pressure.
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
@@ -1138,14 +1325,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
smem_lora = (block_r * block_k + block_n * block_r) * 2
smem = smem_base + smem_lora
# Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],
# a[R,K], b[N,R], xa[M,R], dy_b[M,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_k), # dA_acc
(block_n, block_r), # dB_acc
(block_m, block_k), # x tile
(block_m, block_n), # dy tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile
(block_m, block_r), # xa intermediate
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1330,6 +1543,279 @@ def _group_bwd_lora(
)
def _group_bwd_split_configs():
"""Autotune configs for split dA/dB kernels."""
configs = []
for block_m, block_dim, warps, stages in product(
[32, 64, 128], # BLOCK_M (token tile)
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
num_stages=stages,
num_warps=warps,
)
)
return configs
def _prune_split_configs(configs, named_args, **kwargs):
"""Prune split kernel configs based on SMEM capacity and register pressure."""
smem_cap = _get_smem_capacity()
block_r = named_args.get("BLOCK_R", 64)
# Fixed inner tile for reduction dimension
BLOCK_INNER = 64
pruned = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_dim = config.kwargs["BLOCK_DIM"]
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
# LoRA weights held in registers: [INNER, R] or [R, DIM]
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
# Register pressure check
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_dim), # acc
(block_m, BLOCK_INNER), # input tile
(block_m, block_dim), # other tile
(block_r, BLOCK_INNER), # lora weight
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
if smem <= smem_cap - _SMEM_SLACK:
pruned.append(config)
if pruned:
return pruned
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
return [configs[0]]
@triton.autotune(
configs=_group_bwd_split_configs(),
key=["M", "K", "N"],
prune_configs_by={"early_config_prune": _prune_split_configs},
)
@triton.heuristics(
{
"NO_DIM_MASK": lambda args: (
(args["K"] % args["BLOCK_DIM"]) == 0
if args["COMPUTE_DA"]
else (args["N"] % args["BLOCK_DIM"]) == 0
),
}
)
@triton.jit
def _group_bwd_lora_split(
# Data tensors (DY and X are always present)
DY_ptr,
stride_dym,
stride_dyn,
X_ptr,
stride_xm,
stride_xk,
# LoRA weight for the inner reduction (B for dA, A for dB)
LW_ptr,
stride_lw0,
stride_lw1,
# Output gradient tensor (dA or dB)
OUT_ptr,
stride_out0,
stride_out1,
# Expert offsets
expert_offsets_ptr,
# Dimensions
M,
K: tl.constexpr,
N: tl.constexpr,
ACTUAL_R: tl.constexpr,
BLOCK_R: tl.constexpr,
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
scaling,
# Mode flag
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
# Tile sizes
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_DIM_MASK: tl.constexpr,
):
"""
Unified split kernel for LoRA gradient computation.
When COMPUTE_DA=True:
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
Grid: (E, cdiv(K, BLOCK_DIM))
- outer_ptr/stride = X (read [M, K_block])
- inner reduction over N using DY and B
- output shape [BLOCK_R, BLOCK_DIM]
When COMPUTE_DA=False:
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
Grid: (E, cdiv(N, BLOCK_DIM))
- outer_ptr/stride = DY (read [M, N_block])
- inner reduction over K using X and A
- output shape [BLOCK_DIM, BLOCK_R]
No atomic adds — each (E, dim_block) pair is written by exactly one block.
"""
E_idx = tl.program_id(0)
dim_block_id = tl.program_id(1)
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
num_tokens = end_idx - start_idx
# Output dimension tile (K for dA, N for dB)
if COMPUTE_DA:
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
else:
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_mask = dim_block < OUT_DIM
R_block = tl.arange(0, BLOCK_R)
R_mask = R_block < ACTUAL_R
lora_offset = E_idx * ACTUAL_R
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
if COMPUTE_DA:
out_blk_ptrs = (
OUT_ptr
+ (lora_offset + R_block)[:, None] * stride_out0
+ dim_block[None, :] * stride_out1
)
out_mask = R_mask[:, None] & dim_mask[None, :]
else:
out_blk_ptrs = (
OUT_ptr
+ dim_block[:, None] * stride_out0
+ (lora_offset + R_block)[None, :] * stride_out1
)
out_mask = dim_mask[:, None] & R_mask[None, :]
if num_tokens > 0:
M_block = tl.arange(0, BLOCK_M)
INPUT_DTYPE = X_ptr.dtype.element_ty
BLOCK_INNER: tl.constexpr = 64
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
if COMPUTE_DA:
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
else:
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
M_iters = tl.cdiv(num_tokens, BLOCK_M)
for i in range(M_iters):
M_idx = start_idx + i * BLOCK_M + M_block
M_mask = M_idx < end_idx
if COMPUTE_DA:
# Load X[M, K_block] (the "outer" tensor for dA)
outer = tl.load(
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < N
dy_tile = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ inn_off[None, :] * stride_dyn,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
lw_tile = tl.load(
LW_ptr
+ inn_off[:, None] * stride_lw0
+ (lora_offset + R_block)[None, :] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
acc += tl.dot(
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
)
else:
# Load DY[M, N_block] (the "outer" tensor for dB)
outer = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ dim_block[None, :] * stride_dyn,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < K
x_tile = tl.load(
X_ptr
+ M_idx[:, None] * stride_xm
+ inn_off[None, :] * stride_xk,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
# We want A[e]^T: [K, R], so load as [K_inner, R]
lw_tile = tl.load(
LW_ptr
+ (lora_offset + R_block)[None, :] * stride_lw0
+ inn_off[:, None] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
acc += tl.dot(
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
)
tl.store(
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
)
else:
# Zero out this expert's slice — needed because output uses empty_like
if COMPUTE_DA:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
else:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
def group_bwd_lora(
DY: torch.Tensor,
X: torch.Tensor,
@@ -1344,6 +1830,9 @@ def group_bwd_lora(
"""
Compute LoRA gradients for A and B on expert-grouped data.
Uses split dA/dB kernels that eliminate atomic adds by giving each
(expert, output_block) pair its own thread block.
Args:
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
X: Input [M_total, K] (grouped by expert)
@@ -1361,19 +1850,46 @@ def group_bwd_lora(
K = X.size(1)
N = DY.size(1)
# Zero-init for atomic accumulation
dA = torch.zeros_like(lora_A)
dB = torch.zeros_like(lora_B)
# No zero-init needed: the split kernels write zeros for experts with
# zero routed tokens directly in the kernel (else branch).
dA = torch.empty_like(lora_A)
dB = torch.empty_like(lora_B)
BLOCK_R = _block_r_for_rank(R)
def grid(META):
return (
E * triton.cdiv(K, META["BLOCK_K"]),
triton.cdiv(N, META["BLOCK_N"]),
)
def grid_dA(META):
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
_group_bwd_lora[grid](
_group_bwd_lora_split[grid_dA](
DY,
DY.stride(0),
DY.stride(1),
X,
X.stride(0),
X.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
expert_offsets,
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=N,
scaling=scaling,
COMPUTE_DA=True,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
def grid_dB(META):
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
_group_bwd_lora_split[grid_dB](
DY,
DY.stride(0),
DY.stride(1),
@@ -1383,12 +1899,6 @@ def group_bwd_lora(
lora_A,
lora_A.stride(0),
lora_A.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
dB,
dB.stride(0),
dB.stride(1),
@@ -1396,9 +1906,11 @@ def group_bwd_lora(
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R, # True LoRA rank
BLOCK_R=BLOCK_R, # Padded tile size
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=K,
scaling=scaling,
COMPUTE_DA=False,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)

View File

@@ -489,20 +489,71 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Selective expert weight dequantization
# ====================================================================
# When experts are BnB-quantized (quantize_moe_experts), dequantize
# only the active experts instead of all E. This saves ~97% memory
# for the transient dequant buffer when few experts are active.
use_selective = (
getattr(self, "_use_selective_dequant", False)
and hasattr(experts, "parametrizations")
and "gate_up_proj" in experts.parametrizations
)
if use_selective:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
get_active_experts,
remap_expert_indices,
selective_expert_weights,
selective_lora_weights,
)
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
remapped_expert_idxs, compact_offsets = remap_expert_indices(
sorted_expert_idxs,
expert_offsets,
active_experts,
num_experts,
)
# Dequantize only active experts' weights
gate_up_W = selective_expert_weights(
experts,
"gate_up_proj",
active_experts,
).transpose(2, 1) # [num_active, hidden, 2*inter]
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup_A, gup_B = selective_lora_weights(
gup_A,
gup_B,
active_experts,
num_experts,
)
gup_lora = (gup_A, gup_B, gup_scaling)
# Use remapped indices for ScatterMoE kernels
sei_gup = remapped_expert_idxs
eo_gup = compact_offsets
else:
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
sei_gup = sorted_expert_idxs
eo_gup = expert_offsets
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
@@ -516,9 +567,9 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
grouped_in=False,
grouped_out=True,
)
@@ -529,7 +580,29 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if use_selective:
down_W = selective_expert_weights(
experts,
"down_proj",
active_experts,
).transpose(2, 1) # [num_active, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
down_A, down_B = selective_lora_weights(
down_A,
down_B,
active_experts,
num_experts,
)
down_lora = (down_A, down_B, down_scaling)
sei_down = remapped_expert_idxs
eo_down = compact_offsets
else:
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
sei_down = sorted_expert_idxs
eo_down = expert_offsets
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
@@ -537,9 +610,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
@@ -554,9 +627,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
grouped_in=True,
grouped_out=False,
gates=routing_weights,

View File

@@ -0,0 +1,282 @@
"""
Selective Expert Dequantization
===============================
Instead of dequantizing all E expert weight matrices at once (which creates
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
are actually routed to by the current batch's top-k selection.
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
- Savings: ~97% memory reduction per layer
This module provides format-agnostic selective weight extraction:
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
- bf16/fp32: direct indexing (no dequant needed)
- FP8: slice + cast
The ScatterMoE kernel itself doesn't change — we remap expert indices
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
weight tensor.
"""
import torch
import torch.nn as nn
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
"""Get sorted unique expert indices from the routing output.
Args:
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
E: Total number of experts
Returns:
active: Sorted unique expert indices [num_active]
"""
return torch.unique(sorted_expert_idxs)
def remap_expert_indices(
sorted_expert_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Remap global expert indices to compact indices.
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
sort order. Also compacts expert_offsets to only active experts.
Args:
sorted_expert_idxs: [T*k] expert ids in sorted order
expert_offsets: [E] cumulative token counts (original)
active_experts: [num_active] sorted unique expert ids
E: Total number of experts
Returns:
remapped_idxs: [T*k] expert ids in [0..num_active-1]
compact_offsets: [num_active] cumulative token counts
"""
# Build remap table: global_id -> compact_id
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
remap[active_experts] = torch.arange(
len(active_experts), device=sorted_expert_idxs.device
)
remapped_idxs = remap[sorted_expert_idxs]
# Compact the expert_offsets: only keep active experts' cumulative counts
compact_offsets = expert_offsets[active_experts]
return remapped_idxs, compact_offsets
def _selective_dequant_bnb4(
raw_param: torch.Tensor,
quant_state,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
) -> torch.Tensor:
"""Dequantize only selected experts from BnB 4-bit packed data.
The raw parameter is a flattened 4-bit packed tensor. Each expert's
data is contiguous (stored in expert-major order), so we can gather
the packed data and absmax blocks for active experts, then dequantize
as one contiguous block.
Args:
raw_param: Flattened uint8 tensor of packed 4-bit weights
quant_state: BnB QuantState with absmax, blocksize, code, etc.
active_experts: [num_active] expert indices to dequantize
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
Returns:
Dequantized weights [num_active, dim1, dim2] in original dtype
"""
import bitsandbytes.functional as F # noqa: N812
from bitsandbytes.functional import QuantState
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
blocks_per_expert = expert_numel // quant_state.blocksize
num_active = len(active_experts)
if blocks_per_expert == 0:
# Expert is smaller than one quantization block — blocks span across
# expert boundaries, so per-expert slicing isn't possible.
# Fallback: full dequantize + index.
full = F.dequantize_4bit(raw_param, quant_state)
E_total = full.numel() // expert_numel
return full.reshape(E_total, *expert_shape)[active_experts]
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
selective_dequant_nf4_triton,
)
# Handle nested (double) quantization: dequantize absmax first
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
if quant_state.nested:
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
else:
absmax = quant_state.absmax
return selective_dequant_nf4_triton(
packed_data=raw_param,
absmax=absmax,
active_experts=active_experts,
expert_shape=expert_shape,
blocksize=quant_state.blocksize,
dtype=quant_state.dtype,
codebook=quant_state.code,
)
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
raw_flat = raw_param.reshape(-1)
offsets_qt = (
active_experts.long()[:, None] * packed_per_expert
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
qt_gathered = raw_flat[offsets_qt]
offsets_abs = (
active_experts.long()[:, None] * blocks_per_expert
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
if quant_state.nested:
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
full_absmax += quant_state.offset
if full_absmax.dtype != torch.float32:
full_absmax = full_absmax.float()
absmax_gathered = full_absmax[offsets_abs]
else:
absmax_gathered = quant_state.absmax[offsets_abs]
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
gathered_qs = QuantState(
absmax=absmax_gathered,
shape=torch.Size([num_active * expert_numel]),
blocksize=quant_state.blocksize,
quant_type=quant_state.quant_type,
code=quant_state.code,
dtype=quant_state.dtype,
)
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
return deq.reshape(num_active, *expert_shape)
def _selective_index_dense(
param: torch.Tensor,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Select experts from a dense (bf16/fp32) weight tensor.
Simple indexing — no dequantization needed.
"""
return param[active_experts]
def selective_expert_weights(
experts_module: nn.Module,
param_name: str,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Extract and dequantize only the active experts' weights.
Format-agnostic: dispatches based on whether the parameter is
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
Args:
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
param_name: "gate_up_proj" or "down_proj"
active_experts: [num_active] sorted unique expert indices
Returns:
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
"""
# Check if the parameter is BnB-quantized via parametrize
if (
hasattr(experts_module, "parametrizations")
and param_name in experts_module.parametrizations
):
param_list = experts_module.parametrizations[param_name]
parametrization = param_list[0]
# BnB 4-bit parametrization
if hasattr(parametrization, "quant_state"):
# The raw quantized data is on the ParametrizationList, not the
# individual Bnb4bitParametrization module
raw_param = param_list.original
qs = parametrization.quant_state
# qs.shape is the original tensor shape before flattening.
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
orig_shape = qs.shape
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
expert_shape = (orig_shape[1], orig_shape[2])
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
# Flattened — need to infer from module attributes
E_total = getattr(experts_module, "num_experts", None)
if E_total is None:
E_total = int(active_experts.max().item()) + 1
expert_numel = orig_shape[0] // E_total
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
experts_module, "intermediate_dim", None
)
if d2 and expert_numel % d2 == 0:
expert_shape = (expert_numel // d2, d2)
else:
full = getattr(experts_module, param_name)
return full[active_experts]
else:
full = getattr(experts_module, param_name)
return full[active_experts]
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
# Dense parameter (bf16/fp32) — direct indexing
param = getattr(experts_module, param_name)
if param.dim() == 3:
return param[active_experts]
# Fallback: full access
return param
def selective_lora_weights(
lora_A: torch.Tensor,
lora_B: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select LoRA A and B weights for only the active experts.
LoRA layout (scattermoe format):
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
Returns compact:
A: [r*num_active, K]
B: [N, r*num_active]
"""
R = lora_A.size(0) // E
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
row_idx = (
active_experts.long()[:, None] * R
+ torch.arange(R, device=lora_A.device)[None, :]
).reshape(-1)
compact_A = lora_A[row_idx] # [r*num_active, K]
compact_B = lora_B[:, row_idx] # [N, r*num_active]
return compact_A, compact_B

View File

@@ -0,0 +1,179 @@
"""
Triton kernel for fused selective expert gather + NF4 dequantization.
Instead of:
1. Gather packed uint8 data for active experts (memory copy)
2. Gather absmax for active experts (memory copy)
3. Call BnB dequantize_4bit CUDA kernel
This kernel does all three in one pass:
- Reads packed NF4 bytes from expert-strided positions
- Looks up the NF4 codebook
- Multiplies by the per-block absmax
- Writes bf16 output directly
This eliminates the intermediate gather buffer entirely.
"""
import torch
import triton
import triton.language as tl
# NF4 codebook (16 values, precomputed by BnB)
# These are the normalized float4 reconstruction values
NF4_CODEBOOK = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
@triton.jit
def _selective_dequant_nf4_kernel(
# Input: packed NF4 data (flattened, expert-major order)
packed_ptr,
# Input: absmax values (flattened, expert-major order)
absmax_ptr,
# Input: active expert indices
active_experts_ptr,
# Input: NF4 codebook (16 float values)
codebook_ptr,
# Output: dequantized bf16 weights [num_active, expert_numel]
out_ptr,
stride_out_e, # stride for expert dim in output
# Dimensions
num_active,
packed_per_expert, # expert_numel // 2
blocks_per_expert, # expert_numel // blocksize
blocksize: tl.constexpr,
# Tile size
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
):
"""
Each program processes BLOCK_SIZE elements from one expert.
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
For each output element:
1. Compute which byte in packed data contains this element
2. Extract the 4-bit nibble (high or low)
3. Look up in NF4 codebook
4. Scale by absmax for this block
"""
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
block_id = tl.program_id(1) # which element block
# Load the global expert index
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = elem_offset < expert_numel
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
byte_idx = elem_offset // 2
is_high = (elem_offset % 2) == 1
# Read packed bytes from the global expert's region
packed_global_offset = expert_global * packed_per_expert + byte_idx
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
tl.int32
)
# Extract 4-bit nibble
# BnB packing: high nibble = even element, low nibble = odd element
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
# NF4 codebook lookup
# Load all 16 codebook values (small, fits in registers)
# Use gather from codebook pointer
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
# Load absmax for this element's quantization block
block_idx = elem_offset // blocksize
absmax_global_offset = expert_global * blocks_per_expert + block_idx
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
# Dequantize: value = codebook[nibble] * absmax
result = code_val * absmax_val
# Store to output
out_offset = expert_local_idx * stride_out_e + elem_offset
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
def selective_dequant_nf4_triton(
packed_data: torch.Tensor,
absmax: torch.Tensor,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
blocksize: int,
dtype: torch.dtype = torch.bfloat16,
codebook: torch.Tensor | None = None,
) -> torch.Tensor:
"""Fused selective gather + NF4 dequantization via Triton kernel.
Args:
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
absmax: Per-block scaling factors [total_blocks]
active_experts: Sorted indices of experts to dequantize [num_active]
expert_shape: (dim1, dim2) per expert
blocksize: Quantization block size
dtype: Output dtype (default bf16)
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
Returns:
Dequantized weights [num_active, dim1, dim2]
"""
num_active = active_experts.shape[0]
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2
blocks_per_expert = expert_numel // blocksize
# Prepare codebook on device
if codebook is None:
codebook = torch.tensor(
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
)
else:
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
# Flatten inputs
packed_flat = packed_data.reshape(-1)
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
# Output buffer
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
_selective_dequant_nf4_kernel[grid](
packed_flat,
absmax_flat,
active_experts,
codebook,
out,
out.stride(0),
num_active=num_active,
packed_per_expert=packed_per_expert,
blocks_per_expert=blocks_per_expert,
blocksize=blocksize,
BLOCK_SIZE=BLOCK_SIZE,
)
return out.reshape(num_active, *expert_shape)

View File

@@ -61,7 +61,16 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
# Prefer text backbone type for VLMs, but fall back to base type
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if (
moe_model_type not in SPARSE_MOE_BLOCK
and cfg.model_config_type in SPARSE_MOE_BLOCK
):
moe_model_type = cfg.model_config_type
if cfg.use_scattermoe:
self._register_kernels()

View File

@@ -30,6 +30,15 @@ class LigerArgs(BaseModel):
liger_rope: bool | None = None
liger_rms_norm: bool | None = None
liger_rms_norm_gated: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
)
},
)
liger_layer_norm: bool | None = None
liger_swiglu: bool | None = None
liger_glu_activation: bool | None = None

View File

@@ -0,0 +1,175 @@
"""
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3_5(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
if rms_norm:
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward

View File

@@ -0,0 +1,198 @@
"""
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
load_balancing_loss_func,
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits,
labels,
self.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def apply_liger_kernel_to_qwen3_5_moe(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
if rms_norm:
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_mod.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward

View File

@@ -174,6 +174,19 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5":
from axolotl.integrations.liger.models.qwen3_5 import (
apply_liger_kernel_to_qwen3_5,
)
apply_liger_kernel_to_qwen3_5(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
@@ -186,6 +199,19 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5_moe":
from axolotl.integrations.liger.models.qwen3_5_moe import (
apply_liger_kernel_to_qwen3_5_moe,
)
apply_liger_kernel_to_qwen3_5_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite

147
src/axolotl/kernels/dora.py Normal file
View File

@@ -0,0 +1,147 @@
"""
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
Fuses the weight norm computation and magnitude scaling to avoid
materializing the full [out_features, in_features] combined weight matrix.
The B@A product is computed row-by-row inside the kernel.
"""
import torch
import triton
import triton.language as tl
from .quantize import dequantize
@triton.jit
def _dora_fused_norm_kernel(
# Pointers
W_ptr, # base weight [out, in] (dequantized, row-major)
B_ptr, # LoRA B [out, rank] (row-major)
A_ptr, # LoRA A [rank, in] (row-major)
mag_ptr, # magnitude vector [out]
out_ptr, # output mag_norm_scale [out]
# Shapes
out_features,
in_features,
rank,
# Scaling
lora_scale, # float scaling factor
# Block sizes
BLOCK_IN: tl.constexpr,
BLOCK_R: tl.constexpr, # >= rank, power of 2
):
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
Each program handles one output row. B[row,:] is loaded once (small),
then we tile over in_features computing the dot product with A[:,tile]
and accumulating the squared norm.
This avoids materializing the full [out, in] B@A matrix.
"""
row = tl.program_id(0)
if row >= out_features:
return
# Accumulate squared norm across tiles of in_features
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
for start in range(0, in_features, BLOCK_IN):
cols = start + tl.arange(0, BLOCK_IN)
col_mask = cols < in_features
# Load W[row, cols]
w_vals = tl.load(
W_ptr + row * in_features + cols,
mask=col_mask,
other=0.0,
).to(tl.float32)
# Compute (B[row,:] @ A[:, cols]) for this tile
# Load B[row, r] as scalar and A[r, cols] as vector for each r
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
for r in tl.static_range(BLOCK_R):
# Load scalar B[row, r]
b_val = tl.load(
B_ptr + row * rank + r,
mask=(r < rank),
other=0.0,
).to(tl.float32)
# Load vector A[r, cols]
a_vals = tl.load(
A_ptr + r * in_features + cols,
mask=(col_mask & (r < rank)),
other=0.0,
).to(tl.float32)
ba_vals += b_val * a_vals
# Combined: W + s * (B @ A)
combined = w_vals + lora_scale * ba_vals
# Accumulate squared values
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
# Reduce to scalar norm
norm_sq = tl.sum(norm_sq_acc, axis=0)
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
# Load magnitude and compute scale
mag = tl.load(mag_ptr + row).to(tl.float32)
scale = mag / norm
tl.store(out_ptr + row, scale)
def triton_dora_scale(
W: torch.Tensor,
W_quant,
A: torch.Tensor,
B: torch.Tensor,
s: float,
magnitude: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Compute DoRA mag_norm_scale using fused Triton kernel.
Computes B@A row-by-row inside the kernel, avoiding the full
[out_features, in_features] materialization.
Args:
W: base weight [out, in] (possibly quantized)
W_quant: quantization state
A: LoRA A [rank, in]
B: LoRA B [out, rank]
s: LoRA scaling factor
magnitude: learned magnitude [out]
dtype: compute dtype
Returns:
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
"""
# Dequantize W to [out, in]
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
out_features, in_features = W_full.shape
rank = A.shape[0]
out = torch.empty(out_features, dtype=dtype, device=W.device)
# Block sizes
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
BLOCK_R = triton.next_power_of_2(rank)
_dora_fused_norm_kernel[(out_features,)](
W_full,
B.contiguous().to(dtype),
A.contiguous().to(dtype),
magnitude.contiguous(),
out,
out_features=out_features,
in_features=in_features,
rank=rank,
lora_scale=s,
BLOCK_IN=BLOCK_IN,
BLOCK_R=BLOCK_R,
)
return out.detach()

File diff suppressed because it is too large Load Diff

View File

@@ -105,6 +105,10 @@ def dequantize(
# Extract quantization state
if not isinstance(quant_state, list):
# New style quant_state class
# Non-double-quantized models have offset=None and state2=None
if quant_state.offset is None or quant_state.state2 is None:
# Fall back to bitsandbytes standard dequantize
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape
dtype = quant_state.dtype

View File

@@ -0,0 +1,333 @@
"""
Fused RMSNorm + SiLU Gate Triton kernel.
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
and silu(G) = G * sigmoid(G)
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _rms_norm_gated_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
offset,
BLOCK_SIZE: tl.constexpr,
):
"""
Y = (W + offset) * (X / RMS(X)) * silu(G)
All computation done in fp32 (Gemma-style), result cast to input dtype.
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
# Cast everything to fp32
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
W_fp32 = W_row.to(tl.float32)
# RMS norm
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
X_norm = X_fp32 * rstd
# SiLU gate: silu(G) = G * sigmoid(G)
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# Fused output
Y_row = (offset + W_fp32) * X_norm * silu_G
tl.store(
Y_ptr + row_idx * Y_row_stride + col_offsets,
Y_row.to(X_row_dtype),
mask=mask,
)
@triton.jit
def _rms_norm_gated_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
dG_ptr,
dG_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
dW = sum_batch(dY * X_norm * silu(G))
dG = dY * (W + offset) * X_norm * silu'(G)
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
where m = dY * (W + offset) * silu(G)
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row.to(tl.float32) + offset
for row_idx in range(row_start, row_end):
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
)
G_row = tl.load(
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
)
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
# Cast to fp32
dY_fp32 = dY_row.to(tl.float32)
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
# Recompute intermediates
X_norm = X_fp32 * rstd_row
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# dW: accumulate dY * X_norm * silu(G)
dW_acc += dY_fp32 * X_norm * silu_G
# dG: dY * (W + offset) * X_norm * silu'(G)
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
tl.store(
dG_ptr + row_idx * dG_row_stride + col_offsets,
dG_row.to(X_dtype),
mask=mask,
)
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
m = dY_fp32 * W_row * silu_G
dX_row = rstd_row * m
dX_row += rstd_row * (
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets,
dX_row.to(X_dtype),
mask=mask,
)
tl.store(
dW_ptr + row_block_id * dW_row_stride + col_offsets,
dW_acc,
mask=mask,
)
def rms_norm_gated_forward(X, G, W, eps, offset):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
G = G.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
assert X.shape[1] == W.shape[0], (
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
)
assert X.shape == G.shape, (
f"X and G must have same shape, got {X.shape} and {G.shape}"
)
_rms_norm_gated_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
dX = torch.empty_like(dY)
dG = torch.empty_like(dY)
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
_rms_norm_gated_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
dG,
dG.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dX = dX.view(*shape)
dG = dG.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)
return dX, dG, dW
class FusedRMSNormGatedFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, G, W, eps, offset=0.0):
"""
X: (B, T, H) or (BxT, H) — input hidden states
G: (B, T, H) or (BxT, H) — gate tensor
W: (H,) — weight parameter
"""
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
X, G, W, eps, offset
)
ctx.offset = offset
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, G, W, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, G, W, RSTD = ctx.saved_tensors
dX, dG, dW = rms_norm_gated_backward(
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
)
return dX, dG, dW, None, None
class FusedRMSNormGated(torch.nn.Module):
"""
Fused RMSNorm + SiLU Gate.
Computes: Y = W * RMSNorm(X) * silu(G)
Drop-in replacement for Qwen3_5RMSNormGated with matching
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
and forward signature: forward(hidden_states, gate=None)
"""
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.offset = offset
def forward(self, hidden_states, gate=None):
if gate is None:
raise ValueError("FusedRMSNormGated requires a gate tensor")
if hidden_states.device.type != "cuda":
raise ValueError(
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
)
return FusedRMSNormGatedFunction.apply(
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

View File

@@ -505,6 +505,20 @@ class ModelLoader:
elif not is_ds_zero3:
self.model_kwargs["device_map"] = device_map
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
# so the actual VRAM usage is much less than bf16 estimates.
# When device_map is "auto", accelerate's infer_auto_device_map computes
# the device map at bf16 size (before quantization), causing it to offload
# layers to CPU, which BnB then rejects. Force single-GPU placement to
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
"auto",
None,
):
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"

View File

@@ -571,15 +571,6 @@ class PatchManager:
LOG.info("Patching with xformers attention...")
hijack_llama_attention()
def _patch_llama_sample_packing(self):
"""Apply sample packing patches for LLaMA models."""
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
hijack_llama_prepare_4d_mask()
def _patch_llama_derived_model(self):
"""Modify all llama derived models in one block."""
if self.cfg.is_llama_derived_model and not (
@@ -591,8 +582,6 @@ class PatchManager:
self._patch_llama_flash_attention()
elif self.cfg.xformers_attention:
self._patch_llama_xformers_attention()
elif self.cfg.sample_packing:
self._patch_llama_sample_packing()
elif self.cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."

View File

@@ -221,6 +221,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()

View File

@@ -78,30 +78,21 @@ def patch_parallelism_config():
def patch_prepare_cp():
import functools
import contextlib
import torch
from accelerate import Accelerator
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
@contextlib.contextmanager
def _noop_cp_context(
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
):
yield
self._cp_context = _noop_cp_context
return args
Accelerator._prepare_cp = patched_prepare_cp

View File

@@ -1,24 +0,0 @@
"""
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
"""
from typing import Optional
import torch
from axolotl.monkeypatch.utils import mask_2d_to_4d
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def hijack_expand_mask():
import transformers
transformers.models.llama.modeling_llama._expand_mask = _expand_mask

View File

@@ -1,26 +0,0 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
from axolotl.monkeypatch.utils import (
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
def hijack_llama_prepare_4d_mask():
from transformers import modeling_attn_mask_utils
from transformers.models.llama import modeling_llama
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = (
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = (
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_llama._prepare_4d_causal_attention_mask = (
patched_prepare_4d_causal_attention_mask
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = (
patched_prepare_4d_causal_attention_mask
)

View File

@@ -12,6 +12,7 @@ from torch import nn
from transformers import AutoConfig
from axolotl.kernels.lora import (
apply_lora_embedding,
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
@@ -51,6 +52,29 @@ QKV_PATCHES = [
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
(
"""
query_states, gate = torch.chunk(
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states, gate = torch.chunk(
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
]
ORIGINAL_O_CODE = """
@@ -299,6 +323,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
if hasattr(pretrained_model, "language_model"):
return pretrained_model.language_model.layers
if hasattr(pretrained_model, "model"):
if hasattr(pretrained_model.model, "language_model"):
return pretrained_model.model.language_model.layers
return pretrained_model.model.layers
raise NotImplementedError(
@@ -345,13 +371,13 @@ def apply_lora_kernel_patches(
active_adapter = model.active_adapter
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model
# Log what features are active
if lora_config.lora_dropout > 0:
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
if lora_config.bias != "none":
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
if lora_config.use_dora:
LOG.info("LoRA kernels: DoRA enabled")
# This needs to be reset after patching
original_level = LOG.getEffectiveLevel()
@@ -394,44 +420,33 @@ def apply_lora_kernel_patches(
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
hasattr(module, "lora_A") for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention QKV projections - requires LoRA adapters"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
if can_patch_o:
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention output projection - requires LoRA adapters"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_mlp:
@@ -439,15 +454,50 @@ def apply_lora_kernel_patches(
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters and no "
"lora_magnitude_vector (DoRA)"
"Cannot patch some MLP layers - requires LoRA adapters"
)
# Patch embedding layers (model-level, not per-layer)
if cfg.lora_embedding_kernel:
_patch_embedding_layers(model, cfg)
LOG.setLevel(original_level)
return model
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
"""Patch embedding layers with fused LoRA kernel.
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
"""
pretrained_model = model.model
patched = 0
# Find embedding modules - check common locations
for attr_path in [
("model", "embed_tokens"),
("model", "language_model", "embed_tokens"),
]:
parent = pretrained_model
for attr in attr_path:
parent = getattr(parent, attr, None)
if parent is None:
break
if parent is not None and hasattr(parent, "lora_embedding_A"):
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
parent.forward = types.MethodType(apply_lora_embedding, parent)
patched += 1
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
# when included in target_modules. No special embedding handling needed since
# PEFT wraps it as a Linear (not Embedding) even for tied models.
if not patched:
LOG.debug("No embedding layers with LoRA found to patch")
class FakeMLP(nn.Module):
"""
placeholder MLP for triton patching

View File

@@ -59,6 +59,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"ministral3",
"mistral4",
"afmoe",
"nemotron",
]

View File

@@ -3,15 +3,10 @@ Shared utils for the monkeypatches
"""
import re
from typing import Optional, Tuple
from typing import Tuple
import torch
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
@@ -170,65 +165,6 @@ def set_module_name(model, name, value):
setattr(parent, child_name, value)
def mask_2d_to_4d(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
return masked_zero_one_mask
def patched_prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def patched_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)

View File

@@ -0,0 +1,96 @@
"""
Synthetic dataset generator for benchmarking and testing.
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
Useful for benchmarking memory usage and speed by sequence length, and for validating
weighted dataset mixes.
YAML configuration example:
datasets:
- path: synthetic
type: _synthetic
length: 1000
sequence_length: 2048
min_input_id: 100
max_input_id: 32000
seed: 42
"""
from typing import Any, Dict, Optional
import numpy as np
from datasets import Dataset
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
def __init__(
self,
sequence_length: int = 2048,
length: int = 1000,
min_input_id: int = 100,
max_input_id: int = 32000,
seed: Optional[int] = None,
):
self.sequence_length = sequence_length
self.length = length
self.min_input_id = min_input_id
self.max_input_id = max_input_id
self.seed = seed
def wrap_dataset(
self,
dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
) -> Dataset:
LOG.info(
f"Generating synthetic dataset: {self.length} samples, "
f"sequence_length={self.sequence_length}, "
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
)
rng = np.random.default_rng(self.seed)
input_ids = rng.integers(
low=self.min_input_id,
high=self.max_input_id,
size=(self.length, self.sequence_length),
).tolist()
attention_mask = [[1] * self.sequence_length] * self.length
# labels == input_ids means we train on all tokens
labels = [row[:] for row in input_ids]
return Dataset.from_dict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
length = ds_cfg.get("length", 1000)
min_input_id = ds_cfg.get("min_input_id", 100)
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
seed = ds_cfg.get("seed", None)
return SyntheticDatasetStrategy(
sequence_length=sequence_length,
length=length,
min_input_id=min_input_id,
max_input_id=max_input_id,
seed=seed,
)

View File

@@ -9,7 +9,6 @@ import os
import shutil
import signal
import sys
import typing
import weakref
from collections import OrderedDict
from contextlib import ExitStack
@@ -42,9 +41,6 @@ from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
from axolotl.utils.trainer import setup_trainer
if typing.TYPE_CHECKING:
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -82,7 +78,7 @@ def setup_model_and_tokenizer(
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
model, peft_config = model_loader.load()
if model.generation_config is not None:
if getattr(model, "generation_config", None) is not None:
model.generation_config.do_sample = True
model_properties = model.config.to_dict()
@@ -487,7 +483,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
Trainer,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
@@ -554,6 +550,36 @@ def setup_model_and_trainer(
)
def _is_tui_enabled(cfg: DictDefault) -> bool:
"""Check if TUI is enabled via config or environment variable."""
if os.environ.get("AXOLOTL_TUI", "").lower() in ("1", "true", "yes"):
return True
tui = cfg.get("tui")
if tui is None:
return False
if isinstance(tui, bool):
return tui
if isinstance(tui, dict):
return tui.get("enabled", False)
if hasattr(tui, "enabled"):
return tui.enabled
return False
def _get_tui_config(cfg: DictDefault) -> dict:
"""Extract TUI config dict from cfg."""
tui = cfg.get("tui")
if tui is None or isinstance(tui, bool):
return {"enabled": True}
if isinstance(tui, dict):
return {**tui, "enabled": True}
if hasattr(tui, "model_dump"):
d = tui.model_dump()
d["enabled"] = True
return d
return {"enabled": True}
@send_errors
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
@@ -577,6 +603,37 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Register TUI callback if enabled and rank 0
tui_enabled = _is_tui_enabled(cfg)
if tui_enabled and cfg.local_rank == 0:
from axolotl.tui import AxolotlTUICallback
from axolotl.tui.config import TUIConfig
tui_config = _get_tui_config(cfg)
tui_config_obj = (
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
)
# Reuse the early-started renderer if available (started in do_train)
early_renderer = getattr(cfg, "_tui_renderer", None)
early_queue = getattr(cfg, "_tui_queue", None)
tui_callback = AxolotlTUICallback(config=tui_config_obj)
if early_renderer is not None and early_queue is not None:
# Reuse the already-running renderer and queue
tui_callback._renderer = early_renderer
tui_callback._queue = early_queue
tui_callback._renderer_started_early = True
trainer.add_callback(tui_callback)
# Stash model info so on_train_begin can emit a single unified run_info event
tui_callback._pending_run_info = {
"model_name": cfg.base_model or "",
"training_mode": str(cfg.rl) if cfg.rl else "sft",
"world_size": int(os.environ.get("WORLD_SIZE", 1)),
}
LOG.info("TUI dashboard enabled")
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)

View File

@@ -0,0 +1,17 @@
"""Axolotl Training TUI — rich-based terminal dashboard for monitoring training runs."""
from axolotl.tui.callback import AxolotlTUICallback
from axolotl.tui.config import TUIConfig
from axolotl.tui.io_capture import LineParser, register_parser
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
__all__ = [
"AxolotlTUICallback",
"BasePanel",
"LineParser",
"TUIConfig",
"TUIState",
"register_panel",
"register_parser",
]

142
src/axolotl/tui/callback.py Normal file
View File

@@ -0,0 +1,142 @@
"""AxolotlTUICallback — HF TrainerCallback that feeds metrics to the TUI."""
from __future__ import annotations
import logging
import queue
from transformers.trainer_callback import TrainerCallback
from axolotl.tui.config import TUIConfig
from axolotl.tui.renderer import TUIRenderer
class _TUILogHandler(logging.Handler):
"""Logging handler that pushes log records into the TUI metric queue."""
_LEVEL_MAP = {
logging.DEBUG: "debug",
logging.INFO: "info",
logging.WARNING: "warning",
logging.ERROR: "error",
logging.CRITICAL: "error",
}
def __init__(self, metric_queue: queue.Queue, min_level: str = "info"):
super().__init__()
level_name = min_level.upper()
self.setLevel(getattr(logging, level_name, logging.INFO))
self._queue = metric_queue
def emit(self, record: logging.LogRecord) -> None:
try:
level = self._LEVEL_MAP.get(record.levelno, "info")
msg = self.format(record)
self._queue.put_nowait(
{
"type": "log_line",
"level": level,
"message": msg,
}
)
except queue.Full:
pass
except Exception:
self.handleError(record)
class AxolotlTUICallback(TrainerCallback):
"""Pushes training metrics into a queue for the TUI renderer.
The callback never blocks on the render thread. The queue is bounded
(maxsize=512) with put_nowait; overflow is silently dropped.
"""
def __init__(self, config: TUIConfig):
self._config = config
self._queue: queue.Queue = queue.Queue(maxsize=4096)
self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
self._log_handler: _TUILogHandler | None = None
self._renderer_started_early: bool = False
self._pending_run_info: dict | None = None
def _put(self, event: dict) -> None:
try:
self._queue.put_nowait(event)
except queue.Full:
pass
def on_train_begin(self, args, state, control, model=None, **kwargs):
# Send a single unified run_info event with all fields
run_info = {
"type": "run_info",
"run_name": getattr(args, "run_name", "") or "",
"total_steps": state.max_steps,
"total_epochs": float(args.num_train_epochs)
if args.num_train_epochs
else 1.0,
}
# Merge in model_name/training_mode/world_size if stashed by train.py
if self._pending_run_info:
run_info.update(self._pending_run_info)
self._pending_run_info = None
self._put(run_info)
if not self._renderer_started_early:
# Attach a logging handler to feed log messages into the events panel
self._log_handler = _TUILogHandler(
self._queue, min_level=self._config.log_level
)
self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
# Attach to both root and axolotl loggers (axolotl has propagate=False)
logging.getLogger().addHandler(self._log_handler)
logging.getLogger("axolotl").addHandler(self._log_handler)
# Start the renderer background thread
self._renderer.start()
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# Filter out non-numeric keys and internal keys
filtered = {}
for key, value in logs.items():
if key.startswith("_"):
continue
if isinstance(value, (int, float)):
filtered[key] = value
elif isinstance(value, str):
# HF Trainer sometimes passes string-encoded numbers
try:
filtered[key] = float(value)
except (ValueError, TypeError):
pass
if filtered:
self._put({"type": "metrics", "logs": filtered})
def on_step_end(self, args, state, control, **kwargs):
self._put(
{
"type": "step",
"step": state.global_step,
"total_steps": state.max_steps,
"epoch": state.epoch if state.epoch else 0,
}
)
def on_prediction_step(self, args, state, control, **kwargs):
pass
def on_train_end(self, args, state, control, **kwargs):
self._put({"type": "done"})
# If renderer was started early, do_train's finally block handles stop
if not self._renderer_started_early:
self._renderer.stop()
# Remove the logging handler (only if we added it)
if self._log_handler:
logging.getLogger().removeHandler(self._log_handler)
logging.getLogger("axolotl").removeHandler(self._log_handler)
self._log_handler = None

38
src/axolotl/tui/config.py Normal file
View File

@@ -0,0 +1,38 @@
"""TUI configuration — Pydantic model for TUI settings."""
from __future__ import annotations
from pydantic import BaseModel, Field
class TUIConfig(BaseModel):
"""Configuration for the Axolotl Training TUI dashboard."""
enabled: bool = Field(
default=False,
json_schema_extra={"description": "Enable the TUI dashboard"},
)
refresh_rate: int = Field(
default=4,
json_schema_extra={"description": "Renders per second"},
)
log_level: str = Field(
default="debug",
json_schema_extra={"description": "Minimum log level shown in events panel"},
)
panels: list[str] = Field(
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
json_schema_extra={"description": "Ordered list of panels to display"},
)
hardware_poll_interval: int = Field(
default=2,
json_schema_extra={"description": "Seconds between pynvml GPU queries"},
)
stdout_log_path: str = Field(
default="axolotl_stdout.log",
json_schema_extra={"description": "File path for captured stdout/stderr log"},
)
parser_plugins: list[str] = Field(
default_factory=list,
json_schema_extra={"description": "List of extra parser classes to load"},
)

72
src/axolotl/tui/gpu.py Normal file
View File

@@ -0,0 +1,72 @@
"""GPU polling wrapper around pynvml with graceful fallback."""
from __future__ import annotations
import logging
from axolotl.tui.state import GPUStats
LOG = logging.getLogger(__name__)
_nvml_available = False
try:
import pynvml
pynvml.nvmlInit()
_nvml_available = True
except Exception:
LOG.debug("pynvml unavailable — GPU stats will not be shown")
class GPUPoller:
"""Polls local GPU stats via pynvml. Falls back gracefully if unavailable."""
def __init__(self):
self._device_count = 0
if _nvml_available:
try:
self._device_count = pynvml.nvmlDeviceGetCount()
except Exception:
self._device_count = 0
@property
def available(self) -> bool:
return _nvml_available and self._device_count > 0
def poll(self) -> list[GPUStats]:
if not self.available:
return []
stats = []
for i in range(self._device_count):
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle)
if isinstance(name, bytes):
name = name.decode("utf-8")
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
temp = pynvml.nvmlDeviceGetTemperature(
handle, pynvml.NVML_TEMPERATURE_GPU
)
try:
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
except Exception:
power = None
stats.append(
GPUStats(
id=i,
name=name,
util_pct=util.gpu,
vram_used_gb=mem.used / (1024**3),
vram_total_gb=mem.total / (1024**3),
temp_c=temp,
power_w=power,
)
)
except Exception:
LOG.debug("Error polling GPU device %d", i, exc_info=True)
return stats

View File

@@ -0,0 +1,196 @@
"""I/O capture: OS-level stdout/stderr redirect, line parser chain, and parser registry."""
from __future__ import annotations
import logging
import os
import queue
import sys
import threading
from abc import ABC, abstractmethod
from datetime import datetime
from typing import IO
# ---------------------------------------------------------------------------
# Parser registry
# ---------------------------------------------------------------------------
_parser_registry: list[type[LineParser]] = []
def register_parser(cls: type[LineParser]) -> type[LineParser]:
"""Decorator to register a LineParser subclass."""
if cls not in _parser_registry:
_parser_registry.append(cls)
return cls
def get_registered_parsers() -> list[type[LineParser]]:
return list(_parser_registry)
# ---------------------------------------------------------------------------
# Base LineParser
# ---------------------------------------------------------------------------
class LineParser(ABC):
"""Base class for stdout/stderr line parsers."""
priority: int = 50
name: str = ""
@abstractmethod
def parse(self, line: str, source: str) -> list[dict]:
"""Parse a single captured line.
Args:
line: one line of captured output, trailing newline stripped.
source: "stdout" or "stderr".
Returns:
List of event dicts to push onto the metric queue.
Return [] if this line is not relevant.
"""
...
# ---------------------------------------------------------------------------
# ParserChain
# ---------------------------------------------------------------------------
class ParserChain:
def __init__(self):
self._parsers: list[LineParser] = []
def register(self, parser: LineParser) -> None:
self._parsers.append(parser)
self._parsers.sort(key=lambda p: p.priority)
def parse(self, line: str, source: str = "stdout") -> list[dict]:
events: list[dict] = []
for parser in self._parsers:
events.extend(parser.parse(line, source))
return events
# ---------------------------------------------------------------------------
# IOCapture — OS-level fd redirect to pipe
# ---------------------------------------------------------------------------
class IOCapture:
"""Redirects fd 1 and fd 2 into an OS pipe, drains via a reader thread,
passes lines through a ParserChain, and tees to a log file."""
def __init__(
self, log_path: str, parser_chain: ParserChain, metric_queue: queue.Queue
):
self._parser_chain = parser_chain
self._queue = metric_queue
self._log_path = log_path
self._log_file: IO[str] | None = None
self._thread: threading.Thread | None = None
self._read_fd: int | None = None
self._write_fd: int | None = None
self._saved_stdout_fd: int | None = None
self._saved_stderr_fd: int | None = None
def start(self) -> None:
# Write run-start separator
self._log_file = open(self._log_path, "a", buffering=1) # noqa: SIM115
self._log_file.write(
f"\n=== axolotl run started {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n"
)
self._log_file.flush()
# OS-level pipe
self._read_fd, self._write_fd = os.pipe()
# Save originals
self._saved_stdout_fd = os.dup(1)
self._saved_stderr_fd = os.dup(2)
# Redirect both stdout and stderr into the write end
os.dup2(self._write_fd, 1)
os.dup2(self._write_fd, 2)
os.close(self._write_fd) # write end now held by fds 1 and 2
# Also redirect Python-level handles
sys.stdout = open(1, "w", buffering=1, closefd=False) # noqa: SIM115
sys.stderr = open(2, "w", buffering=1, closefd=False) # noqa: SIM115
# Drain thread
self._thread = threading.Thread(target=self._drain, daemon=True)
self._thread.start()
def stop(self) -> None:
# Restore fds — closes the write end, causing reader to see EOF
if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
os.dup2(self._saved_stdout_fd, 1)
os.dup2(self._saved_stderr_fd, 2)
os.close(self._saved_stdout_fd)
os.close(self._saved_stderr_fd)
self._saved_stdout_fd = None
self._saved_stderr_fd = None
if self._thread is not None:
self._thread.join(timeout=2.0)
if self._thread.is_alive():
logging.getLogger(__name__).warning(
"IO capture thread did not exit after 2s"
)
self._thread = None
if self._log_file is not None:
self._log_file.close()
self._log_file = None
def _drain(self) -> None:
# Read raw bytes and split on both \n and \r to handle tqdm progress bars
# which use \r for in-place updates without \n
assert self._read_fd is not None, "_drain called before start()"
with os.fdopen(self._read_fd, "rb") as pipe:
buf = b""
while True:
chunk = pipe.read(4096)
if not chunk:
# EOF — process remaining buffer
if buf:
self._process_line(buf.decode("utf-8", errors="replace"))
break
buf += chunk
# Split on \n or \r
while b"\n" in buf or b"\r" in buf:
# Find the earliest delimiter
idx_n = buf.find(b"\n")
idx_r = buf.find(b"\r")
if idx_n == -1:
idx = idx_r
elif idx_r == -1:
idx = idx_n
else:
idx = min(idx_n, idx_r)
line = buf[:idx].decode("utf-8", errors="replace")
buf = buf[idx + 1 :]
# Handle \r\n as single delimiter
if buf.startswith(b"\n"):
buf = buf[1:]
if line:
self._process_line(line)
def _process_line(self, line: str) -> None:
line = line.rstrip()
if not line:
return
if self._log_file:
self._log_file.write(line + "\n")
self._log_file.flush()
for event in self._parser_chain.parse(line):
try:
self._queue.put_nowait(event)
except queue.Full:
pass

View File

@@ -0,0 +1,63 @@
"""Panel registry and base class for TUI panels."""
from __future__ import annotations
from abc import ABC, abstractmethod
from rich.console import RenderableType
from axolotl.tui.state import TUIState
# ---------------------------------------------------------------------------
# Panel registry
# ---------------------------------------------------------------------------
_panel_registry: dict[str, type[BasePanel]] = {}
def register_panel(position: str = "bottom", weight: int = 50):
"""Decorator to register a panel class with position and weight."""
def decorator(cls: type[BasePanel]) -> type[BasePanel]:
cls.position = position
cls.weight = weight
_panel_registry[cls.name] = cls
return cls
return decorator
def get_registered_panels() -> dict[str, type[BasePanel]]:
return dict(_panel_registry)
# ---------------------------------------------------------------------------
# BasePanel
# ---------------------------------------------------------------------------
class BasePanel(ABC):
name: str = ""
position: str = "bottom"
weight: int = 50
min_height: int = 4
max_height: int | None = None
modes: list[str] = ["*"]
@abstractmethod
def render(self, state: TUIState) -> RenderableType:
"""Return a rich renderable. Called every tick."""
...
def on_event(self, event: dict) -> None: # noqa: B027
"""Optional: react to raw metric events before state is merged."""
pass
# Auto-import built-in panels to trigger registration
from axolotl.tui.panels.completions import CompletionsPanel # noqa: E402, F401
from axolotl.tui.panels.debug import DebugPanel # noqa: E402, F401
from axolotl.tui.panels.events import EventsPanel # noqa: E402, F401
from axolotl.tui.panels.hardware import HardwarePanel # noqa: E402, F401
from axolotl.tui.panels.progress import ProgressPanel # noqa: E402, F401
from axolotl.tui.panels.training import TrainingPanel # noqa: E402, F401

View File

@@ -0,0 +1,61 @@
"""CompletionsPanel — shows recent RL/log_completions samples."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
def _truncate(s: str, maxlen: int = 60) -> str:
return s[:maxlen] + "" if len(s) > maxlen else s
@register_panel(position="bottom", weight=20)
class CompletionsPanel(BasePanel):
name = "completions"
min_height = 6
modes = ["grpo", "dpo"]
def render(self, state: TUIState) -> RenderableType:
if "*" not in self.modes and state.training_mode not in self.modes:
return Text("")
if not state.completions:
return Panel(
Text("No completions yet...", style="dim"),
title="Completions",
border_style="magenta",
)
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("step", justify="right", width=6)
table.add_column("prompt", no_wrap=False, max_width=40)
table.add_column("completion", no_wrap=False, max_width=40)
table.add_column("reward", justify="right", width=8)
table.add_column("adv", justify="right", width=8)
for sample in list(state.completions)[-5:]:
reward_str = f"{sample.reward:.2f}" if sample.reward is not None else "--"
adv_str = (
f"{sample.advantage:+.2f}" if sample.advantage is not None else "--"
)
table.add_row(
str(sample.step),
_truncate(sample.prompt),
_truncate(sample.completion),
reward_str,
adv_str,
)
return Panel(table, title="Completions", border_style="magenta")

View File

@@ -0,0 +1,34 @@
"""DebugPanel — scrolling log of debug-level messages, separate from main events."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
@register_panel(position="bottom", weight=30)
class DebugPanel(BasePanel):
name = "debug"
min_height = 6
max_height = 10
def render(self, state: TUIState) -> RenderableType:
lines = Text()
# Show last 8 debug-level log lines
debug_lines = [
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
][-8:]
for log_line in debug_lines:
ts = log_line.timestamp.strftime("%H:%M:%S")
lines.append(f"[{ts}] ", style="dim")
lines.append(log_line.message[:200], style="dim")
lines.append("\n")
if not debug_lines:
lines = Text("No debug messages yet...", style="dim")
return Panel(lines, title="Debug", border_style="dim")

View File

@@ -0,0 +1,45 @@
"""EventsPanel — scrolling log of recent events, color-coded by level."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
_LEVEL_STYLES = {
"debug": "dim",
"info": "",
"warning": "yellow",
"error": "red bold",
"critical": "red bold",
}
@register_panel(position="bottom", weight=10)
class EventsPanel(BasePanel):
name = "events"
min_height = 8
max_height = 20
def render(self, state: TUIState) -> RenderableType:
lines = Text()
# Show last 15 non-debug log lines (debug goes to DebugPanel)
recent = [
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
][-15:]
for log_line in recent:
ts = log_line.timestamp.strftime("%H:%M:%S")
level = log_line.level.upper()
style = _LEVEL_STYLES.get(log_line.level, "")
lines.append(f"[{ts}] ", style="dim")
lines.append(f"[{level}] ", style=style or "")
lines.append(log_line.message[:200], style=style or "")
lines.append("\n")
if not recent:
lines = Text("No events yet...", style="dim")
return Panel(lines, title="Events", border_style="yellow")

View File

@@ -0,0 +1,80 @@
"""HardwarePanel — per-GPU stats via pynvml."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
_BAR_FULL = ""
_BAR_EMPTY = ""
def _util_bar(pct: float, width: int = 6) -> Text:
filled = int(pct / 100 * width)
bar = _BAR_FULL * filled + _BAR_EMPTY * (width - filled)
color = "green" if pct < 70 else ("yellow" if pct < 90 else "red")
return Text.assemble((bar, color), f" {pct:3.0f}%")
@register_panel(position="right", weight=10)
class HardwarePanel(BasePanel):
name = "hardware"
min_height = 6
def render(self, state: TUIState) -> RenderableType:
if not state.gpus:
return Panel(
Text("GPU stats unavailable", style="dim"),
title="Hardware",
border_style="green",
)
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("id", justify="right", width=3)
table.add_column("util", no_wrap=True)
table.add_column("vram", no_wrap=True)
table.add_column("°C", justify="right", width=4)
table.add_column("W", justify="right", width=5)
total_vram_used = 0.0
total_vram_total = 0.0
total_util = 0.0
for gpu in state.gpus:
total_vram_used += gpu.vram_used_gb
total_vram_total += gpu.vram_total_gb
total_util += gpu.util_pct
power_str = f"{gpu.power_w:.0f}" if gpu.power_w is not None else "--"
table.add_row(
str(gpu.id),
_util_bar(gpu.util_pct),
f"{gpu.vram_used_gb:.1f}/{gpu.vram_total_gb:.1f} GB",
str(gpu.temp_c),
power_str,
)
# Footer with aggregates
n = len(state.gpus)
if n > 1:
avg_util = total_util / n
table.add_row(
"Σ",
Text(f"avg {avg_util:.0f}%", style="dim"),
Text(f"{total_vram_used:.1f}/{total_vram_total:.1f} GB", style="dim"),
"",
"",
)
return Panel(table, title="Hardware", border_style="green")

View File

@@ -0,0 +1,73 @@
"""ProgressPanel — top-bar progress display with step count, elapsed, ETA."""
from __future__ import annotations
from rich.console import RenderableType
from rich.progress import BarColumn, Progress, TextColumn
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
def _fmt_time(seconds: float | None) -> str:
if seconds is None or seconds < 0:
return "--:--:--"
h = int(seconds) // 3600
m = (int(seconds) % 3600) // 60
s = int(seconds) % 60
return f"{h}:{m:02d}:{s:02d}"
def _fmt_eta(seconds: float | None) -> str:
if seconds is None or seconds < 0:
return "eta --"
h = int(seconds) // 3600
m = (int(seconds) % 3600) // 60
if h > 0:
return f"eta {h}h{m:02d}m"
return f"eta {m}m{int(seconds) % 60:02d}s"
@register_panel(position="top", weight=10)
class ProgressPanel(BasePanel):
name = "progress"
min_height = 3
max_height = 3
def render(self, state: TUIState) -> RenderableType:
pct = (
(state.current_step / state.total_steps * 100)
if state.total_steps > 0
else 0
)
# Header line
mode_upper = state.training_mode.upper() if state.training_mode else "SFT"
model_short = state.model_name.split("/")[-1] if state.model_name else "model"
header = Text.assemble(
("", "bold green"),
("AXOLOTL", "bold cyan"),
f" {mode_upper} · {model_short} ",
(
f"{state.current_step} / {state.total_steps}",
"bold",
),
f" · {_fmt_time(state.elapsed_seconds)} elapsed · {_fmt_eta(state.eta_seconds)} · {pct:.1f}%",
)
# Progress bar
progress = Progress(
TextColumn(""),
BarColumn(bar_width=None),
TextColumn("{task.percentage:>3.0f}%"),
expand=True,
)
task = progress.add_task("", total=state.total_steps or 1)
progress.update(task, completed=state.current_step)
table = Table.grid(expand=True)
table.add_row(header)
table.add_row(progress)
return table

View File

@@ -0,0 +1,97 @@
"""TrainingPanel — live scalar metrics table with loss sparkline."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
# Braille sparkline characters (8 levels)
_SPARK_CHARS = "▁▂▃▄▅▆▇█"
def _sparkline(values: list[float] | None, width: int = 20) -> str:
if not values or len(values) < 2:
return ""
vals = list(values)[-width:]
lo, hi = min(vals), max(vals)
rng = hi - lo if hi != lo else 1.0
return "".join(_SPARK_CHARS[min(int((v - lo) / rng * 7), 7)] for v in vals)
# Known key ordering and formatting
_KNOWN_KEYS: list[tuple[str, str, str]] = [
("loss", "loss", ".4f"),
("grad_norm", "grad norm", ".3f"),
("learning_rate", "lr", ".2e"),
("tokens_per_second", "tok/s", ".1f"),
("samples_per_second", "samples/s", ".1f"),
("mfu", "MFU", ".1f"),
# RL-specific
("rewards_mean", "rewards/mean", ".4f"),
("rewards_std", "rewards/std", ".4f"),
("kl_divergence", "KL", ".4f"),
("clip_ratio", "clip ratio", ".3f"),
("queue_size", "queue", "d"),
]
@register_panel(position="left", weight=10)
class TrainingPanel(BasePanel):
name = "training"
min_height = 8
def render(self, state: TUIState) -> RenderableType:
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("metric", style="cyan", no_wrap=True)
table.add_column("value", justify="right")
table.add_column("trend", justify="left", no_wrap=True)
for attr, label, fmt in _KNOWN_KEYS:
val = getattr(state, attr, None)
if val is None:
# Also check extra dict
val = state.extra.get(attr)
if val is None:
continue
try:
formatted = f"{val:{fmt}}"
except (ValueError, TypeError):
formatted = str(val)
trend = ""
if attr == "loss":
trend = _sparkline(list(state.loss_history))
table.add_row(label, formatted, trend)
# Any extra keys not in _KNOWN_KEYS
known_attrs = {k for k, _, _ in _KNOWN_KEYS}
for key, val in sorted(state.extra.items()):
if key in known_attrs or val is None:
continue
try:
formatted = f"{val:.4f}"
except (ValueError, TypeError):
formatted = str(val)
table.add_row(key, formatted, "")
if table.row_count == 0:
return Panel(
Text("Waiting for first log step...", style="dim"),
title="Training",
border_style="blue",
)
return Panel(table, title="Training", border_style="blue")

View File

@@ -0,0 +1,7 @@
"""Built-in line parsers — auto-imported to trigger @register_parser decorators."""
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401

View File

@@ -0,0 +1,29 @@
"""DeepSpeedParser — extracts DeepSpeed stage info and throughput metrics."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class DeepSpeedParser(LineParser):
priority = 20
name = "deepspeed"
_SAMPLES_RE = re.compile(r"samples/sec=([0-9.]+)")
_STAGE_RE = re.compile(r"ZeRO Stage (\d)")
def parse(self, line: str, source: str) -> list[dict]:
events: list[dict] = []
if m := self._SAMPLES_RE.search(line):
events.append(
{
"type": "metrics",
"logs": {"samples_per_second": float(m.group(1))},
}
)
if m := self._STAGE_RE.search(line):
events.append({"type": "run_info", "zero_stage": int(m.group(1))})
return events

View File

@@ -0,0 +1,27 @@
"""NCCLErrorParser — surfaces NCCL errors as red alert events."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class NCCLErrorParser(LineParser):
priority = 10
name = "nccl_error"
_RE = re.compile(r"NCCL error|Unhandled NCCL", re.IGNORECASE)
def parse(self, line: str, source: str) -> list[dict]:
if self._RE.search(line):
return [
{
"type": "log_line",
"level": "error",
"message": f"⚠ NCCL: {line}",
},
{"type": "alert", "severity": "error", "message": line},
]
return []

View File

@@ -0,0 +1,37 @@
"""RawLogParser — catches every line as a log_line event."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class RawLogParser(LineParser):
priority = 99
name = "raw_log"
_LOG_RE = re.compile(
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
r"\s*[-]\s*(?P<msg>.+)$",
re.IGNORECASE,
)
# Filter out tqdm progress bar lines and other noisy output
_TQDM_RE = re.compile(r"^\s*\d+%\|.*\|")
_EMPTY_RE = re.compile(r"^\s*$")
def parse(self, line: str, source: str) -> list[dict]:
# Skip empty lines and tqdm progress bar updates
if self._EMPTY_RE.match(line) or self._TQDM_RE.match(line):
return []
m = self._LOG_RE.match(line)
level = (
m.group("level").lower()
if m
else ("error" if source == "stderr" else "info")
)
return [{"type": "log_line", "level": level, "message": line}]

View File

@@ -0,0 +1,26 @@
"""TorchCompileParser — detects torch.compile graph breaks and recompilations."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class TorchCompileParser(LineParser):
priority = 20
name = "torch_compile"
_RE = re.compile(r"Graph break|Recompiling|torch\.compile", re.IGNORECASE)
def parse(self, line: str, source: str) -> list[dict]:
if self._RE.search(line):
return [
{
"type": "log_line",
"level": "warning",
"message": f"⚡ compile: {line}",
}
]
return []

View File

@@ -0,0 +1,86 @@
"""TqdmParser — captures tqdm progress bar output and surfaces as structured events."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class TqdmParser(LineParser):
priority = 15
name = "tqdm"
# Match tqdm-style progress lines, e.g.:
# Tokenizing Prompts (num_proc=24): 35%|███▍ | 19008/54568 [00:02<00:02, 17417.65 examples/s]
# Loading weights: 53%|█████▎ | 77/146 [00:00<00:00, 396.39it/s]
# 0%| | 0/30 [00:00<?, ?it/s]
_TQDM_RE = re.compile(
r"(?P<desc>.*?)\s*"
r"(?P<pct>\d+)%\|[▏▎▍▌▋▊▉█░▓▒# ]*\|\s*"
r"(?P<current>[\d,]+)/(?P<total>[\d,]+)"
r"\s*\[(?P<elapsed>[^\]]*)\]"
)
# Also match simpler forms like:
# Fetching 0 files: 0it [00:00, ?it/s]
_FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
def parse(self, line: str, source: str) -> list[dict]:
m = self._TQDM_RE.search(line)
if m:
desc = m.group("desc").strip().rstrip(":")
pct = int(m.group("pct"))
current = int(m.group("current").replace(",", ""))
total = int(m.group("total").replace(",", ""))
events: list[dict] = []
# Surface as a log line with progress info
if pct == 100 or pct == 0 or pct % 25 == 0:
msg = (
f"[{desc}] {pct}% ({current}/{total})"
if desc
else f"{pct}% ({current}/{total})"
)
events.append(
{
"type": "log_line",
"level": "info",
"message": msg,
}
)
# Also emit as a progress metric
cleaned_desc = desc.strip().lower().replace(" ", "_")
if not cleaned_desc:
cleaned_desc = "progress"
events.append(
{
"type": "metrics",
"logs": {
f"progress/{cleaned_desc}": pct / 100.0,
},
}
)
return events
# Fallback: try simpler fetch-style progress lines
m = self._FETCH_RE.search(line)
if m:
desc = m.group("desc").strip().rstrip(":")
current = int(m.group("current"))
cleaned_desc = desc.strip().lower().replace(" ", "_")
if not cleaned_desc:
cleaned_desc = "fetch"
return [
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] {current}" if desc else f"{current}",
}
]
return []

449
src/axolotl/tui/renderer.py Normal file
View File

@@ -0,0 +1,449 @@
"""TUIRenderer — background daemon thread that drives the rich.live.Live display."""
from __future__ import annotations
import logging
import queue
import threading
import time
from datetime import datetime
from typing import Any
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from axolotl.tui.config import TUIConfig
from axolotl.tui.gpu import GPUPoller
from axolotl.tui.io_capture import (
IOCapture,
ParserChain,
get_registered_parsers,
)
from axolotl.tui.panels import BasePanel, get_registered_panels
from axolotl.tui.state import CompletionSample, LogLine, TUIState
LOG = logging.getLogger(__name__)
class TUIRenderer:
"""Background thread that renders the TUI dashboard using rich.live.Live."""
def __init__(self, config: TUIConfig, metric_queue: queue.Queue):
self._config = config
self._queue = metric_queue
self._state = TUIState()
self._gpu_poller = GPUPoller()
self._panels: list[BasePanel] = []
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._io_capture: IOCapture | None = None
self._parser_chain: ParserChain | None = None
def _init_panels(self) -> None:
registry = get_registered_panels()
for panel_name in self._config.panels:
if panel_name in registry:
self._panels.append(registry[panel_name]())
def _init_parser_chain(self) -> None:
# Ensure built-in parsers are imported so @register_parser decorators fire
import axolotl.tui.parsers # noqa: F401
self._parser_chain = ParserChain()
# Register all built-in parsers
for parser_cls in get_registered_parsers():
self._parser_chain.register(parser_cls())
# Load plugin parsers
for plugin_spec in self._config.parser_plugins:
try:
if "::" in plugin_spec:
# file path :: class name
file_path, class_name = plugin_spec.split("::", 1)
import importlib.util
spec = importlib.util.spec_from_file_location(
"custom_parser", file_path
)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load spec for {file_path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
parser_cls = getattr(mod, class_name)
else:
# dotted module path
module_path, class_name = plugin_spec.rsplit(".", 1)
mod = importlib.import_module(module_path)
parser_cls = getattr(mod, class_name)
self._parser_chain.register(parser_cls())
except Exception as exc:
LOG.warning(f"Failed to load parser plugin {plugin_spec}: {exc}")
def _build_layout(self) -> Layout:
layout = Layout()
top_panels = [p for p in self._panels if p.position == "top"]
left_panels = [p for p in self._panels if p.position == "left"]
right_panels = [p for p in self._panels if p.position == "right"]
bottom_panels = [p for p in self._panels if p.position == "bottom"]
sections = []
if top_panels:
layout_top = Layout(name="top", size=3)
sections.append(layout_top)
if left_panels or right_panels:
layout_middle = Layout(name="middle", ratio=3)
middle_parts = []
if left_panels:
middle_parts.append(Layout(name="left", ratio=1))
if right_panels:
middle_parts.append(Layout(name="right", ratio=1))
if middle_parts:
layout_middle.split_row(*middle_parts)
sections.append(layout_middle)
if bottom_panels:
layout_bottom = Layout(name="bottom", ratio=2)
if len(bottom_panels) > 1:
layout_bottom.split_row(
*[
Layout(name=f"bottom_{i}", ratio=1)
for i in range(len(bottom_panels))
]
)
sections.append(layout_bottom)
if sections:
layout.split_column(*sections)
return layout
def _update_layout(self, layout: Layout) -> None:
top_panels = [p for p in self._panels if p.position == "top"]
left_panels = [p for p in self._panels if p.position == "left"]
right_panels = [p for p in self._panels if p.position == "right"]
bottom_panels = [p for p in self._panels if p.position == "bottom"]
if top_panels:
layout["top"].update(top_panels[0].render(self._state))
if left_panels:
layout["left"].update(left_panels[0].render(self._state))
if right_panels:
layout["right"].update(right_panels[0].render(self._state))
if bottom_panels:
if len(bottom_panels) == 1:
layout["bottom"].update(bottom_panels[0].render(self._state))
else:
for i, panel in enumerate(bottom_panels):
layout[f"bottom_{i}"].update(panel.render(self._state))
def _drain_queue(self) -> None:
while True:
try:
event = self._queue.get_nowait()
except queue.Empty:
break
# Dispatch event to panels first
for panel in self._panels:
panel.on_event(event)
event_type = event.get("type")
if event_type == "metrics":
logs = event.get("logs", {})
self._apply_metrics(logs)
elif event_type == "step":
self._state.current_step = event.get("step", self._state.current_step)
self._state.total_steps = event.get(
"total_steps", self._state.total_steps
)
self._state.current_epoch = event.get(
"epoch", self._state.current_epoch
)
now = time.time()
self._state.elapsed_seconds = now - self._state.start_time.timestamp()
if self._state.current_step > 0 and self._state.total_steps > 0:
rate = self._state.elapsed_seconds / self._state.current_step
remaining = self._state.total_steps - self._state.current_step
self._state.eta_seconds = rate * remaining
elif event_type == "log_line":
level = event.get("level", "info")
message = event.get("message", "")
self._state.log_lines.append(
LogLine(
timestamp=datetime.now(),
level=level,
message=message,
)
)
elif event_type == "completion":
self._state.completions.append(
CompletionSample(
step=event.get("step", 0),
prompt=event.get("prompt", ""),
completion=event.get("completion", ""),
reward=event.get("reward"),
advantage=event.get("advantage"),
)
)
elif event_type == "run_info":
if "run_name" in event:
self._state.run_name = event["run_name"]
if "model_name" in event:
self._state.model_name = event["model_name"]
if "training_mode" in event:
self._state.training_mode = event["training_mode"]
if "world_size" in event:
self._state.world_size = event["world_size"]
if "total_steps" in event:
self._state.total_steps = event["total_steps"]
if "total_epochs" in event:
self._state.total_epochs = event["total_epochs"]
if "zero_stage" in event:
self._state.zero_stage = event["zero_stage"]
elif event_type == "done":
self._stop_event.set()
def _apply_metrics(self, logs: dict[str, Any]) -> None:
metric_map = {
"loss": "loss",
"grad_norm": "grad_norm",
"learning_rate": "learning_rate",
"tokens_per_second": "tokens_per_second",
"samples_per_second": "samples_per_second",
"mfu": "mfu",
"rewards/mean": "rewards_mean",
"rewards_mean": "rewards_mean",
"rewards/std": "rewards_std",
"rewards_std": "rewards_std",
"kl": "kl_divergence",
"kl_divergence": "kl_divergence",
"clip_ratio": "clip_ratio",
"queue_size": "queue_size",
}
for key, value in logs.items():
if key in metric_map:
setattr(self._state, metric_map[key], value)
else:
self._state.extra[key] = value
if "loss" in logs and logs["loss"] is not None:
self._state.loss_history.append(logs["loss"])
def start(self) -> None:
self._init_panels()
self._init_parser_chain()
# Set up I/O capture
assert self._parser_chain is not None, "_init_parser_chain must be called first"
self._io_capture = IOCapture(
log_path=self._config.stdout_log_path,
parser_chain=self._parser_chain,
metric_queue=self._queue,
)
# Monkeypatch tqdm to suppress terminal output and route through our queue.
# This prevents tqdm progress bars from flickering through the TUI and
# ensures all progress events appear in the Events panel.
self._install_tqdm_hook()
self._io_capture_ready = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
self._io_capture_ready.wait(timeout=5.0)
def _install_tqdm_hook(self) -> None:
"""Replace tqdm's display method to route updates through TUI queue."""
try:
import io
import tqdm
import tqdm.auto
q = self._queue
self._tqdm_parser = None
# Find our tqdm parser in the chain
for p in self._parser_chain._parsers if self._parser_chain else []:
if p.name == "tqdm":
self._tqdm_parser = p
break
# Save originals for restore
self._orig_tqdm_class_auto = tqdm.auto.tqdm
self._orig_tqdm_class_tqdm = tqdm.tqdm
self._orig_tqdm_class_std = tqdm.std.tqdm
class TUITqdm(tqdm.tqdm):
"""tqdm subclass that sends progress to TUI instead of terminal."""
def __init__(self, *args, **kwargs):
# Force output to devnull so nothing reaches the terminal
kwargs["file"] = io.StringIO()
kwargs["dynamic_ncols"] = False
kwargs["ncols"] = 80
super().__init__(*args, **kwargs)
def display(self, msg=None, pos=None):
# Build a progress string and push to queue
if self.total and self.total > 0:
pct = self.n / self.total * 100
desc = self.desc.rstrip(": ") if self.desc else ""
# Emit events at milestones or at low frequency
is_milestone = (
self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
)
if is_milestone:
try:
q.put_nowait(
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
if desc
else f"{pct:.0f}% ({self.n}/{self.total})",
}
)
except Exception:
pass
try:
metric_key = (
f"progress/{desc.lower().replace(' ', '_')}"
if desc
else "progress/unknown"
)
q.put_nowait(
{
"type": "metrics",
"logs": {metric_key: pct / 100.0},
}
)
except Exception:
pass
def close(self):
# Emit final completion event
if self.total and self.total > 0 and self.n > 0:
desc = self.desc.rstrip(": ") if self.desc else ""
try:
q.put_nowait(
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] 100% ({self.total}/{self.total}) done"
if desc
else f"100% ({self.total}/{self.total}) done",
}
)
except Exception:
pass
super().close()
# Replace tqdm globally
tqdm.auto.tqdm = TUITqdm
tqdm.tqdm = TUITqdm
# Also patch tqdm.std which some libraries use directly
tqdm.std.tqdm = TUITqdm
self._tui_tqdm_cls = TUITqdm
except Exception as exc:
LOG.debug(f"Failed to install tqdm hook: {exc}")
def _uninstall_tqdm_hook(self) -> None:
"""Restore original tqdm."""
try:
import tqdm
import tqdm.auto
if hasattr(self, "_orig_tqdm_class_auto"):
tqdm.auto.tqdm = self._orig_tqdm_class_auto
if hasattr(self, "_orig_tqdm_class_tqdm"):
tqdm.tqdm = self._orig_tqdm_class_tqdm
if hasattr(self, "_orig_tqdm_class_std"):
tqdm.std.tqdm = self._orig_tqdm_class_std
except Exception:
pass
def stop(self) -> None:
self._stop_event.set()
self._uninstall_tqdm_hook()
if self._thread is not None:
self._thread.join(timeout=5.0)
def _run(self) -> None:
import os
# Save a handle to the REAL terminal BEFORE IO capture redirects fds.
# This ensures rich.live.Live writes to the terminal, not the pipe.
saved_tty_fd = os.dup(1)
tty_file = os.fdopen(saved_tty_fd, "w", buffering=1, closefd=True)
console = Console(file=tty_file)
layout = self._build_layout()
tick_interval = 1.0 / max(self._config.refresh_rate, 1)
gpu_poll_counter = 0
gpu_poll_ticks = max(
1, int(self._config.hardware_poll_interval / tick_interval)
)
# Start I/O capture — redirects fd 1/2 to pipe AFTER we saved the tty fd
if self._io_capture:
self._io_capture.start()
# Signal that IO capture is live so start() can return
if hasattr(self, "_io_capture_ready"):
self._io_capture_ready.set()
try:
with Live(
layout,
console=console,
refresh_per_second=self._config.refresh_rate,
screen=True,
redirect_stdout=False,
redirect_stderr=False,
) as live:
while not self._stop_event.is_set():
self._drain_queue()
# Poll GPU stats periodically
gpu_poll_counter += 1
if gpu_poll_counter >= gpu_poll_ticks:
gpu_poll_counter = 0
if self._gpu_poller.available:
self._state.gpus = self._gpu_poller.poll()
# Update elapsed time
self._state.elapsed_seconds = (
time.time() - self._state.start_time.timestamp()
)
self._update_layout(layout)
live.update(layout)
time.sleep(tick_interval)
# Final drain
self._drain_queue()
self._update_layout(layout)
live.update(layout)
finally:
if self._io_capture:
self._io_capture.stop()
try:
tty_file.close()
except Exception:
pass

88
src/axolotl/tui/state.py Normal file
View File

@@ -0,0 +1,88 @@
"""TUI shared data model — dataclasses for the dashboard state."""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class GPUStats:
id: int
name: str
util_pct: float
vram_used_gb: float
vram_total_gb: float
temp_c: int
power_w: float | None
@dataclass
class LogLine:
timestamp: datetime
level: str # "info" | "debug" | "warning" | "error"
message: str
@dataclass
class CompletionSample:
step: int
prompt: str
completion: str
reward: float | None
advantage: float | None
@dataclass
class TUIState:
# Run metadata
run_name: str = ""
model_name: str = ""
training_mode: str = "sft"
world_size: int = 1
start_time: datetime = field(default_factory=datetime.now)
# Progress
current_step: int = 0
total_steps: int = 0
current_epoch: float = 0.0
total_epochs: float = 1.0
elapsed_seconds: float = 0.0
eta_seconds: float | None = None
# Training metrics (rolling window + current)
loss: float | None = None
grad_norm: float | None = None
learning_rate: float | None = None
tokens_per_second: float | None = None
samples_per_second: float | None = None
mfu: float | None = None
# RL-specific (None for non-RL modes)
rewards_mean: float | None = None
rewards_std: float | None = None
kl_divergence: float | None = None
clip_ratio: float | None = None
queue_size: int | None = None
# Per-GPU hardware (list indexed by local rank)
gpus: list[GPUStats] = field(default_factory=list)
# Recent log lines
log_lines: deque[LogLine] = field(default_factory=lambda: deque(maxlen=200))
# Recent completions (GRPO/SFT with log_completions)
completions: deque[CompletionSample] = field(
default_factory=lambda: deque(maxlen=20)
)
# Loss history for sparkline
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
# DeepSpeed zero stage (None if not using DeepSpeed)
zero_stage: int | None = None
# Arbitrary plugin state
extra: dict[str, Any] = field(default_factory=dict)

View File

@@ -17,6 +17,8 @@ from transformers import (
class PytorchProfilerCallback(TrainerCallback):
"""
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
Also runs torch.profiler to produce a Chrome trace for timing analysis.
"""
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
@@ -26,9 +28,10 @@ class PytorchProfilerCallback(TrainerCallback):
if profiler_steps_start == 0:
# start recording memory allocations before everything is allocated, because if we start
# at the beginning of step 0, we won't have any memory allocations in the traces
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
profiler_steps_start = -1
self.profiler_steps_start = profiler_steps_start
self._profiler = None
def on_step_begin(
self,
@@ -38,7 +41,21 @@ class PytorchProfilerCallback(TrainerCallback):
**kwargs,
):
if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
# Start torch.profiler on the first profiled step
if state.global_step == max(self.profiler_steps_start, 0):
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
)
profiler.__enter__()
self._profiler = profiler
def on_step_end(
self,
@@ -55,6 +72,13 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
# Stop and export torch.profiler trace
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None
def on_train_end(
self,
args: TrainingArguments,
@@ -73,3 +97,9 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None

View File

@@ -25,9 +25,11 @@ def toggle_fake_quant(mod: nn.Module, enable: bool):
if (
isinstance(mod, FakeQuantizedLinear)
and mod.activation_fake_quantizer is not None
and hasattr(mod.activation_fake_quantizer, "enabled")
):
mod.activation_fake_quantizer.enabled = enable
mod.weight_fake_quantizer.enabled = enable
if hasattr(mod.weight_fake_quantizer, "enabled"):
mod.weight_fake_quantizer.enabled = enable
class QATCallback(TrainerCallback):

View File

@@ -12,12 +12,11 @@ from transformers import (
TrainingArguments,
)
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state.json"
class TokensPerSecondCallback(TrainerCallback):
"""

View File

@@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
from axolotl.utils.schemas.datasets import (
DPODataset,
KTODataset,
SFTDataset,
SyntheticDataset,
)
LOG = get_logger(__name__)
@@ -308,6 +313,14 @@ def validate_config(
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
elif (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
cfg["datasets"][idx] = SyntheticDataset(
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
)
elif not isinstance(ds_cfg, SFTDataset):
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))

View File

@@ -376,10 +376,14 @@ def _load_and_process_single_dataset(
streaming: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# For synthetic datasets, create a minimal placeholder instead of loading from path
if dataset_config.type == "_synthetic":
dataset = Dataset.from_dict({"text": [""]})
else:
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)

View File

@@ -10,9 +10,11 @@ from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
)
from torchao.quantization.qat.fake_quantize_config import Int4WeightFakeQuantizeConfig
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
)
@@ -173,6 +175,70 @@ def quantize_model(
)
def _make_qat_config(
base_config: AOBaseConfig,
weight_dtype: TorchAOQuantDType,
activation_dtype: TorchAOQuantDType | None,
group_size: int | None,
) -> QATConfig:
"""Build a QATConfig, explicitly constructing fake quantize configs to ensure
group_size and other params are properly propagated (torchao's QATConfig(base_config)
does not always map these correctly)."""
from torchao.quantization.qat.fake_quantize_config import (
Float8FakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
if isinstance(base_config, MXFakeQuantizeConfig):
return QATConfig(
activation_config=base_config,
weight_config=base_config,
)
# Build explicit weight config
weight_fq_config: (
Int4WeightFakeQuantizeConfig
| IntxFakeQuantizeConfig
| Float8FakeQuantizeConfig
| None
) = None
if weight_dtype == TorchAOQuantDType.int4:
gs = (
group_size
if group_size is not None
else getattr(base_config, "group_size", 128)
)
activation_dt = None
if activation_dtype == TorchAOQuantDType.int8:
activation_dt = torch.bfloat16
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
activation_dt = torch.float8_e4m3fn
kwargs = {"group_size": gs}
if activation_dt is not None:
kwargs["activation_dtype"] = activation_dt
weight_fq_config = Int4WeightFakeQuantizeConfig(**kwargs)
elif weight_dtype == TorchAOQuantDType.float8_e4m3fn:
weight_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
# Build explicit activation config
activation_fq_config = None
if activation_dtype == TorchAOQuantDType.int8:
activation_fq_config = IntxFakeQuantizeConfig(
dtype=torch.int8, granularity="per_token", is_symmetric=False
)
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
activation_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
if weight_fq_config is not None:
return QATConfig(
weight_config=weight_fq_config,
activation_config=activation_fq_config,
)
# Fallback to base_config for unhandled combos
return QATConfig(base_config)
def prepare_model_for_qat(
model,
weight_dtype: TorchAOQuantDType,
@@ -200,13 +266,9 @@ def prepare_model_for_qat(
activation_dtype=activation_dtype,
group_size=group_size,
)
if isinstance(base_config, MXFakeQuantizeConfig):
qat_config = QATConfig(
activation_config=base_config,
weight_config=base_config,
)
else:
qat_config = QATConfig(base_config)
qat_config = _make_qat_config(
base_config, weight_dtype, activation_dtype, group_size
)
quantize_(model, qat_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
@@ -215,12 +277,9 @@ def prepare_model_for_qat(
activation_dtype=None,
group_size=group_size,
)
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
embedding_qat_config = QATConfig(
weight_config=embedding_base_config,
)
else:
embedding_qat_config = QATConfig(embedding_base_config)
embedding_qat_config = _make_qat_config(
embedding_base_config, weight_dtype, None, group_size
)
quantize_(
model,
embedding_qat_config,

View File

@@ -2,7 +2,7 @@
import math
from functools import partial
from typing import Sequence
from typing import Any, Sequence
from torch import Tensor
from torch.optim import Optimizer
@@ -340,3 +340,19 @@ class JaggedLRRestartScheduler(LRScheduler):
return [lr * scale for lr in original]
return original * scale
def state_dict(self) -> dict[str, Any]:
"""Return serializable state, saving inner_schedule as its own state_dict."""
state = {
key: value
for key, value in self.__dict__.items()
if key not in ("optimizer", "inner_schedule")
}
state["inner_schedule_state"] = self.inner_schedule.state_dict()
return state
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Restore state, including inner_schedule."""
inner_state = state_dict.pop("inner_schedule_state")
self.__dict__.update(state_dict)
self.inner_schedule.load_state_dict(inner_state)

View File

@@ -13,6 +13,7 @@ from pydantic import (
model_validator,
)
from axolotl.tui.config import TUIConfig
from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import (
@@ -22,6 +23,7 @@ from axolotl.utils.schemas.datasets import (
PretrainingDataset,
SFTDataset,
StepwiseSupervisedDataset,
SyntheticDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
@@ -139,6 +141,12 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
tui: TUIConfig | None = Field(
default=None,
json_schema_extra={
"description": "TUI dashboard configuration. Set enabled: true to activate."
},
)
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(
@@ -185,7 +193,13 @@ class AxolotlInputConfig(
datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
MinLen(1),
]
| None
@@ -198,7 +212,13 @@ class AxolotlInputConfig(
test_datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
MinLen(1),
]
| None
@@ -433,6 +453,12 @@ class AxolotlInputConfig(
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
},
)
layer_offloading: bool | None = Field(
default=False,
json_schema_extra={
"description": "Offload model layer parameters to CPU during forward, prefetch back during backward."
},
)
unfrozen_parameters: list[str] | None = Field(
default=None,
@@ -684,6 +710,12 @@ class AxolotlInputConfig(
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
lora_embedding_kernel: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
chunked_cross_entropy: bool | None = Field(
default=None,
@@ -1294,6 +1326,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
or data.get("lora_embedding_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp_config") is not None
@@ -1341,7 +1374,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
kernel_fields = [
"lora_mlp_kernel",
"lora_qkv_kernel",
"lora_o_kernel",
"lora_embedding_kernel",
]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
@@ -1354,10 +1392,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("trust_remote_code"):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
@@ -1379,6 +1413,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("lora_o_kernel") is None:
data["lora_o_kernel"] = True
if data.get("lora_embedding_kernel") is None:
data["lora_embedding_kernel"] = True
LOG.warning(
"Auto-enabling LoRA kernel optimizations for faster training. "
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "

Some files were not shown because too many files have changed in this diff Show More