Compare commits

...

20 Commits

Author SHA1 Message Date
Wing Lian
598c965043 use train_loss for sp test 2026-03-22 12:00:55 -04:00
Wing Lian
a96733930e retry and more info on download failure 2026-03-22 11:09:33 -04:00
Wing Lian
6130e40c37 fix flaky tests; should be using train loss from final step rather than final avg train loss 2026-03-22 10:38:46 -04:00
Wing Lian
5b2e3f00ce fix: handle connection errors when checking user whoami (#3529) 2026-03-22 09:11:17 -04:00
Wing Lian
fc3b3d1d4e synthetic datasets for benchmarking and testing (#3518) [skip ci]
* synthetic datasets for benchmarking and testing

* fix synthetic dataset parse from config and add tests

* use type=_synthetic
2026-03-21 22:47:26 -04:00
Wing Lian
c9df6efdc2 support offloading layers to CPU (#3512) [skip ci]
* support offloading layers to CPU

* chore: lint

* revert change

* update docs
2026-03-21 22:47:02 -04:00
Wing Lian
0ee98a0309 fix token state json and mistral tokenizer issue (#3522) [skip ci]
* fix token state json and mistral tokenizer issue

* centralize constants

* forgot to commit constants file

* Fix weakref in pickling relora state dict

* make curl a bit quieter so it doesn't log 2K lines

* fix path traversal for olmoe test

* more test fixes that weren't flagged previously

* chore: lint

* skip tests that fail b/c of OutOfResources

* scattermoe as slow tests

* update fbgemm-genai for torch 2.10
2026-03-21 22:46:10 -04:00
Wing Lian
2c05847a5f reduce autotune search space (#3525) [skip ci]
* reduce autotune search space

* consistent docstrings
2026-03-21 18:30:15 -04:00
Wing Lian
b0294b3427 handle qwen3.5 moe loading (#3523) [skip ci] 2026-03-20 09:25:16 -04:00
Avaya Aggarwal
1bcfc08c90 feat: add support and end-to-end tests for multiple custom optimizers… (#3457) [skip ci]
* feat: add support and end-to-end tests for multiple custom optimizers including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW.

* feat: Add standalone flashoptim integration test and E2E tests for various custom optimizers including FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion, optimi_adamw, adopt_adamw, muon, dion, and schedule_free_adamw.

* feat: introduce Pydantic schema validation for dataset, attention, and training configurations.

* feat: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: fix assertion in flash optimizers test to compare class names directly

* fix: address PR review - reuse require_torch_2_7_0 decorator, remove fsdp_config.version check, extract shared FSDP version helper, remove unused imports and optim_args

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 08:24:44 -04:00
NanoCode012
5a5cf30b26 fix: add dequant bf16 repo (#3507) [skip ci] 2026-03-20 17:11:46 +07:00
Avaya Aggarwal
7ddfb2d8a0 cleanup: remove dead SDPA patches (#3488) [skip ci]
Transformers 5.x routes attention through sdpa_attention.py and no longer
calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that
these patches targeted. This makes the following patches dead code:

- llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*)
- llama_expand_mask.py (patched _expand_mask, never called)
- Related utility functions in monkeypatch/utils.py

Closes axolotl-ai-cloud/axolotl#3331
2026-03-20 17:10:41 +07:00
Owen Arliawan
c57acef2c7 Qwen3.5-MoE example config with lora_target_modules regex (#3515) [skip ci]
* lora target modules with regex

* updates

* fsdp for non moe

* update wording

* chore: cleanup and lint

* chore: cleanup docs from merge

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 16:52:46 +07:00
Lorenzo Baraldi
038ffe3f26 fix: solved double sequence partition from SequenceParallelContextManager and Accelerate's native CP (#3498) 2026-03-20 16:27:24 +07:00
VED
c13cb7c853 feat: add nemotron config (#3506)
* nemotron config exp

* Update examples/nemotron/nemotron-mini-4b-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-03-20 16:23:42 +07:00
VED
b3823cc6b0 fix: gemma3 configs (#3500) [skip ci]
* gemma fft , text fix

* good lint
2026-03-20 16:14:06 +07:00
VED
113d275bd9 qwen docs + new config (#3499) [skip ci]
* qwen docs + new config

* docss lint

* simplify comments

* read me

* lint comments

* Update docs/multimodal.qmd

* Update docs/multimodal.qmd

* Update examples/qwen3.5/9b-fft-vision.yaml

* chore: fix link and incorrect points

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 16:13:34 +07:00
VED
7920fe74ec fix num_labels= 1 test fail (#3493) [skip ci]
* trl_num_lables=1

* casual num_lables=1,rwd model

* lint
2026-03-20 16:12:23 +07:00
Wing Lian
1fc86d5295 Scattermoe LoRA optimizations (#3513)
* optimize moe + lora

* more scattermoe optims

* selective dequant

* add correctness unit tests and benchmarks for scattermoe + lora

* handle base+lora split kernel for older moe models

* chore: lint

* fix casting for H200 and B200

* register pressure estimation and pruning for h200/b200

* use soft limit for pruning

* qkv patch for qwen3.5moe

* support text_model for qwen3.5 moe

* nesting of qwen3

* use udpated cce with zero3 support

* Fix decomposed backward for QKV and O projections

eliminates B @ A materialization in LoRA attention backward, replacing full [out, in] matmuls with two small [T, R] matmuls.
2026-03-19 23:07:42 -04:00
Wing Lian
bb483ad4c4 make the CI fail GitHub Actions on test failures (#3517)
* make the CI fail GitHub Actions on test failures

* use model bundle

* install zstd for compressed model artifact
2026-03-19 08:29:24 -04:00
109 changed files with 3614 additions and 515 deletions

View File

@@ -128,11 +128,9 @@ quartodoc:
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils

View File

@@ -0,0 +1,284 @@
"""Benchmark for ScatterMoE LoRA Triton kernels.
Measures forward, backward dX, and backward dA/dB kernels at common MoE
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
and full fwd+bwd autograd throughput.
Usage:
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
"""
import argparse
import gc
import time
from functools import partial
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
WARMUP = 5
ITERS = 20
# ─── Model configs ──────────────────────────────────────────────────────────
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
}
def _resolve_config(spec):
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
key = spec.lower().replace("/", "-")
for name, cfg in BUILTIN_CONFIGS.items():
if key in name.lower() or name.lower() in key:
return name, cfg
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
if callable(getattr(hf_cfg, "get_text_config", None)):
tc = hf_cfg.get_text_config()
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
hf_cfg = tc
hidden = hf_cfg.hidden_size
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
experts = (
getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None)
)
top_k = (
getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None)
or 2
)
name = spec.split("/")[-1]
return name, (experts, hidden, inter, top_k)
# ─── Benchmark helpers ──────────────────────────────────────────────────────
def _clean():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _bench(fn, warmup=WARMUP, iters=ITERS):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
torch.cuda.synchronize()
t0 = time.perf_counter()
fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
times.sort()
return times[len(times) // 2]
def _setup(num_experts, K, N, T, top_k, R):
torch.manual_seed(42)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, num_experts, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
gx = base_ops.group(x, ssi, fan_out=top_k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
return lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
lora_A=lA,
lora_B=lB,
scaling=2.0,
)
def _call_base(x, W, sei, ssi, top_k):
return base_ops.scatter2scatter(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
)
def _call_dx(dy, W, sei, ssi, lA, lB):
return lora_ops.scatter2scatter_lora_dX(
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=2.0,
dy_grouped=True,
dx_grouped=False,
)
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
return lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=num_experts,
scaling=2.0,
)
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
parser.add_argument(
"--models",
"-m",
nargs="+",
help="Model names or HF IDs (default: all builtins)",
)
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
parser.add_argument("--seq-len", "-T", type=int, default=2048)
args = parser.parse_args()
T = args.seq_len
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"T={T}, ranks={args.ranks}\n")
if args.models:
configs = [_resolve_config(m) for m in args.models]
else:
configs = list(BUILTIN_CONFIGS.items())
for model_name, (num_experts, hidden, inter, top_k) in configs:
print(f"{'=' * 70}")
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
print(f"{'=' * 70}")
for R in args.ranks:
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
num_experts, K, N, T, top_k, R
)
# Forward with LoRA (auto-dispatched: fused or split)
dispatch = (
"split"
if (
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
)
else "fused"
)
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
total = t_fwd + t_dx + t_bwd
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(
f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead * 100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms"
)
# Full autograd fwd+bwd with memory measurement
x_ag = x.clone().requires_grad_(True)
lA_ag = lA.clone().requires_grad_(True)
lB_ag = lB.clone().requires_grad_(True)
def _run_autograd(
_x=x_ag,
_W=W,
_k=top_k,
_sei=sei,
_ssi=ssi,
_eo=eo,
_lA=lA_ag,
_lB=lB_ag,
):
out = ScatterMoELoRA.apply(
_x,
_W,
_k,
_sei,
_ssi,
_eo,
_lA,
_lB,
2.0,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
_x.grad = None
_lA.grad = None
_lB.grad = None
t_full = _bench(_run_autograd)
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
_run_autograd()
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(
f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
)
print()
if __name__ == "__main__":
main()

View File

@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -3,11 +3,13 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
hf download "NousResearch/Meta-Llama-3-8B"
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
hf download "microsoft/Phi-4-reasoning"
hf download "microsoft/Phi-3.5-mini-instruct"
set -o pipefail
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
# hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning"
# hf download "microsoft/Phi-3.5-mini-instruct"
# hf download "microsoft/Phi-3-medium-128k-instruct"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \

View File

@@ -68,10 +68,6 @@ def run_cmd(cmd: str, run_folder: str):
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
# Propagate errors from subprocess.
try:
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
print(f"Command '{cmd}' failed with exit code {exit_code}")
return exit_code
except Exception as e: # pylint: disable=broad-except
print(f"Command '{cmd}' failed with exception {e}")
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")

View File

@@ -37,6 +37,7 @@ coverage:
only_pulls: false
flags: null
paths: null
informational: true
parsers:
gcov:

View File

@@ -1,5 +1,5 @@
---
title: Gradient Checkpointing and Activation Offloading
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
@@ -27,3 +27,33 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
### Enabling Layer Offloading
```yaml
layer_offloading: true
```
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
trainable adapter weights stay on GPU permanently.
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
prefetched asynchronously on a separate CUDA stream for overlap.
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
previous layer is prefetched.
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
that is kept on GPU at any given time.
**Requirements:**
- CUDA GPU (CPU-only training is not supported for this feature)
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
- Best combined with LoRA/QLoRA where most parameters are frozen

View File

@@ -20,6 +20,7 @@ format:
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Qwen3.5](#sec-qwen3-5)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
@@ -191,6 +192,14 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3.5 {#sec-qwen3-5}
```yaml
base_model: Qwen/Qwen3.5-9B
chat_template: qwen3_5
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.

View File

@@ -54,6 +54,13 @@ These techniques save VRAM by changing how activations are handled.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
### Layer Offloading
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
- **Config:** `layer_offloading: true`
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
### Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
]
},
{

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -27,6 +24,11 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -27,6 +24,11 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,9 +1,5 @@
base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp
@@ -24,6 +20,11 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_model_dir:

View File

@@ -6,9 +6,6 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
Note: Training this model requires weights in BF16 which we will link to later.
Users interested in training can convert / descale the existing FP8 weights.
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -1,4 +1,4 @@
base_model: mistralai/Mistral-Small-4-119B-2603
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -1,4 +1,4 @@
base_model: mistralai/Mistral-Small-4-119B-2603
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
processor_type: AutoProcessor
plugins:

View File

@@ -1,4 +1,4 @@
base_model: mistralai/Mistral-Small-4-119B-2603
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -1,4 +1,4 @@
base_model: mistralai/Mistral-Small-4-119B-2603
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
processor_type: AutoProcessor
plugins:

View File

@@ -0,0 +1,57 @@
base_model: nvidia/Nemotron-Mini-4B-Instruct
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/nemotron-mini-4b-qlora
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- up_proj
- down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
special_tokens:

View File

@@ -0,0 +1,84 @@
base_model: Qwen/Qwen3.5-122B-A10B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -32,7 +32,11 @@ lora_target_modules:
- v_proj
- o_proj
#lora_target_parameters:
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
@@ -52,7 +56,6 @@ learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false

View File

@@ -0,0 +1,59 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
# Freeze vision encoder
unfrozen_parameters:
- model\.language_model\..*
- lm_head\..*
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,81 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment below to also target the linear attention projections.
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,9 +1,7 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
# the text-only path. For multimodal (image+text) fine-tuning, add image
# columns to your dataset following axolotl's multimodal dataset format.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -0,0 +1,85 @@
base_model: Qwen/Qwen3.5-35B-A3B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -32,7 +32,11 @@ lora_target_modules:
- v_proj
- o_proj
#lora_target_parameters:
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj

View File

@@ -0,0 +1,49 @@
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Required for multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,10 +1,6 @@
base_model: Qwen/Qwen3.5-7B
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
# Vision and text tokens are processed together by the same transformer layers.
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
# These 3 lines are required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
@@ -30,8 +26,6 @@ lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# Targets the language model attention and MLP layers.
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
# the same transformer stack, so standard attention targets work for both modalities.
lora_target_modules:
- q_proj
- k_proj

View File

@@ -1,15 +1,6 @@
# Finetune Qwen3.5 with Axolotl
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
Available configs:
| Config | Model | Type |
|---|---|---|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
## Getting started
@@ -18,35 +9,69 @@ Available configs:
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Pick any config from the table below and run:
```bash
axolotl train examples/qwen3.5/<config>.yaml
```
Available configs:
| Config | Model | Type | Peak VRAM |
|---|---|---|---|
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
### Gated DeltaNet Linear Attention
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
```yaml
lora_target_modules:
# ... standard projections ...
- linear_attn.in_proj_qkv
- linear_attn.in_proj_z
- linear_attn.out_proj
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Run a finetuning example:
### Routed Experts (MoE)
```bash
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
axolotl train examples/qwen3.5/27b-qlora.yaml
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
# MoE 35B-A3B text-only (QLoRA)
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router
```
# MoE 122B-A10B text-only (QLoRA)
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
### Shared Experts (MoE)
# 7B vision+text (LoRA, multimodal dataset)
axolotl train examples/qwen3.5/7b-lora-vision.yaml
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
```yaml
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
```
### TIPS
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- For inference hyp, please see the respective model card details.
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
## Optimization Guides

View File

@@ -61,5 +61,11 @@ skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.pytest.ini_options]
addopts = "-m 'not slow'"
markers = [
"slow: marks tests as slow",
]
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
)

View File

@@ -81,16 +81,23 @@ def parse_requirements(extras_require_map):
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 9):
if (major, minor) >= (2, 10):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.5.0",
"fbgemm-gpu-genai==1.5.0",
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.17.1"]
elif (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:

View File

@@ -3,6 +3,7 @@
import os
from pathlib import Path
import httpcore
from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
@@ -47,7 +48,7 @@ def check_user_token() -> bool:
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except HTTPError:
except (HTTPError, httpcore.ConnectError):
LOG.warning(
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
)

View File

@@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adamw":
from flashoptim import FlashAdamW
optimizer_cls = FlashAdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adam":
from flashoptim import FlashAdam
optimizer_cls = FlashAdam
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_sgd":
from flashoptim import FlashSGD
optimizer_cls = FlashSGD
elif self.cfg.optimizer == "flash_sgdw":
from flashoptim import FlashSGDW
optimizer_cls = FlashSGDW
elif self.cfg.optimizer == "flash_lion":
from flashoptim import FlashLion
optimizer_cls = FlashLion
if "betas" in adam_kwargs:
optimizer_kwargs["betas"] = adam_kwargs["betas"]
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
@@ -484,6 +508,8 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.layer_offloading:
training_args_kwargs["layer_offloading"] = True
if self.cfg.activation_offloading is True:
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False

View File

@@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
# config reflects this regardless of how the model was instantiated.
if (
self.cfg.reward_model
and getattr(self.model.config, "num_labels", None) != 1
):
self.model.config.num_labels = 1
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,

View File

@@ -208,7 +208,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
if (
self.cfg.adapter
and self.peft_config
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
):
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (

View File

@@ -29,10 +29,12 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.experimental.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
DistributedParallelMixin,
LayerOffloadingMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
@@ -51,8 +53,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state."
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
@@ -67,6 +67,7 @@ class AxolotlTrainer(
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
LayerOffloadingMixin,
ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer,

View File

@@ -0,0 +1 @@
TOKENS_STATE_FILE = "tokens_state.json"

View File

@@ -2,7 +2,8 @@
Axolotl specific DPO args
"""
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from trl import DPOConfig
@@ -16,3 +17,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
dpo_norm_loss: bool | None = False
rpo_alpha: Optional[float] = field(default=None)

View File

@@ -4,6 +4,7 @@
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .layer_offloading import LayerOffloadingMixin
from .distributed_parallel import DistributedParallelMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin

View File

@@ -0,0 +1,304 @@
"""
Trainer mixin for layer-wise parameter offloading to CPU.
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
transfer stream), post-hook offloads layer N-1's frozen params.
Backward: same in reverse order.
"""
import contextlib
import torch
import torch.nn as nn
from transformers import Trainer
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
"""Recursively search the model for the decoder layer ModuleList.
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
where layers are at model.language_model.layers).
"""
# BFS to find the first ModuleList containing decoder layers
queue = [model]
while queue:
m = queue.pop(0)
for _name, child in m.named_children():
if isinstance(child, nn.ModuleList) and len(child) > 0:
first_type = type(child[0]).__name__
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
layer_types = list({type(layer).__name__ for layer in child})
return child, layer_types
else:
queue.append(child)
return None, []
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
"""Get all non-trainable parameters in a layer."""
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
class LayerOffloadManager:
"""Manages offloading frozen decoder layer params to CPU and streaming
them back during forward/backward with CUDA stream overlap.
Only frozen (requires_grad=False) parameters are offloaded.
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
"""
def __init__(
self,
model: nn.Module,
num_prefetch: int = 1,
):
self.model = model
self.num_prefetch = num_prefetch
self._hooks: list = []
self._device = None
# Find decoder layers
self.layers, layer_types = _find_decoder_layers(model)
if self.layers is None:
LOG.warning(
"LayerOffloadManager: no decoder layers found, offloading disabled"
)
self.enabled = False
return
self.enabled = True
self.n_layers = len(self.layers)
LOG.info(
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
)
# Determine GPU device
for p in model.parameters():
if p.device.type == "cuda":
self._device = p.device
break
if self._device is None:
LOG.warning("LayerOffloadManager: no CUDA parameters found")
self.enabled = False
return
# Transfer stream for async prefetch
self._transfer_stream = torch.cuda.Stream(device=self._device)
# Track which layers have their frozen params on GPU
self._on_gpu: set[int] = set(range(self.n_layers))
# Cache: frozen param references per layer (list of (name, param) tuples)
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
]
# CPU storage: pinned tensors for each layer's frozen params
# Populated on first offload
self._cpu_data: list[dict[str, torch.Tensor]] = [
{} for _ in range(self.n_layers)
]
# Offload all layers upfront
self._offload_all()
# Release cached memory blocks back to the driver
torch.cuda.empty_cache()
def _offload_all(self):
"""Move all frozen params in all decoder layers to CPU."""
mem_before = torch.cuda.memory_allocated(self._device)
for i in range(self.n_layers):
self._offload_layer(i)
mem_after = torch.cuda.memory_allocated(self._device)
freed = (mem_before - mem_after) / 1e6
LOG.info(
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
f"freed {freed:.0f} MB GPU memory"
)
def _offload_layer(self, idx: int):
"""Move frozen params of layer idx to CPU pinned memory."""
if idx not in self._on_gpu:
return
for name, param in self._frozen_params[idx]:
if param.device.type != "cuda":
continue
# Allocate pinned CPU tensor on first offload
if name not in self._cpu_data[idx]:
self._cpu_data[idx][name] = torch.empty_like(
param.data, device="cpu", pin_memory=True
)
cpu_buf = self._cpu_data[idx][name]
# Async copy GPU -> CPU (on transfer stream for overlap)
cpu_buf.copy_(param.data, non_blocking=True)
# Point parameter at a dummy CPU tensor to free GPU memory
param.data = cpu_buf
self._on_gpu.discard(idx)
def _load_layer(self, idx: int, stream=None):
"""Move frozen params of layer idx back to GPU."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
ctx = (
torch.cuda.stream(stream)
if stream is not None
else contextlib.nullcontext()
)
with ctx:
for _name, param in self._frozen_params[idx]:
if param.device.type == "cuda":
continue
gpu_data = param.data.to(self._device, non_blocking=True)
param.data = gpu_data
self._on_gpu.add(idx)
def _prefetch_layer(self, idx: int):
"""Async prefetch layer idx on the transfer stream."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
self._load_layer(idx, stream=self._transfer_stream)
def _wait_transfer(self):
"""Make default stream wait for any in-flight transfers."""
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
def setup_hooks(self):
"""Register forward and backward hooks on each decoder layer."""
if not self.enabled:
return
for idx in range(self.n_layers):
layer = self.layers[idx]
def make_pre_fwd(i):
def hook(module, args):
# Ensure this layer is on GPU
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch next layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i + offset)
return hook
def make_post_fwd(i):
def hook(module, args, output):
# Offload previous layer (no longer needed in forward)
if i > 0:
self._offload_layer(i - 1)
# Offload last layer after forward
if i == self.n_layers - 1:
self._offload_layer(i)
return hook
def make_pre_bwd(i):
def hook(module, grad_output):
# Load this layer for backward
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch previous layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i - offset)
return hook
def make_post_bwd(i):
def hook(module, grad_input, grad_output):
# Offload the layer above
if i < self.n_layers - 1:
self._offload_layer(i + 1)
# Offload first layer after backward
if i == 0:
self._offload_layer(i)
return hook
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
h2 = layer.register_forward_hook(make_post_fwd(idx))
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
self._hooks.extend([h1, h2, h3, h4])
def remove_hooks(self):
"""Remove all hooks and restore layers to GPU."""
for h in self._hooks:
h.remove()
self._hooks.clear()
if self.enabled:
for i in range(self.n_layers):
if i not in self._on_gpu:
self._load_layer(i)
def pre_step(self):
"""Called before each training step — ensure layers start offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for forward
self._prefetch_layer(0)
def post_step(self):
"""Called after each training step — ensure layers are offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for next step
self._prefetch_layer(0)
class _LayerOffloadContext:
"""Context manager wrapping pre_step / post_step around a training step."""
def __init__(self, manager: LayerOffloadManager):
self.manager = manager
def __enter__(self):
self.manager.pre_step()
return self
def __exit__(self, *args):
self.manager.post_step()
class LayerOffloadingMixin(Trainer):
"""
Trainer mixin class for layer-wise parameter offloading to CPU.
Offloads frozen decoder layer params to CPU at init, then streams them
on/off GPU one layer at a time during each training step.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if getattr(self.args, "layer_offloading", False):
LOG.info("Layer parameter offloading enabled")
self._layer_offload_manager = LayerOffloadManager(
model=self.model,
num_prefetch=1,
)
self._layer_offload_manager.setup_hooks()
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
else:
self._layer_offload_manager = None
self._layer_offload_ctx = contextlib.nullcontext()
def training_step(self, *args, **kwargs):
with self._layer_offload_ctx:
return super().training_step(*args, **kwargs)

View File

@@ -235,6 +235,13 @@ class AxolotlTrainingMixins:
metadata={"help": "Use activation offloading with CUDA streams for training."},
)
layer_offloading: bool | None = field(
default=None,
metadata={
"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
```
## Usage

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
)

View File

@@ -15,6 +15,7 @@ SPARSE_MOE_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock",
"qwen3_next": "Qwen3NextSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
@@ -58,7 +59,16 @@ def resolve_moe_block_classes(model_type: str):
cls_names = entry if isinstance(entry, list) else [entry]
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = importlib.import_module(module_path)
try:
module = importlib.import_module(module_path)
except ModuleNotFoundError:
# Text sub-model types (e.g. qwen3_5_moe_text) share the parent module
if model_type.endswith("_text"):
parent_type = model_type.removesuffix("_text")
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
module = importlib.import_module(module_path)
else:
raise
classes = []
for cls_name in cls_names:

View File

@@ -195,6 +195,36 @@ def _estimate_smem_usage(
_SMEM_SLACK = 10_000
def _estimate_register_pressure(
num_warps: int,
*tile_sizes: tuple[int, int],
) -> float:
"""Rough estimate of per-thread register footprint from live tile sizes.
This is a heuristic, NOT an accurate register count. Triton uses tensor
core MMA fragments that pack multiple elements per register, and can spill
to local memory when the hardware limit (255 regs/thread) is exceeded.
The estimate is used to prune only truly extreme configs that would cause
excessive spilling or compilation failures. The threshold is set high
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
(est ~520) work fine in practice via spilling.
Returns estimated registers per thread.
"""
# Each thread in a warp holds ~1/32 of the tile elements
tile_regs = sum(r * c for r, c in tile_sizes) / 32
scalar_overhead = 40
return tile_regs + scalar_overhead
# Soft limit for register pressure pruning. Only prune configs with extreme
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
_MAX_REGS_SOFT_LIMIT = 1024
# =============================================================================
# Forward Kernel: scatter2scatter with fused LoRA
# =============================================================================
@@ -313,12 +343,11 @@ def _compute_expert_block_lora(
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
) # [BLOCK_N, BLOCK_R]
# Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16)
# Both operands must match; cast to float32 (accumulator type) for precision.
b_f32 = b.to(tl.float32)
# tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype
b_inp = b.to(INPUT_DTYPE)
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32)
lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)
acc += scaling * lora_out
return acc
@@ -327,28 +356,29 @@ def _compute_expert_block_lora(
def _scatter2scatter_lora_configs():
"""Generate forward kernel autotune configs.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space includes BLOCK_M to allow trading token-tile size for
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
(4× fewer inner-loop iterations).
Search space:
BLOCK_N: {32, 64, 128, 256}
BLOCK_M: {32, 64, 128}
BLOCK_N: {32, 64}
BLOCK_K: {32, 64, 128}
num_warps: {4, 8}
num_stages: {3, 4, 5}
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
scatter2scatter pattern).
"""
configs = []
for block_n, block_k, warps, stages in product(
[32, 64, 128, 256], # BLOCK_N
for block_m, block_n, block_k, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64], # BLOCK_N
[32, 64, 128], # BLOCK_K
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_N": block_n, "BLOCK_K": block_k},
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
num_stages=stages,
num_warps=warps,
)
@@ -357,7 +387,7 @@ def _scatter2scatter_lora_configs():
def _prune_fwd_configs(configs, named_args, **kwargs):
"""Prune forward configs based on SMEM capacity.
"""Prune forward configs based on SMEM capacity and register pressure.
The forward kernel inner loop loads three tiles per pipeline stage:
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
@@ -373,23 +403,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_n = config.kwargs["BLOCK_N"]
block_k = config.kwargs["BLOCK_K"]
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_r * block_k * 2
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
smem_lora_epilogue = block_n * block_r * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,N], xa_acc[M,R],
# x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_n), # acc
(block_m, block_r), # xa_acc
(block_m, block_k), # x tile
(block_k, block_n), # w tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"]
),
)
]
@triton.autotune(
@@ -531,6 +587,89 @@ def _scatter2scatter_lora(
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def _scatter2scatter_lora_split(
X: torch.Tensor,
W: torch.Tensor,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
k: int,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
b: Optional[torch.Tensor] = None,
x_grouped: bool = False,
y_grouped: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
because the base kernel runs at full speed without LoRA SMEM overhead,
and the LoRA matmuls (R=16) are tiny separate passes.
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
"""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
scatter2scatter,
)
E = W.size(0)
R = lora_A.size(0) // E
K = W.size(1)
N = W.size(2)
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
output = scatter2scatter(
X=X,
W=W,
b=b,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=y_grouped,
out=out,
)
# 2. XA = X @ A^T (tiny: output is [M*k, R])
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
XA = scatter2scatter(
X=X,
W=W_A,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=True,
)
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
# Reshape B: [N, R*E] → [E, R, N]
W_B = lora_B.T.reshape(E, R, N).contiguous()
Y_lora = scatter2scatter(
X=XA,
W=W_B,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
x_grouped=True,
y_grouped=y_grouped,
)
# 4. Y = Y_base + scaling * Y_lora
output.add_(Y_lora, alpha=scaling)
return output
# Threshold for switching from fused to split LoRA forward.
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
# loads dominate in the fused kernel's inner loop).
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
def scatter2scatter_lora(
X: torch.Tensor,
W: torch.Tensor,
@@ -546,7 +685,13 @@ def scatter2scatter_lora(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Automatically selects between:
- Fused kernel: single Triton kernel with LoRA in the inner loop.
Best for many small experts (E>=64, small K*N).
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
Best for few large experts (E<=32, large K*N like Mixtral).
Args:
X: Input [M, K] or [M*k, K] if x_grouped
@@ -565,12 +710,30 @@ def scatter2scatter_lora(
Returns:
Y: Output [M*k, N]
"""
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
E = W.size(0)
K = W.size(1)
N = W.size(2)
# Dispatch: split for few large experts, fused for many small experts
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
return _scatter2scatter_lora_split(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
lora_A,
lora_B,
scaling,
b,
x_grouped,
y_grouped,
out,
)
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
R = lora_A.size(0) // E
# Pad R to power of 2 for Triton tile size
@@ -610,11 +773,9 @@ def scatter2scatter_lora(
b_ptr,
stride_be,
stride_bn,
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
lora_A,
lora_A.stride(0),
lora_A.stride(1),
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
lora_B,
lora_B.stride(0),
lora_B.stride(1),
@@ -625,9 +786,8 @@ def scatter2scatter_lora(
K=K,
N=N,
E=E,
ACTUAL_R=R, # True LoRA rank for weight indexing
BLOCK_M=BLOCK_M,
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
allow_tf32=ALLOW_TF32,
@@ -761,13 +921,13 @@ def _compute_expert_block_lora_dX(
+ (A_expert_offset + R_block)[:, None] * stride_ar
+ K_block[None, :] * stride_ak
)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0)
# Cast to float32 for precision
a_f32 = a_e.to(tl.float32)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(
INPUT_DTYPE
)
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32)
# tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype
lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)
acc += scaling * lora_dx
return acc
@@ -779,25 +939,26 @@ def _scatter2scatter_lora_dX_configs():
The inner loop is over N (not K as in forward). The output dimension is K.
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
BLOCK_M is now autotunable (was fixed at 128).
Search space:
BLOCK_K: {32, 64, 128, 256} (output tile)
BLOCK_N: {32, 64, 128, 256} (reduction tile)
BLOCK_M: {32, 64, 128} (token tile)
BLOCK_K: {32, 64, 128} (output tile)
BLOCK_N: {32, 64} (reduction tile)
num_warps: {4, 8}
num_stages: {3, 4, 5}
"""
configs = []
for block_k, block_n, warps, stages in product(
[32, 64, 128, 256], # BLOCK_K (output dimension)
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128], # BLOCK_K (output dimension)
[32, 64], # BLOCK_N (reduction dimension)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_K": block_k, "BLOCK_N": block_n},
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
num_stages=stages,
num_warps=warps,
)
@@ -806,7 +967,7 @@ def _scatter2scatter_lora_dX_configs():
def _prune_dX_configs(configs, named_args, **kwargs):
"""Prune backward dX configs based on SMEM capacity.
"""Prune backward dX configs based on SMEM capacity and register pressure.
The dX kernel inner loop loads three tiles per pipeline stage:
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
@@ -822,23 +983,49 @@ def _prune_dX_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_k = config.kwargs["BLOCK_K"]
block_n = config.kwargs["BLOCK_N"]
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_n * block_r * 2
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
smem_lora_epilogue = block_r * block_k * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],
# dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_k), # acc
(block_m, block_r), # dy_b_acc
(block_m, block_n), # dy tile
(block_n, block_k), # wt tile
(block_n, block_r), # b tile
(block_r, block_k), # a tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1067,7 +1254,7 @@ def scatter2scatter_lora_dX(
N=N,
E=E,
ACTUAL_R=R,
BLOCK_M=BLOCK_M,
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
@@ -1091,9 +1278,9 @@ def _group_bwd_lora_configs():
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space:
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
BLOCK_K: {32, 64, 128, 256}
BLOCK_N: {32, 64, 128, 256}
BLOCK_M: {32, 64, 128} (token-loop tile)
BLOCK_K: {32, 64, 128}
BLOCK_N: {32, 64}
num_warps: {4, 8}
num_stages: {3, 4, 5}
@@ -1102,9 +1289,9 @@ def _group_bwd_lora_configs():
"""
configs = []
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128, 256], # BLOCK_M
[32, 64, 128, 256], # BLOCK_K
[32, 64, 128, 256], # BLOCK_N
[32, 64, 128], # BLOCK_M
[32, 64, 128], # BLOCK_K
[32, 64], # BLOCK_N
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
@@ -1119,7 +1306,7 @@ def _group_bwd_lora_configs():
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
"""Prune backward configs based on SMEM capacity.
"""Prune backward configs based on SMEM capacity and register pressure.
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
@@ -1138,14 +1325,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
smem_lora = (block_r * block_k + block_n * block_r) * 2
smem = smem_base + smem_lora
# Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],
# a[R,K], b[N,R], xa[M,R], dy_b[M,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_k), # dA_acc
(block_n, block_r), # dB_acc
(block_m, block_k), # x tile
(block_m, block_n), # dy tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile
(block_m, block_r), # xa intermediate
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1330,6 +1543,279 @@ def _group_bwd_lora(
)
def _group_bwd_split_configs():
"""Autotune configs for split dA/dB kernels."""
configs = []
for block_m, block_dim, warps, stages in product(
[32, 64, 128], # BLOCK_M (token tile)
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
num_stages=stages,
num_warps=warps,
)
)
return configs
def _prune_split_configs(configs, named_args, **kwargs):
"""Prune split kernel configs based on SMEM capacity and register pressure."""
smem_cap = _get_smem_capacity()
block_r = named_args.get("BLOCK_R", 64)
# Fixed inner tile for reduction dimension
BLOCK_INNER = 64
pruned = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_dim = config.kwargs["BLOCK_DIM"]
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
# LoRA weights held in registers: [INNER, R] or [R, DIM]
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
# Register pressure check
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_dim), # acc
(block_m, BLOCK_INNER), # input tile
(block_m, block_dim), # other tile
(block_r, BLOCK_INNER), # lora weight
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
if smem <= smem_cap - _SMEM_SLACK:
pruned.append(config)
if pruned:
return pruned
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
return [configs[0]]
@triton.autotune(
configs=_group_bwd_split_configs(),
key=["M", "K", "N"],
prune_configs_by={"early_config_prune": _prune_split_configs},
)
@triton.heuristics(
{
"NO_DIM_MASK": lambda args: (
(args["K"] % args["BLOCK_DIM"]) == 0
if args["COMPUTE_DA"]
else (args["N"] % args["BLOCK_DIM"]) == 0
),
}
)
@triton.jit
def _group_bwd_lora_split(
# Data tensors (DY and X are always present)
DY_ptr,
stride_dym,
stride_dyn,
X_ptr,
stride_xm,
stride_xk,
# LoRA weight for the inner reduction (B for dA, A for dB)
LW_ptr,
stride_lw0,
stride_lw1,
# Output gradient tensor (dA or dB)
OUT_ptr,
stride_out0,
stride_out1,
# Expert offsets
expert_offsets_ptr,
# Dimensions
M,
K: tl.constexpr,
N: tl.constexpr,
ACTUAL_R: tl.constexpr,
BLOCK_R: tl.constexpr,
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
scaling,
# Mode flag
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
# Tile sizes
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_DIM_MASK: tl.constexpr,
):
"""
Unified split kernel for LoRA gradient computation.
When COMPUTE_DA=True:
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
Grid: (E, cdiv(K, BLOCK_DIM))
- outer_ptr/stride = X (read [M, K_block])
- inner reduction over N using DY and B
- output shape [BLOCK_R, BLOCK_DIM]
When COMPUTE_DA=False:
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
Grid: (E, cdiv(N, BLOCK_DIM))
- outer_ptr/stride = DY (read [M, N_block])
- inner reduction over K using X and A
- output shape [BLOCK_DIM, BLOCK_R]
No atomic adds — each (E, dim_block) pair is written by exactly one block.
"""
E_idx = tl.program_id(0)
dim_block_id = tl.program_id(1)
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
num_tokens = end_idx - start_idx
# Output dimension tile (K for dA, N for dB)
if COMPUTE_DA:
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
else:
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_mask = dim_block < OUT_DIM
R_block = tl.arange(0, BLOCK_R)
R_mask = R_block < ACTUAL_R
lora_offset = E_idx * ACTUAL_R
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
if COMPUTE_DA:
out_blk_ptrs = (
OUT_ptr
+ (lora_offset + R_block)[:, None] * stride_out0
+ dim_block[None, :] * stride_out1
)
out_mask = R_mask[:, None] & dim_mask[None, :]
else:
out_blk_ptrs = (
OUT_ptr
+ dim_block[:, None] * stride_out0
+ (lora_offset + R_block)[None, :] * stride_out1
)
out_mask = dim_mask[:, None] & R_mask[None, :]
if num_tokens > 0:
M_block = tl.arange(0, BLOCK_M)
INPUT_DTYPE = X_ptr.dtype.element_ty
BLOCK_INNER: tl.constexpr = 64
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
if COMPUTE_DA:
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
else:
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
M_iters = tl.cdiv(num_tokens, BLOCK_M)
for i in range(M_iters):
M_idx = start_idx + i * BLOCK_M + M_block
M_mask = M_idx < end_idx
if COMPUTE_DA:
# Load X[M, K_block] (the "outer" tensor for dA)
outer = tl.load(
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < N
dy_tile = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ inn_off[None, :] * stride_dyn,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
lw_tile = tl.load(
LW_ptr
+ inn_off[:, None] * stride_lw0
+ (lora_offset + R_block)[None, :] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
acc += tl.dot(
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
)
else:
# Load DY[M, N_block] (the "outer" tensor for dB)
outer = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ dim_block[None, :] * stride_dyn,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < K
x_tile = tl.load(
X_ptr
+ M_idx[:, None] * stride_xm
+ inn_off[None, :] * stride_xk,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
# We want A[e]^T: [K, R], so load as [K_inner, R]
lw_tile = tl.load(
LW_ptr
+ (lora_offset + R_block)[None, :] * stride_lw0
+ inn_off[:, None] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
acc += tl.dot(
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
)
tl.store(
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
)
else:
# Zero out this expert's slice — needed because output uses empty_like
if COMPUTE_DA:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
else:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
def group_bwd_lora(
DY: torch.Tensor,
X: torch.Tensor,
@@ -1344,6 +1830,9 @@ def group_bwd_lora(
"""
Compute LoRA gradients for A and B on expert-grouped data.
Uses split dA/dB kernels that eliminate atomic adds by giving each
(expert, output_block) pair its own thread block.
Args:
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
X: Input [M_total, K] (grouped by expert)
@@ -1361,19 +1850,46 @@ def group_bwd_lora(
K = X.size(1)
N = DY.size(1)
# Zero-init for atomic accumulation
dA = torch.zeros_like(lora_A)
dB = torch.zeros_like(lora_B)
# No zero-init needed: the split kernels write zeros for experts with
# zero routed tokens directly in the kernel (else branch).
dA = torch.empty_like(lora_A)
dB = torch.empty_like(lora_B)
BLOCK_R = _block_r_for_rank(R)
def grid(META):
return (
E * triton.cdiv(K, META["BLOCK_K"]),
triton.cdiv(N, META["BLOCK_N"]),
)
def grid_dA(META):
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
_group_bwd_lora[grid](
_group_bwd_lora_split[grid_dA](
DY,
DY.stride(0),
DY.stride(1),
X,
X.stride(0),
X.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
expert_offsets,
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=N,
scaling=scaling,
COMPUTE_DA=True,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
def grid_dB(META):
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
_group_bwd_lora_split[grid_dB](
DY,
DY.stride(0),
DY.stride(1),
@@ -1383,12 +1899,6 @@ def group_bwd_lora(
lora_A,
lora_A.stride(0),
lora_A.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
dB,
dB.stride(0),
dB.stride(1),
@@ -1396,9 +1906,11 @@ def group_bwd_lora(
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R, # True LoRA rank
BLOCK_R=BLOCK_R, # Padded tile size
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=K,
scaling=scaling,
COMPUTE_DA=False,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)

View File

@@ -489,20 +489,71 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Selective expert weight dequantization
# ====================================================================
# When experts are BnB-quantized (quantize_moe_experts), dequantize
# only the active experts instead of all E. This saves ~97% memory
# for the transient dequant buffer when few experts are active.
use_selective = (
getattr(self, "_use_selective_dequant", False)
and hasattr(experts, "parametrizations")
and "gate_up_proj" in experts.parametrizations
)
if use_selective:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
get_active_experts,
remap_expert_indices,
selective_expert_weights,
selective_lora_weights,
)
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
remapped_expert_idxs, compact_offsets = remap_expert_indices(
sorted_expert_idxs,
expert_offsets,
active_experts,
num_experts,
)
# Dequantize only active experts' weights
gate_up_W = selective_expert_weights(
experts,
"gate_up_proj",
active_experts,
).transpose(2, 1) # [num_active, hidden, 2*inter]
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup_A, gup_B = selective_lora_weights(
gup_A,
gup_B,
active_experts,
num_experts,
)
gup_lora = (gup_A, gup_B, gup_scaling)
# Use remapped indices for ScatterMoE kernels
sei_gup = remapped_expert_idxs
eo_gup = compact_offsets
else:
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
sei_gup = sorted_expert_idxs
eo_gup = expert_offsets
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
@@ -516,9 +567,9 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
grouped_in=False,
grouped_out=True,
)
@@ -529,7 +580,29 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if use_selective:
down_W = selective_expert_weights(
experts,
"down_proj",
active_experts,
).transpose(2, 1) # [num_active, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
down_A, down_B = selective_lora_weights(
down_A,
down_B,
active_experts,
num_experts,
)
down_lora = (down_A, down_B, down_scaling)
sei_down = remapped_expert_idxs
eo_down = compact_offsets
else:
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
sei_down = sorted_expert_idxs
eo_down = expert_offsets
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
@@ -537,9 +610,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
@@ -554,9 +627,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
grouped_in=True,
grouped_out=False,
gates=routing_weights,

View File

@@ -0,0 +1,282 @@
"""
Selective Expert Dequantization
===============================
Instead of dequantizing all E expert weight matrices at once (which creates
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
are actually routed to by the current batch's top-k selection.
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
- Savings: ~97% memory reduction per layer
This module provides format-agnostic selective weight extraction:
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
- bf16/fp32: direct indexing (no dequant needed)
- FP8: slice + cast
The ScatterMoE kernel itself doesn't change — we remap expert indices
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
weight tensor.
"""
import torch
import torch.nn as nn
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
"""Get sorted unique expert indices from the routing output.
Args:
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
E: Total number of experts
Returns:
active: Sorted unique expert indices [num_active]
"""
return torch.unique(sorted_expert_idxs)
def remap_expert_indices(
sorted_expert_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Remap global expert indices to compact indices.
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
sort order. Also compacts expert_offsets to only active experts.
Args:
sorted_expert_idxs: [T*k] expert ids in sorted order
expert_offsets: [E] cumulative token counts (original)
active_experts: [num_active] sorted unique expert ids
E: Total number of experts
Returns:
remapped_idxs: [T*k] expert ids in [0..num_active-1]
compact_offsets: [num_active] cumulative token counts
"""
# Build remap table: global_id -> compact_id
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
remap[active_experts] = torch.arange(
len(active_experts), device=sorted_expert_idxs.device
)
remapped_idxs = remap[sorted_expert_idxs]
# Compact the expert_offsets: only keep active experts' cumulative counts
compact_offsets = expert_offsets[active_experts]
return remapped_idxs, compact_offsets
def _selective_dequant_bnb4(
raw_param: torch.Tensor,
quant_state,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
) -> torch.Tensor:
"""Dequantize only selected experts from BnB 4-bit packed data.
The raw parameter is a flattened 4-bit packed tensor. Each expert's
data is contiguous (stored in expert-major order), so we can gather
the packed data and absmax blocks for active experts, then dequantize
as one contiguous block.
Args:
raw_param: Flattened uint8 tensor of packed 4-bit weights
quant_state: BnB QuantState with absmax, blocksize, code, etc.
active_experts: [num_active] expert indices to dequantize
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
Returns:
Dequantized weights [num_active, dim1, dim2] in original dtype
"""
import bitsandbytes.functional as F # noqa: N812
from bitsandbytes.functional import QuantState
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
blocks_per_expert = expert_numel // quant_state.blocksize
num_active = len(active_experts)
if blocks_per_expert == 0:
# Expert is smaller than one quantization block — blocks span across
# expert boundaries, so per-expert slicing isn't possible.
# Fallback: full dequantize + index.
full = F.dequantize_4bit(raw_param, quant_state)
E_total = full.numel() // expert_numel
return full.reshape(E_total, *expert_shape)[active_experts]
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
selective_dequant_nf4_triton,
)
# Handle nested (double) quantization: dequantize absmax first
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
if quant_state.nested:
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
else:
absmax = quant_state.absmax
return selective_dequant_nf4_triton(
packed_data=raw_param,
absmax=absmax,
active_experts=active_experts,
expert_shape=expert_shape,
blocksize=quant_state.blocksize,
dtype=quant_state.dtype,
codebook=quant_state.code,
)
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
raw_flat = raw_param.reshape(-1)
offsets_qt = (
active_experts.long()[:, None] * packed_per_expert
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
qt_gathered = raw_flat[offsets_qt]
offsets_abs = (
active_experts.long()[:, None] * blocks_per_expert
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
if quant_state.nested:
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
full_absmax += quant_state.offset
if full_absmax.dtype != torch.float32:
full_absmax = full_absmax.float()
absmax_gathered = full_absmax[offsets_abs]
else:
absmax_gathered = quant_state.absmax[offsets_abs]
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
gathered_qs = QuantState(
absmax=absmax_gathered,
shape=torch.Size([num_active * expert_numel]),
blocksize=quant_state.blocksize,
quant_type=quant_state.quant_type,
code=quant_state.code,
dtype=quant_state.dtype,
)
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
return deq.reshape(num_active, *expert_shape)
def _selective_index_dense(
param: torch.Tensor,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Select experts from a dense (bf16/fp32) weight tensor.
Simple indexing — no dequantization needed.
"""
return param[active_experts]
def selective_expert_weights(
experts_module: nn.Module,
param_name: str,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Extract and dequantize only the active experts' weights.
Format-agnostic: dispatches based on whether the parameter is
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
Args:
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
param_name: "gate_up_proj" or "down_proj"
active_experts: [num_active] sorted unique expert indices
Returns:
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
"""
# Check if the parameter is BnB-quantized via parametrize
if (
hasattr(experts_module, "parametrizations")
and param_name in experts_module.parametrizations
):
param_list = experts_module.parametrizations[param_name]
parametrization = param_list[0]
# BnB 4-bit parametrization
if hasattr(parametrization, "quant_state"):
# The raw quantized data is on the ParametrizationList, not the
# individual Bnb4bitParametrization module
raw_param = param_list.original
qs = parametrization.quant_state
# qs.shape is the original tensor shape before flattening.
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
orig_shape = qs.shape
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
expert_shape = (orig_shape[1], orig_shape[2])
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
# Flattened — need to infer from module attributes
E_total = getattr(experts_module, "num_experts", None)
if E_total is None:
E_total = int(active_experts.max().item()) + 1
expert_numel = orig_shape[0] // E_total
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
experts_module, "intermediate_dim", None
)
if d2 and expert_numel % d2 == 0:
expert_shape = (expert_numel // d2, d2)
else:
full = getattr(experts_module, param_name)
return full[active_experts]
else:
full = getattr(experts_module, param_name)
return full[active_experts]
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
# Dense parameter (bf16/fp32) — direct indexing
param = getattr(experts_module, param_name)
if param.dim() == 3:
return param[active_experts]
# Fallback: full access
return param
def selective_lora_weights(
lora_A: torch.Tensor,
lora_B: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select LoRA A and B weights for only the active experts.
LoRA layout (scattermoe format):
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
Returns compact:
A: [r*num_active, K]
B: [N, r*num_active]
"""
R = lora_A.size(0) // E
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
row_idx = (
active_experts.long()[:, None] * R
+ torch.arange(R, device=lora_A.device)[None, :]
).reshape(-1)
compact_A = lora_A[row_idx] # [r*num_active, K]
compact_B = lora_B[:, row_idx] # [N, r*num_active]
return compact_A, compact_B

View File

@@ -0,0 +1,179 @@
"""
Triton kernel for fused selective expert gather + NF4 dequantization.
Instead of:
1. Gather packed uint8 data for active experts (memory copy)
2. Gather absmax for active experts (memory copy)
3. Call BnB dequantize_4bit CUDA kernel
This kernel does all three in one pass:
- Reads packed NF4 bytes from expert-strided positions
- Looks up the NF4 codebook
- Multiplies by the per-block absmax
- Writes bf16 output directly
This eliminates the intermediate gather buffer entirely.
"""
import torch
import triton
import triton.language as tl
# NF4 codebook (16 values, precomputed by BnB)
# These are the normalized float4 reconstruction values
NF4_CODEBOOK = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
@triton.jit
def _selective_dequant_nf4_kernel(
# Input: packed NF4 data (flattened, expert-major order)
packed_ptr,
# Input: absmax values (flattened, expert-major order)
absmax_ptr,
# Input: active expert indices
active_experts_ptr,
# Input: NF4 codebook (16 float values)
codebook_ptr,
# Output: dequantized bf16 weights [num_active, expert_numel]
out_ptr,
stride_out_e, # stride for expert dim in output
# Dimensions
num_active,
packed_per_expert, # expert_numel // 2
blocks_per_expert, # expert_numel // blocksize
blocksize: tl.constexpr,
# Tile size
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
):
"""
Each program processes BLOCK_SIZE elements from one expert.
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
For each output element:
1. Compute which byte in packed data contains this element
2. Extract the 4-bit nibble (high or low)
3. Look up in NF4 codebook
4. Scale by absmax for this block
"""
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
block_id = tl.program_id(1) # which element block
# Load the global expert index
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = elem_offset < expert_numel
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
byte_idx = elem_offset // 2
is_high = (elem_offset % 2) == 1
# Read packed bytes from the global expert's region
packed_global_offset = expert_global * packed_per_expert + byte_idx
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
tl.int32
)
# Extract 4-bit nibble
# BnB packing: high nibble = even element, low nibble = odd element
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
# NF4 codebook lookup
# Load all 16 codebook values (small, fits in registers)
# Use gather from codebook pointer
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
# Load absmax for this element's quantization block
block_idx = elem_offset // blocksize
absmax_global_offset = expert_global * blocks_per_expert + block_idx
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
# Dequantize: value = codebook[nibble] * absmax
result = code_val * absmax_val
# Store to output
out_offset = expert_local_idx * stride_out_e + elem_offset
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
def selective_dequant_nf4_triton(
packed_data: torch.Tensor,
absmax: torch.Tensor,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
blocksize: int,
dtype: torch.dtype = torch.bfloat16,
codebook: torch.Tensor | None = None,
) -> torch.Tensor:
"""Fused selective gather + NF4 dequantization via Triton kernel.
Args:
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
absmax: Per-block scaling factors [total_blocks]
active_experts: Sorted indices of experts to dequantize [num_active]
expert_shape: (dim1, dim2) per expert
blocksize: Quantization block size
dtype: Output dtype (default bf16)
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
Returns:
Dequantized weights [num_active, dim1, dim2]
"""
num_active = active_experts.shape[0]
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2
blocks_per_expert = expert_numel // blocksize
# Prepare codebook on device
if codebook is None:
codebook = torch.tensor(
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
)
else:
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
# Flatten inputs
packed_flat = packed_data.reshape(-1)
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
# Output buffer
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
_selective_dequant_nf4_kernel[grid](
packed_flat,
absmax_flat,
active_experts,
codebook,
out,
out.stride(0),
num_active=num_active,
packed_per_expert=packed_per_expert,
blocks_per_expert=blocks_per_expert,
blocksize=blocksize,
BLOCK_SIZE=BLOCK_SIZE,
)
return out.reshape(num_active, *expert_shape)

View File

@@ -61,7 +61,16 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
# Prefer text backbone type for VLMs, but fall back to base type
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if (
moe_model_type not in SPARSE_MOE_BLOCK
and cfg.model_config_type in SPARSE_MOE_BLOCK
):
moe_model_type = cfg.model_config_type
if cfg.use_scattermoe:
self._register_kernels()

View File

@@ -640,7 +640,9 @@ class LoRA_QKV(torch.autograd.Function):
del q_weight
del q_weight_t
if A_q is not None and B_q is not None:
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in]
# This is 65x fewer FLOPs than materializing B@A into [out, in]
grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled)
# K path
k_weight_t = dequantize(k_weight, k_quant)
@@ -648,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
del k_weight
del k_weight_t
if A_k is not None and B_k is not None:
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled)
# V path
v_weight_t = dequantize(v_weight, v_quant)
@@ -656,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
del v_weight
del v_weight_t
if A_v is not None and B_v is not None:
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled)
# Transpose gradients if needed
if d_A_q is not None:
@@ -819,7 +821,8 @@ class LoRA_O(torch.autograd.Function):
del W
A, B = A.to(dtype), B.to(dtype)
dX += s * dY @ B @ A
# Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in]
dX.addmm_(torch.mm(dY, B), A, alpha=s)
# W, b, W_quant, A, B, s
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None

View File

@@ -505,6 +505,20 @@ class ModelLoader:
elif not is_ds_zero3:
self.model_kwargs["device_map"] = device_map
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
# so the actual VRAM usage is much less than bf16 estimates.
# When device_map is "auto", accelerate's infer_auto_device_map computes
# the device map at bf16 size (before quantization), causing it to offload
# layers to CPU, which BnB then rejects. Force single-GPU placement to
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
"auto",
None,
):
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"

View File

@@ -571,15 +571,6 @@ class PatchManager:
LOG.info("Patching with xformers attention...")
hijack_llama_attention()
def _patch_llama_sample_packing(self):
"""Apply sample packing patches for LLaMA models."""
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
hijack_llama_prepare_4d_mask()
def _patch_llama_derived_model(self):
"""Modify all llama derived models in one block."""
if self.cfg.is_llama_derived_model and not (
@@ -591,8 +582,6 @@ class PatchManager:
self._patch_llama_flash_attention()
elif self.cfg.xformers_attention:
self._patch_llama_xformers_attention()
elif self.cfg.sample_packing:
self._patch_llama_sample_packing()
elif self.cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."

View File

@@ -221,6 +221,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()

View File

@@ -78,30 +78,21 @@ def patch_parallelism_config():
def patch_prepare_cp():
import functools
import contextlib
import torch
from accelerate import Accelerator
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
@contextlib.contextmanager
def _noop_cp_context(
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
):
yield
self._cp_context = _noop_cp_context
return args
Accelerator._prepare_cp = patched_prepare_cp

View File

@@ -1,24 +0,0 @@
"""
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
"""
from typing import Optional
import torch
from axolotl.monkeypatch.utils import mask_2d_to_4d
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def hijack_expand_mask():
import transformers
transformers.models.llama.modeling_llama._expand_mask = _expand_mask

View File

@@ -1,26 +0,0 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
from axolotl.monkeypatch.utils import (
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
def hijack_llama_prepare_4d_mask():
from transformers import modeling_attn_mask_utils
from transformers.models.llama import modeling_llama
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = (
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = (
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_llama._prepare_4d_causal_attention_mask = (
patched_prepare_4d_causal_attention_mask
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = (
patched_prepare_4d_causal_attention_mask
)

View File

@@ -51,6 +51,29 @@ QKV_PATCHES = [
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
(
"""
query_states, gate = torch.chunk(
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states, gate = torch.chunk(
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
]
ORIGINAL_O_CODE = """
@@ -299,6 +322,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
if hasattr(pretrained_model, "language_model"):
return pretrained_model.language_model.layers
if hasattr(pretrained_model, "model"):
if hasattr(pretrained_model.model, "language_model"):
return pretrained_model.model.language_model.layers
return pretrained_model.model.layers
raise NotImplementedError(

View File

@@ -59,6 +59,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"ministral3",
"mistral4",
"afmoe",
"nemotron",
]

View File

@@ -3,15 +3,10 @@ Shared utils for the monkeypatches
"""
import re
from typing import Optional, Tuple
from typing import Tuple
import torch
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
@@ -170,65 +165,6 @@ def set_module_name(model, name, value):
setattr(parent, child_name, value)
def mask_2d_to_4d(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
return masked_zero_one_mask
def patched_prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def patched_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)

View File

@@ -0,0 +1,96 @@
"""
Synthetic dataset generator for benchmarking and testing.
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
Useful for benchmarking memory usage and speed by sequence length, and for validating
weighted dataset mixes.
YAML configuration example:
datasets:
- path: synthetic
type: _synthetic
length: 1000
sequence_length: 2048
min_input_id: 100
max_input_id: 32000
seed: 42
"""
from typing import Any, Dict, Optional
import numpy as np
from datasets import Dataset
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
def __init__(
self,
sequence_length: int = 2048,
length: int = 1000,
min_input_id: int = 100,
max_input_id: int = 32000,
seed: Optional[int] = None,
):
self.sequence_length = sequence_length
self.length = length
self.min_input_id = min_input_id
self.max_input_id = max_input_id
self.seed = seed
def wrap_dataset(
self,
dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
) -> Dataset:
LOG.info(
f"Generating synthetic dataset: {self.length} samples, "
f"sequence_length={self.sequence_length}, "
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
)
rng = np.random.default_rng(self.seed)
input_ids = rng.integers(
low=self.min_input_id,
high=self.max_input_id,
size=(self.length, self.sequence_length),
).tolist()
attention_mask = [[1] * self.sequence_length] * self.length
# labels == input_ids means we train on all tokens
labels = [row[:] for row in input_ids]
return Dataset.from_dict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
length = ds_cfg.get("length", 1000)
min_input_id = ds_cfg.get("min_input_id", 100)
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
seed = ds_cfg.get("seed", None)
return SyntheticDatasetStrategy(
sequence_length=sequence_length,
length=length,
min_input_id=min_input_id,
max_input_id=max_input_id,
seed=seed,
)

View File

@@ -82,7 +82,7 @@ def setup_model_and_tokenizer(
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
model, peft_config = model_loader.load()
if model.generation_config is not None:
if getattr(model, "generation_config", None) is not None:
model.generation_config.do_sample = True
model_properties = model.config.to_dict()

View File

@@ -17,6 +17,8 @@ from transformers import (
class PytorchProfilerCallback(TrainerCallback):
"""
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
Also runs torch.profiler to produce a Chrome trace for timing analysis.
"""
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
@@ -26,9 +28,10 @@ class PytorchProfilerCallback(TrainerCallback):
if profiler_steps_start == 0:
# start recording memory allocations before everything is allocated, because if we start
# at the beginning of step 0, we won't have any memory allocations in the traces
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
profiler_steps_start = -1
self.profiler_steps_start = profiler_steps_start
self._profiler = None
def on_step_begin(
self,
@@ -38,7 +41,21 @@ class PytorchProfilerCallback(TrainerCallback):
**kwargs,
):
if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
# Start torch.profiler on the first profiled step
if state.global_step == max(self.profiler_steps_start, 0):
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
)
profiler.__enter__()
self._profiler = profiler
def on_step_end(
self,
@@ -55,6 +72,13 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
# Stop and export torch.profiler trace
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None
def on_train_end(
self,
args: TrainingArguments,
@@ -73,3 +97,9 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None

View File

@@ -25,9 +25,11 @@ def toggle_fake_quant(mod: nn.Module, enable: bool):
if (
isinstance(mod, FakeQuantizedLinear)
and mod.activation_fake_quantizer is not None
and hasattr(mod.activation_fake_quantizer, "enabled")
):
mod.activation_fake_quantizer.enabled = enable
mod.weight_fake_quantizer.enabled = enable
if hasattr(mod.weight_fake_quantizer, "enabled"):
mod.weight_fake_quantizer.enabled = enable
class QATCallback(TrainerCallback):

View File

@@ -12,12 +12,11 @@ from transformers import (
TrainingArguments,
)
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state.json"
class TokensPerSecondCallback(TrainerCallback):
"""

View File

@@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
from axolotl.utils.schemas.datasets import (
DPODataset,
KTODataset,
SFTDataset,
SyntheticDataset,
)
LOG = get_logger(__name__)
@@ -308,6 +313,14 @@ def validate_config(
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
elif (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
cfg["datasets"][idx] = SyntheticDataset(
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
)
elif not isinstance(ds_cfg, SFTDataset):
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))

View File

@@ -376,10 +376,14 @@ def _load_and_process_single_dataset(
streaming: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# For synthetic datasets, create a minimal placeholder instead of loading from path
if dataset_config.type == "_synthetic":
dataset = Dataset.from_dict({"text": [""]})
else:
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)

View File

@@ -10,9 +10,11 @@ from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
)
from torchao.quantization.qat.fake_quantize_config import Int4WeightFakeQuantizeConfig
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
)
@@ -173,6 +175,70 @@ def quantize_model(
)
def _make_qat_config(
base_config: AOBaseConfig,
weight_dtype: TorchAOQuantDType,
activation_dtype: TorchAOQuantDType | None,
group_size: int | None,
) -> QATConfig:
"""Build a QATConfig, explicitly constructing fake quantize configs to ensure
group_size and other params are properly propagated (torchao's QATConfig(base_config)
does not always map these correctly)."""
from torchao.quantization.qat.fake_quantize_config import (
Float8FakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
if isinstance(base_config, MXFakeQuantizeConfig):
return QATConfig(
activation_config=base_config,
weight_config=base_config,
)
# Build explicit weight config
weight_fq_config: (
Int4WeightFakeQuantizeConfig
| IntxFakeQuantizeConfig
| Float8FakeQuantizeConfig
| None
) = None
if weight_dtype == TorchAOQuantDType.int4:
gs = (
group_size
if group_size is not None
else getattr(base_config, "group_size", 128)
)
activation_dt = None
if activation_dtype == TorchAOQuantDType.int8:
activation_dt = torch.bfloat16
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
activation_dt = torch.float8_e4m3fn
kwargs = {"group_size": gs}
if activation_dt is not None:
kwargs["activation_dtype"] = activation_dt
weight_fq_config = Int4WeightFakeQuantizeConfig(**kwargs)
elif weight_dtype == TorchAOQuantDType.float8_e4m3fn:
weight_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
# Build explicit activation config
activation_fq_config = None
if activation_dtype == TorchAOQuantDType.int8:
activation_fq_config = IntxFakeQuantizeConfig(
dtype=torch.int8, granularity="per_token", is_symmetric=False
)
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
activation_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
if weight_fq_config is not None:
return QATConfig(
weight_config=weight_fq_config,
activation_config=activation_fq_config,
)
# Fallback to base_config for unhandled combos
return QATConfig(base_config)
def prepare_model_for_qat(
model,
weight_dtype: TorchAOQuantDType,
@@ -200,13 +266,9 @@ def prepare_model_for_qat(
activation_dtype=activation_dtype,
group_size=group_size,
)
if isinstance(base_config, MXFakeQuantizeConfig):
qat_config = QATConfig(
activation_config=base_config,
weight_config=base_config,
)
else:
qat_config = QATConfig(base_config)
qat_config = _make_qat_config(
base_config, weight_dtype, activation_dtype, group_size
)
quantize_(model, qat_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
@@ -215,12 +277,9 @@ def prepare_model_for_qat(
activation_dtype=None,
group_size=group_size,
)
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
embedding_qat_config = QATConfig(
weight_config=embedding_base_config,
)
else:
embedding_qat_config = QATConfig(embedding_base_config)
embedding_qat_config = _make_qat_config(
embedding_base_config, weight_dtype, None, group_size
)
quantize_(
model,
embedding_qat_config,

View File

@@ -2,7 +2,7 @@
import math
from functools import partial
from typing import Sequence
from typing import Any, Sequence
from torch import Tensor
from torch.optim import Optimizer
@@ -340,3 +340,19 @@ class JaggedLRRestartScheduler(LRScheduler):
return [lr * scale for lr in original]
return original * scale
def state_dict(self) -> dict[str, Any]:
"""Return serializable state, saving inner_schedule as its own state_dict."""
state = {
key: value
for key, value in self.__dict__.items()
if key not in ("optimizer", "inner_schedule")
}
state["inner_schedule_state"] = self.inner_schedule.state_dict()
return state
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Restore state, including inner_schedule."""
inner_state = state_dict.pop("inner_schedule_state")
self.__dict__.update(state_dict)
self.inner_schedule.load_state_dict(inner_state)

View File

@@ -22,6 +22,7 @@ from axolotl.utils.schemas.datasets import (
PretrainingDataset,
SFTDataset,
StepwiseSupervisedDataset,
SyntheticDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
@@ -185,7 +186,13 @@ class AxolotlInputConfig(
datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
MinLen(1),
]
| None
@@ -198,7 +205,13 @@ class AxolotlInputConfig(
test_datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
MinLen(1),
]
| None
@@ -433,6 +446,12 @@ class AxolotlInputConfig(
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
},
)
layer_offloading: bool | None = Field(
default=False,
json_schema_extra={
"description": "Offload model layer parameters to CPU during forward, prefetch back during backward."
},
)
unfrozen_parameters: list[str] | None = Field(
default=None,

View File

@@ -296,4 +296,42 @@ class KTODataset(BaseModel):
revision: str | None = None
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
class SyntheticDataset(BaseModel):
"""Synthetic dataset configuration for benchmarking and testing.
Generates datasets with configurable sequence length, dataset size, and token ID
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
validating weighted dataset mixes.
"""
path: Literal["synthetic"] = "synthetic"
type: Literal["_synthetic"] = "_synthetic"
length: int = Field(
default=1000,
json_schema_extra={"description": "Number of rows to generate"},
)
sequence_length: int | None = Field(
default=None,
json_schema_extra={
"description": "Sequence length per row (defaults to sequence_len from config)"
},
)
min_input_id: int = Field(
default=100,
json_schema_extra={"description": "Minimum token ID for generation"},
)
max_input_id: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
},
)
seed: int | None = Field(
default=None,
json_schema_extra={"description": "Random seed for reproducibility"},
)
DatasetConfig = (
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
)

View File

@@ -87,6 +87,11 @@ class CustomSupportedOptimizers(str, Enum):
came_pytorch = "came_pytorch"
muon = "muon"
dion = "dion"
flash_adamw = "flash_adamw"
flash_adam = "flash_adam"
flash_sgd = "flash_sgd"
flash_sgdw = "flash_sgdw"
flash_lion = "flash_lion"
class RingAttnFunc(str, Enum):

View File

@@ -253,6 +253,23 @@ class TrainingValidationMixin:
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before")
@classmethod
def set_reward_model_defaults(cls, data):
if data.get("reward_model"):
if data.get("num_labels") is None:
data["num_labels"] = 1
if not (data.get("type_of_model") or data.get("model_type")):
data["model_type"] = "AutoModelForSequenceClassification"
if data.get("process_reward_model"):
if data.get("num_labels") is None:
data["num_labels"] = 2
if not (data.get("type_of_model") or data.get("model_type")):
data["model_type"] = "AutoModelForTokenClassification"
return data
@model_validator(mode="before")
@classmethod
def check_gas_bsz(cls, data):
@@ -773,6 +790,14 @@ class OptimizationValidationMixin:
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@staticmethod
def _resolve_fsdp_version(data):
"""Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version."""
fsdp_version = data.get("fsdp_version")
if fsdp_version is None:
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
return fsdp_version
@model_validator(mode="before")
@classmethod
def check_muon_deepspeed_fsdp(cls, data):
@@ -782,15 +807,32 @@ class OptimizationValidationMixin:
"Muon optimizer is currently incompatible with DeepSpeed"
)
if data.get("fsdp") or data.get("fsdp_config"):
fsdp_version = data.get("fsdp_version")
if fsdp_version is None:
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
fsdp_version = cls._resolve_fsdp_version(data)
if str(fsdp_version) != "2":
raise ValueError(
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
)
return data
@model_validator(mode="before")
@classmethod
def check_flashoptim_deepspeed_fsdp(cls, data):
optimizer = data.get("optimizer") or ""
if str(optimizer).startswith("flash_"):
if data.get("deepspeed"):
raise ValueError(
f"{optimizer} optimizer is incompatible with DeepSpeed. "
"Flash optimizers only support DDP and FSDP2."
)
if data.get("fsdp") or data.get("fsdp_config"):
fsdp_version = cls._resolve_fsdp_version(data)
if str(fsdp_version) != "2":
raise ValueError(
f"{optimizer} optimizer is only compatible with FSDP2. "
"Set fsdp_version: 2 to use flash optimizers with FSDP."
)
return data
@model_validator(mode="before")
@classmethod
def check_batch_flattening_fa(cls, data):

View File

@@ -15,6 +15,8 @@ import datasets
import pytest
import requests
import torch
import transformers.utils as _transformers_utils
import transformers.utils.import_utils as _import_utils
from huggingface_hub import snapshot_download
from huggingface_hub.errors import LocalEntryNotFoundError
from tokenizers import AddedToken
@@ -29,6 +31,26 @@ from tests.hf_offline_utils import (
logging.getLogger("filelock").setLevel(logging.CRITICAL)
# Shim for deepseek v3
if not hasattr(_import_utils, "is_torch_fx_available"):
def _is_torch_fx_available():
try:
import torch.fx # noqa: F401 # pylint: disable=unused-import
return True
except ImportError:
return False
_import_utils.is_torch_fx_available = _is_torch_fx_available
if not hasattr(_transformers_utils, "is_flash_attn_greater_or_equal_2_10"):
from transformers.utils import is_flash_attn_greater_or_equal as _is_flash_attn_gte
_transformers_utils.is_flash_attn_greater_or_equal_2_10 = lambda: (
_is_flash_attn_gte("2.10")
)
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):

View File

@@ -536,7 +536,7 @@ class TestHFCausalTrainerBuilder:
"cfg_string",
[
"sft_cfg",
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
"rm_cfg",
"prm_cfg",
],
)

View File

@@ -20,6 +20,7 @@ Test strategy:
- Tolerances account for tf32 accumulation in Triton kernels
"""
from functools import wraps
from types import SimpleNamespace
import pytest
@@ -34,6 +35,21 @@ pytestmark = pytest.mark.skipif(
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
def skip_on_out_of_resources(func):
"""Skip test if Triton kernel exceeds GPU shared memory limits."""
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as exc: # pylint: disable=broad-except
if "OutOfResources" in type(exc).__name__:
pytest.skip(f"GPU shared memory too small: {exc}")
raise
return wrapper
# =============================================================================
# Helpers
# =============================================================================
@@ -209,6 +225,7 @@ def make_test_data(
# =============================================================================
@pytest.mark.slow
class TestForwardPass:
"""Test forward pass of fused scatter2scatter_lora kernel."""
@@ -288,6 +305,7 @@ class TestForwardPass:
)
@pytest.mark.slow
class TestForwardGrouped:
"""Test forward pass with grouped_in/grouped_out configurations."""
@@ -377,6 +395,7 @@ class TestForwardGrouped:
# =============================================================================
@pytest.mark.slow
class TestLoRAGradients:
"""Test backward LoRA gradient computation (dA, dB)."""
@@ -452,6 +471,7 @@ class TestLoRAGradients:
# =============================================================================
@pytest.mark.slow
class TestAutograd:
"""Test full autograd integration through ScatterMoELoRA."""
@@ -620,6 +640,7 @@ class TestAutograd:
# =============================================================================
@pytest.mark.slow
class TestBaseEquivalence:
"""When scaling=0, fused kernel should match base scatter2scatter."""
@@ -692,6 +713,7 @@ class TestBaseEquivalence:
# =============================================================================
@pytest.mark.slow
class TestLoRAAdditivity:
"""Test that the LoRA component is correctly additive."""
@@ -749,6 +771,7 @@ class TestLoRAAdditivity:
# =============================================================================
@pytest.mark.slow
class TestParallelExpertsModule:
"""Test the ParallelExperts module with LoRA."""
@@ -816,6 +839,7 @@ class TestParallelExpertsModule:
# =============================================================================
@pytest.mark.slow
class TestEdgeCases:
"""Edge cases and boundary conditions."""
@@ -913,6 +937,7 @@ class TestEdgeCases:
# =============================================================================
@pytest.mark.slow
class TestFusedDX:
"""Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A."""
@@ -980,6 +1005,7 @@ class TestFusedDX:
def test_basic(self):
self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2)
@skip_on_out_of_resources
def test_large(self):
self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2)
@@ -1122,6 +1148,7 @@ class TestFusedDX:
# =============================================================================
@pytest.mark.slow
class TestFusedGatherBackward:
"""Test fused gather + backward dA/dB kernel."""
@@ -1174,6 +1201,7 @@ class TestFusedGatherBackward:
def test_basic(self):
self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2)
@skip_on_out_of_resources
def test_large(self):
self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2)
@@ -1183,6 +1211,7 @@ class TestFusedGatherBackward:
def test_k1(self):
self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1)
@skip_on_out_of_resources
def test_many_experts(self):
self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4)
@@ -1269,6 +1298,8 @@ class TestFusedGatherBackward:
# =============================================================================
@pytest.mark.slow
@pytest.mark.xfail(reason="flaky", strict=False)
class TestTokenRounding:
"""Test token rounding utility and its integration with backward kernels."""
@@ -1315,6 +1346,7 @@ class TestTokenRounding:
)
prev = padded_offsets[e].item()
@skip_on_out_of_resources
def test_round_with_fused_gather(self):
"""Token rounding + fused gather gives same result as plain fused gather."""
from importlib import import_module
@@ -1414,6 +1446,7 @@ class TestTokenRounding:
# =============================================================================
@pytest.mark.slow
class TestCombinedOptimizations:
"""Test all optimizations together."""
@@ -1583,6 +1616,7 @@ def _make_mock_sigmoid_moe_block(
return moe_block, T, H, FF, E, K
@pytest.mark.slow
class TestHFScatterMoESigmoidRouting:
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
@@ -1724,6 +1758,7 @@ class TestHFScatterMoESigmoidRouting:
)
@pytest.mark.slow
class TestHFScatterMoESigmoidWithSharedExperts:
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""

View File

@@ -933,7 +933,7 @@ class TestKernelizeIntegration:
def _get_repo_path():
"""Get the path to scattermoe_lora within axolotl's plugin."""
return (
Path(__file__).parent.parent.parent
Path(__file__).parent.parent.parent.parent
/ "src"
/ "axolotl"
/ "integrations"
@@ -1219,7 +1219,7 @@ class TestSharedExpertHandling:
# Kernelize
repo_path = (
Path(__file__).parent.parent.parent
Path(__file__).parent.parent.parent.parent
/ "src"
/ "axolotl"
/ "integrations"

View File

@@ -86,5 +86,5 @@ class TestPackedFlex:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -37,7 +37,7 @@ def verify_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
train_loss_df = df[df.tag == "train/loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -37,7 +37,7 @@ def verify_fp8_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
train_loss_df = df[df.tag == "train/loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -38,7 +38,7 @@ def verify_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
train_loss_df = df[df.tag == "train/loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -38,7 +38,7 @@ def verify_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
train_loss_df = df[df.tag == "train/loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -94,5 +94,5 @@ class TestMultiGPUGemma3:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.8, "Train Loss (%s) is too high"
)

View File

@@ -90,7 +90,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.8, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.8, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -156,7 +156,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
)
def test_dpo_lora_ddp(self, temp_dir):
@@ -233,7 +233,7 @@ class TestMultiGPULlama:
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
"train/loss",
loss_threshold,
"Train Loss (%s) is too high",
)
@@ -312,7 +312,7 @@ class TestMultiGPULlama:
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
"train/loss",
loss_threshold,
"Train Loss (%s) is too high",
)
@@ -385,7 +385,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -461,7 +461,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_6_0
@@ -543,7 +543,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
@@ -623,7 +623,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -708,7 +708,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.45, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -784,7 +784,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -859,7 +859,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.5, "Train Loss (%s) is too high"
)
@pytest.mark.skip(
@@ -925,5 +925,5 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 4.0, "Train Loss (%s) is too high"
)

View File

@@ -79,7 +79,7 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_7_0
@@ -138,7 +138,7 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_7_0
@@ -205,5 +205,5 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
)

View File

@@ -64,5 +64,5 @@ class TestTensorParallel:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.0, "Train Loss (%s) is too high"
)

View File

@@ -78,5 +78,5 @@ class TestFAXentropyLlama:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -77,5 +77,5 @@ class TestFAFlattening:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -4,8 +4,7 @@ E2E tests for lora llama
import unittest
import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -68,51 +67,3 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
def test_lora_gptq_packed(self, temp_dir):
cfg = DictDefault(
{
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
"gptq_disable_exllama": True,
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"max_steps": 20,
"save_steps": 0.5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import subprocess
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.datasets import load_datasets
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.train import train
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault

View File

@@ -73,7 +73,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
)
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
@@ -124,7 +124,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -180,5 +180,5 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -63,5 +63,5 @@ class TestPackedFlex(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -14,6 +14,9 @@ from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@pytest.mark.skip(
reason="DeepSeek-V3-11M remote model code needs _supports_flash_attn=True for newer transformers"
)
class TestDeepseekV3:
"""
Test case for DeepseekV3 models

View File

@@ -262,6 +262,7 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip(reason="TRL ORPO trainer has internal zip() length mismatch bug")
@with_temp_dir
def test_orpo_lora(self, temp_dir):
cfg = DictDefault(

View File

@@ -57,9 +57,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)
check_tensorboard(temp_dir + "/runs", "train/loss", 2.0, "Loss is too high")
@with_temp_dir
def test_train_w_embedding_lr(self, temp_dir):
@@ -100,6 +98,4 @@ class TestEmbeddingsLrScale(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)
check_tensorboard(temp_dir + "/runs", "train/loss", 2.0, "Loss is too high")

View File

@@ -66,7 +66,7 @@ class TestPretrainLlama:
loss_threshold = 6.5
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
"train/loss",
loss_threshold,
"Train Loss (%s) is too high",
)

View File

@@ -70,7 +70,7 @@ class TestMixtral(unittest.TestCase):
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
@@ -125,7 +125,7 @@ class TestMixtral(unittest.TestCase):
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
@@ -183,7 +183,7 @@ class TestMixtral(unittest.TestCase):
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
@@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase):
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -4,6 +4,8 @@ E2E tests for custom optimizers using Llama
import unittest
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -282,3 +284,60 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@require_torch_2_7_0
@pytest.mark.parametrize(
"optimizer_name,expected_class,learning_rate",
[
("flash_adamw", "FlashAdamW", 0.00001),
("flash_adam", "FlashAdam", 0.00001),
("flash_sgd", "FlashSGD", 0.01),
("flash_sgdw", "FlashSGDW", 0.01),
("flash_lion", "FlashLion", 0.0001),
],
)
def test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate):
pytest.importorskip("flashoptim")
temp_dir = str(tmp_path)
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": learning_rate,
"optimizer": optimizer_name,
"max_steps": 5,
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert trainer.optimizer.optimizer.__class__.__name__ == expected_class

View File

@@ -62,5 +62,5 @@ class TestPackedLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
)

Some files were not shown because too many files have changed in this diff Show More