Compare commits
22 Commits
scattermoe
...
scattermoe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
936149380f | ||
|
|
86be9f329e | ||
|
|
0e583efeaa | ||
|
|
b3289fd190 | ||
|
|
a67392c427 | ||
|
|
5b2e3f00ce | ||
|
|
fc3b3d1d4e | ||
|
|
c9df6efdc2 | ||
|
|
0ee98a0309 | ||
|
|
2c05847a5f | ||
|
|
b0294b3427 | ||
|
|
1bcfc08c90 | ||
|
|
5a5cf30b26 | ||
|
|
7ddfb2d8a0 | ||
|
|
c57acef2c7 | ||
|
|
038ffe3f26 | ||
|
|
c13cb7c853 | ||
|
|
b3823cc6b0 | ||
|
|
113d275bd9 | ||
|
|
7920fe74ec | ||
|
|
1fc86d5295 | ||
|
|
bb483ad4c4 |
@@ -128,11 +128,9 @@ quartodoc:
|
||||
- monkeypatch.mistral_attn_hijack_flash
|
||||
- monkeypatch.multipack
|
||||
- monkeypatch.relora
|
||||
- monkeypatch.llama_expand_mask
|
||||
- monkeypatch.lora_kernels
|
||||
- monkeypatch.utils
|
||||
- monkeypatch.btlm_attn_hijack_flash
|
||||
- monkeypatch.llama_patch_multipack
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
|
||||
284
benchmarks/bench_scattermoe_lora.py
Normal file
284
benchmarks/bench_scattermoe_lora.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""Benchmark for ScatterMoE LoRA Triton kernels.
|
||||
|
||||
Measures forward, backward dX, and backward dA/dB kernels at common MoE
|
||||
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
|
||||
and full fwd+bwd autograd throughput.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||
ScatterMoELoRA,
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
WARMUP = 5
|
||||
ITERS = 20
|
||||
|
||||
# ─── Model configs ──────────────────────────────────────────────────────────
|
||||
|
||||
BUILTIN_CONFIGS = {
|
||||
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
||||
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
||||
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_config(spec):
|
||||
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
|
||||
key = spec.lower().replace("/", "-")
|
||||
for name, cfg in BUILTIN_CONFIGS.items():
|
||||
if key in name.lower() or name.lower() in key:
|
||||
return name, cfg
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
||||
if callable(getattr(hf_cfg, "get_text_config", None)):
|
||||
tc = hf_cfg.get_text_config()
|
||||
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
||||
hf_cfg = tc
|
||||
hidden = hf_cfg.hidden_size
|
||||
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||
experts = (
|
||||
getattr(hf_cfg, "num_experts", None)
|
||||
or getattr(hf_cfg, "num_local_experts", None)
|
||||
or getattr(hf_cfg, "n_routed_experts", None)
|
||||
)
|
||||
top_k = (
|
||||
getattr(hf_cfg, "num_experts_per_tok", None)
|
||||
or getattr(hf_cfg, "num_experts_per_token", None)
|
||||
or 2
|
||||
)
|
||||
name = spec.split("/")[-1]
|
||||
return name, (experts, hidden, inter, top_k)
|
||||
|
||||
|
||||
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _clean():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def _bench(fn, warmup=WARMUP, iters=ITERS):
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - t0) * 1000)
|
||||
times.sort()
|
||||
return times[len(times) // 2]
|
||||
|
||||
|
||||
def _setup(num_experts, K, N, T, top_k, R):
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, num_experts, device=DEVICE)
|
||||
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
|
||||
gx = base_ops.group(x, ssi, fan_out=top_k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
||||
|
||||
|
||||
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
|
||||
|
||||
|
||||
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _call_base(x, W, sei, ssi, top_k):
|
||||
return base_ops.scatter2scatter(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def _call_dx(dy, W, sei, ssi, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
|
||||
|
||||
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
|
||||
return lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=num_experts,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
"-m",
|
||||
nargs="+",
|
||||
help="Model names or HF IDs (default: all builtins)",
|
||||
)
|
||||
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
||||
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
|
||||
T = args.seq_len
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"T={T}, ranks={args.ranks}\n")
|
||||
|
||||
if args.models:
|
||||
configs = [_resolve_config(m) for m in args.models]
|
||||
else:
|
||||
configs = list(BUILTIN_CONFIGS.items())
|
||||
|
||||
for model_name, (num_experts, hidden, inter, top_k) in configs:
|
||||
print(f"{'=' * 70}")
|
||||
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
for R in args.ranks:
|
||||
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
|
||||
_clean()
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
|
||||
num_experts, K, N, T, top_k, R
|
||||
)
|
||||
|
||||
# Forward with LoRA (auto-dispatched: fused or split)
|
||||
dispatch = (
|
||||
"split"
|
||||
if (
|
||||
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
)
|
||||
else "fused"
|
||||
)
|
||||
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
|
||||
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
|
||||
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
|
||||
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
|
||||
|
||||
total = t_fwd + t_dx + t_bwd
|
||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||
|
||||
print(
|
||||
f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead * 100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms"
|
||||
)
|
||||
|
||||
# Full autograd fwd+bwd with memory measurement
|
||||
x_ag = x.clone().requires_grad_(True)
|
||||
lA_ag = lA.clone().requires_grad_(True)
|
||||
lB_ag = lB.clone().requires_grad_(True)
|
||||
|
||||
def _run_autograd(
|
||||
_x=x_ag,
|
||||
_W=W,
|
||||
_k=top_k,
|
||||
_sei=sei,
|
||||
_ssi=ssi,
|
||||
_eo=eo,
|
||||
_lA=lA_ag,
|
||||
_lB=lB_ag,
|
||||
):
|
||||
out = ScatterMoELoRA.apply(
|
||||
_x,
|
||||
_W,
|
||||
_k,
|
||||
_sei,
|
||||
_ssi,
|
||||
_eo,
|
||||
_lA,
|
||||
_lB,
|
||||
2.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
_x.grad = None
|
||||
_lA.grad = None
|
||||
_lB.grad = None
|
||||
|
||||
t_full = _bench(_run_autograd)
|
||||
|
||||
_clean()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
_run_autograd()
|
||||
torch.cuda.synchronize()
|
||||
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||
|
||||
print(
|
||||
f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
12
cicd/cicd.sh
12
cicd/cicd.sh
@@ -3,11 +3,13 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
hf download "NousResearch/Meta-Llama-3-8B"
|
||||
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
hf download "microsoft/Phi-4-reasoning"
|
||||
hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
set -o pipefail
|
||||
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
# hf download "microsoft/Phi-4-reasoning"
|
||||
# hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
# hf download "microsoft/Phi-3-medium-128k-instruct"
|
||||
|
||||
# Run unit tests with initial coverage report
|
||||
pytest -v --durations=10 -n8 \
|
||||
|
||||
@@ -68,10 +68,6 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
try:
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
print(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
return exit_code
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Command '{cmd}' failed with exception {e}")
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
|
||||
@@ -37,6 +37,7 @@ coverage:
|
||||
only_pulls: false
|
||||
flags: null
|
||||
paths: null
|
||||
informational: true
|
||||
|
||||
parsers:
|
||||
gcov:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Gradient Checkpointing and Activation Offloading
|
||||
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
|
||||
---
|
||||
|
||||
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
||||
@@ -27,3 +27,33 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
|
||||
|
||||
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
||||
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
||||
|
||||
### Enabling Layer Offloading
|
||||
|
||||
```yaml
|
||||
layer_offloading: true
|
||||
```
|
||||
|
||||
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
|
||||
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
|
||||
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
|
||||
trainable adapter weights stay on GPU permanently.
|
||||
|
||||
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
|
||||
|
||||
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
|
||||
prefetched asynchronously on a separate CUDA stream for overlap.
|
||||
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
|
||||
previous layer is prefetched.
|
||||
|
||||
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
|
||||
|
||||
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
|
||||
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
|
||||
that is kept on GPU at any given time.
|
||||
|
||||
**Requirements:**
|
||||
|
||||
- CUDA GPU (CPU-only training is not supported for this feature)
|
||||
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
|
||||
- Best combined with LoRA/QLoRA where most parameters are frozen
|
||||
|
||||
@@ -20,6 +20,7 @@ format:
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [Qwen3.5](#sec-qwen3-5)
|
||||
- [GLM-4.6V](#sec-glm-4-6v)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
@@ -191,6 +192,14 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### Qwen3.5 {#sec-qwen3-5}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
|
||||
chat_template: qwen3_5
|
||||
```
|
||||
|
||||
### GLM-4.6V {#sec-glm-4-6v}
|
||||
|
||||
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
|
||||
|
||||
@@ -54,6 +54,13 @@ These techniques save VRAM by changing how activations are handled.
|
||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||
|
||||
### Layer Offloading
|
||||
|
||||
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
|
||||
|
||||
- **Config:** `layer_offloading: true`
|
||||
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -27,6 +24,11 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -27,6 +24,11 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
@@ -24,6 +20,11 @@ dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
|
||||
@@ -6,9 +6,6 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
Note: Training this model requires weights in BF16 which we will link to later.
|
||||
Users interested in training can convert / descale the existing FP8 weights.
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
|
||||
57
examples/nemotron/nemotron-mini-4b-qlora.yaml
Normal file
57
examples/nemotron/nemotron-mini-4b-qlora.yaml
Normal file
@@ -0,0 +1,57 @@
|
||||
base_model: nvidia/Nemotron-Mini-4B-Instruct
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/nemotron-mini-4b-qlora
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
special_tokens:
|
||||
84
examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml
Normal file
84
examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml
Normal file
@@ -0,0 +1,84 @@
|
||||
base_model: Qwen/Qwen3.5-122B-A10B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: true
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -32,7 +32,11 @@ lora_target_modules:
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
#lora_target_parameters:
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
@@ -52,7 +56,6 @@ learning_rate: 0.0002
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
59
examples/qwen3.5/27b-fft.yaml
Normal file
59
examples/qwen3.5/27b-fft.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
# Freeze vision encoder
|
||||
unfrozen_parameters:
|
||||
- model\.language_model\..*
|
||||
- lm_head\..*
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
81
examples/qwen3.5/27b-qlora-fsdp.yaml
Normal file
81
examples/qwen3.5/27b-qlora-fsdp.yaml
Normal file
@@ -0,0 +1,81 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
# Uncomment below to also target the linear attention projections.
|
||||
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
|
||||
# - linear_attn.in_proj_qkv
|
||||
# - linear_attn.in_proj_z
|
||||
# - linear_attn.out_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,9 +1,7 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
|
||||
# the text-only path. For multimodal (image+text) fine-tuning, add image
|
||||
# columns to your dataset following axolotl's multimodal dataset format.
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
85
examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml
Normal file
85
examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml
Normal file
@@ -0,0 +1,85 @@
|
||||
base_model: Qwen/Qwen3.5-35B-A3B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: true
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -32,7 +32,11 @@ lora_target_modules:
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
#lora_target_parameters:
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
|
||||
49
examples/qwen3.5/9b-fft-vision.yaml
Normal file
49
examples/qwen3.5/9b-fft-vision.yaml
Normal file
@@ -0,0 +1,49 @@
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Required for multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,10 +1,6 @@
|
||||
base_model: Qwen/Qwen3.5-7B
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
|
||||
# Vision and text tokens are processed together by the same transformer layers.
|
||||
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
|
||||
|
||||
# These 3 lines are required for vision/multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
@@ -30,8 +26,6 @@ lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
# Targets the language model attention and MLP layers.
|
||||
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
|
||||
# the same transformer stack, so standard attention targets work for both modalities.
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
@@ -1,15 +1,6 @@
|
||||
# Finetune Qwen3.5 with Axolotl
|
||||
|
||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
|
||||
|
||||
Available configs:
|
||||
|
||||
| Config | Model | Type |
|
||||
|---|---|---|
|
||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
|
||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
|
||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
|
||||
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
|
||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
|
||||
|
||||
## Getting started
|
||||
|
||||
@@ -18,35 +9,69 @@ Available configs:
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
```
|
||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||
|
||||
4. Pick any config from the table below and run:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3.5/<config>.yaml
|
||||
```
|
||||
|
||||
Available configs:
|
||||
|
||||
| Config | Model | Type | Peak VRAM |
|
||||
|---|---|---|---|
|
||||
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
||||
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
|
||||
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
|
||||
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
|
||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
||||
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
|
||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
||||
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
|
||||
|
||||
### Gated DeltaNet Linear Attention
|
||||
|
||||
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
|
||||
|
||||
```yaml
|
||||
lora_target_modules:
|
||||
# ... standard projections ...
|
||||
- linear_attn.in_proj_qkv
|
||||
- linear_attn.in_proj_z
|
||||
- linear_attn.out_proj
|
||||
```
|
||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||
|
||||
4. Run a finetuning example:
|
||||
### Routed Experts (MoE)
|
||||
|
||||
```bash
|
||||
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
|
||||
axolotl train examples/qwen3.5/27b-qlora.yaml
|
||||
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
|
||||
|
||||
# MoE 35B-A3B text-only (QLoRA)
|
||||
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- mlp.experts.gate_up_proj
|
||||
- mlp.experts.down_proj
|
||||
# - mlp.gate.weight # router
|
||||
```
|
||||
|
||||
# MoE 122B-A10B text-only (QLoRA)
|
||||
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
|
||||
### Shared Experts (MoE)
|
||||
|
||||
# 7B vision+text (LoRA, multimodal dataset)
|
||||
axolotl train examples/qwen3.5/7b-lora-vision.yaml
|
||||
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
|
||||
|
||||
```yaml
|
||||
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||
- For inference hyp, please see the respective model card details.
|
||||
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
|
||||
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
|
||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
|
||||
@@ -61,5 +61,11 @@ skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
docstring-code-format = false
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-m 'not slow'"
|
||||
markers = [
|
||||
"slow: marks tests as slow",
|
||||
]
|
||||
|
||||
[tool.uv.extra-build-dependencies]
|
||||
axolotl = ["huggingface_hub"]
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
|
||||
)
|
||||
|
||||
13
setup.py
13
setup.py
@@ -81,16 +81,23 @@ def parse_requirements(extras_require_map):
|
||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
||||
)
|
||||
|
||||
if (major, minor) >= (2, 9):
|
||||
if (major, minor) >= (2, 10):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.5.0",
|
||||
"fbgemm-gpu-genai==1.5.0",
|
||||
]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm==0.17.1"]
|
||||
elif (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.4.0",
|
||||
"fbgemm-gpu-genai==1.4.2",
|
||||
]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
if patch == 0:
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpcore
|
||||
from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
@@ -47,7 +48,7 @@ def check_user_token() -> bool:
|
||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except HTTPError:
|
||||
except (HTTPError, httpcore.ConnectError):
|
||||
LOG.warning(
|
||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||
)
|
||||
|
||||
@@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC):
|
||||
adam_kwargs["eps"] = (eps1, eps2)
|
||||
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_adamw":
|
||||
from flashoptim import FlashAdamW
|
||||
|
||||
optimizer_cls = FlashAdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_adam":
|
||||
from flashoptim import FlashAdam
|
||||
|
||||
optimizer_cls = FlashAdam
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_sgd":
|
||||
from flashoptim import FlashSGD
|
||||
|
||||
optimizer_cls = FlashSGD
|
||||
elif self.cfg.optimizer == "flash_sgdw":
|
||||
from flashoptim import FlashSGDW
|
||||
|
||||
optimizer_cls = FlashSGDW
|
||||
elif self.cfg.optimizer == "flash_lion":
|
||||
from flashoptim import FlashLion
|
||||
|
||||
optimizer_cls = FlashLion
|
||||
if "betas" in adam_kwargs:
|
||||
optimizer_kwargs["betas"] = adam_kwargs["betas"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
|
||||
@@ -484,6 +508,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.layer_offloading:
|
||||
training_args_kwargs["layer_offloading"] = True
|
||||
if self.cfg.activation_offloading is True:
|
||||
# don't use the HF gradient checkpointing, manually wrap
|
||||
training_args_kwargs["gradient_checkpointing"] = False
|
||||
|
||||
@@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
|
||||
# config reflects this regardless of how the model was instantiated.
|
||||
if (
|
||||
self.cfg.reward_model
|
||||
and getattr(self.model.config, "num_labels", None) != 1
|
||||
):
|
||||
self.model.config.num_labels = 1
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
|
||||
@@ -208,7 +208,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.eval_dataset:
|
||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and self.peft_config
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
|
||||
):
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
|
||||
@@ -29,10 +29,12 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
||||
from trl.experimental.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||
from axolotl.core.trainers.mixins import (
|
||||
ActivationOffloadingMixin,
|
||||
CheckpointSaveMixin,
|
||||
DistributedParallelMixin,
|
||||
LayerOffloadingMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
@@ -51,8 +53,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state."
|
||||
|
||||
REDUCTION_FNS = {
|
||||
"mean": torch.mean,
|
||||
"min": torch.min,
|
||||
@@ -67,6 +67,7 @@ class AxolotlTrainer(
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
LayerOffloadingMixin,
|
||||
ActivationOffloadingMixin,
|
||||
DistributedParallelMixin,
|
||||
Trainer,
|
||||
|
||||
1
src/axolotl/core/trainers/constants.py
Normal file
1
src/axolotl/core/trainers/constants.py
Normal file
@@ -0,0 +1 @@
|
||||
TOKENS_STATE_FILE = "tokens_state.json"
|
||||
@@ -2,7 +2,8 @@
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
@@ -16,3 +17,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
|
||||
dpo_norm_loss: bool | None = False
|
||||
rpo_alpha: Optional[float] = field(default=None)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from .activation_checkpointing import ActivationOffloadingMixin
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .layer_offloading import LayerOffloadingMixin
|
||||
from .distributed_parallel import DistributedParallelMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
|
||||
304
src/axolotl/core/trainers/mixins/layer_offloading.py
Normal file
304
src/axolotl/core/trainers/mixins/layer_offloading.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Trainer mixin for layer-wise parameter offloading to CPU.
|
||||
|
||||
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
|
||||
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
|
||||
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
|
||||
|
||||
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
|
||||
transfer stream), post-hook offloads layer N-1's frozen params.
|
||||
Backward: same in reverse order.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
|
||||
"""Recursively search the model for the decoder layer ModuleList.
|
||||
|
||||
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
|
||||
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
|
||||
where layers are at model.language_model.layers).
|
||||
"""
|
||||
# BFS to find the first ModuleList containing decoder layers
|
||||
queue = [model]
|
||||
while queue:
|
||||
m = queue.pop(0)
|
||||
for _name, child in m.named_children():
|
||||
if isinstance(child, nn.ModuleList) and len(child) > 0:
|
||||
first_type = type(child[0]).__name__
|
||||
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
|
||||
layer_types = list({type(layer).__name__ for layer in child})
|
||||
return child, layer_types
|
||||
else:
|
||||
queue.append(child)
|
||||
|
||||
return None, []
|
||||
|
||||
|
||||
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
|
||||
"""Get all non-trainable parameters in a layer."""
|
||||
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
|
||||
|
||||
|
||||
class LayerOffloadManager:
|
||||
"""Manages offloading frozen decoder layer params to CPU and streaming
|
||||
them back during forward/backward with CUDA stream overlap.
|
||||
|
||||
Only frozen (requires_grad=False) parameters are offloaded.
|
||||
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
num_prefetch: int = 1,
|
||||
):
|
||||
self.model = model
|
||||
self.num_prefetch = num_prefetch
|
||||
self._hooks: list = []
|
||||
self._device = None
|
||||
|
||||
# Find decoder layers
|
||||
self.layers, layer_types = _find_decoder_layers(model)
|
||||
if self.layers is None:
|
||||
LOG.warning(
|
||||
"LayerOffloadManager: no decoder layers found, offloading disabled"
|
||||
)
|
||||
self.enabled = False
|
||||
return
|
||||
|
||||
self.enabled = True
|
||||
self.n_layers = len(self.layers)
|
||||
LOG.info(
|
||||
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
|
||||
)
|
||||
|
||||
# Determine GPU device
|
||||
for p in model.parameters():
|
||||
if p.device.type == "cuda":
|
||||
self._device = p.device
|
||||
break
|
||||
if self._device is None:
|
||||
LOG.warning("LayerOffloadManager: no CUDA parameters found")
|
||||
self.enabled = False
|
||||
return
|
||||
|
||||
# Transfer stream for async prefetch
|
||||
self._transfer_stream = torch.cuda.Stream(device=self._device)
|
||||
|
||||
# Track which layers have their frozen params on GPU
|
||||
self._on_gpu: set[int] = set(range(self.n_layers))
|
||||
|
||||
# Cache: frozen param references per layer (list of (name, param) tuples)
|
||||
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
|
||||
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
|
||||
]
|
||||
|
||||
# CPU storage: pinned tensors for each layer's frozen params
|
||||
# Populated on first offload
|
||||
self._cpu_data: list[dict[str, torch.Tensor]] = [
|
||||
{} for _ in range(self.n_layers)
|
||||
]
|
||||
|
||||
# Offload all layers upfront
|
||||
self._offload_all()
|
||||
|
||||
# Release cached memory blocks back to the driver
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _offload_all(self):
|
||||
"""Move all frozen params in all decoder layers to CPU."""
|
||||
mem_before = torch.cuda.memory_allocated(self._device)
|
||||
for i in range(self.n_layers):
|
||||
self._offload_layer(i)
|
||||
mem_after = torch.cuda.memory_allocated(self._device)
|
||||
freed = (mem_before - mem_after) / 1e6
|
||||
LOG.info(
|
||||
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
|
||||
f"freed {freed:.0f} MB GPU memory"
|
||||
)
|
||||
|
||||
def _offload_layer(self, idx: int):
|
||||
"""Move frozen params of layer idx to CPU pinned memory."""
|
||||
if idx not in self._on_gpu:
|
||||
return
|
||||
for name, param in self._frozen_params[idx]:
|
||||
if param.device.type != "cuda":
|
||||
continue
|
||||
# Allocate pinned CPU tensor on first offload
|
||||
if name not in self._cpu_data[idx]:
|
||||
self._cpu_data[idx][name] = torch.empty_like(
|
||||
param.data, device="cpu", pin_memory=True
|
||||
)
|
||||
cpu_buf = self._cpu_data[idx][name]
|
||||
# Async copy GPU -> CPU (on transfer stream for overlap)
|
||||
cpu_buf.copy_(param.data, non_blocking=True)
|
||||
# Point parameter at a dummy CPU tensor to free GPU memory
|
||||
param.data = cpu_buf
|
||||
self._on_gpu.discard(idx)
|
||||
|
||||
def _load_layer(self, idx: int, stream=None):
|
||||
"""Move frozen params of layer idx back to GPU."""
|
||||
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
|
||||
return
|
||||
ctx = (
|
||||
torch.cuda.stream(stream)
|
||||
if stream is not None
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with ctx:
|
||||
for _name, param in self._frozen_params[idx]:
|
||||
if param.device.type == "cuda":
|
||||
continue
|
||||
gpu_data = param.data.to(self._device, non_blocking=True)
|
||||
param.data = gpu_data
|
||||
self._on_gpu.add(idx)
|
||||
|
||||
def _prefetch_layer(self, idx: int):
|
||||
"""Async prefetch layer idx on the transfer stream."""
|
||||
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
|
||||
return
|
||||
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
|
||||
self._load_layer(idx, stream=self._transfer_stream)
|
||||
|
||||
def _wait_transfer(self):
|
||||
"""Make default stream wait for any in-flight transfers."""
|
||||
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
|
||||
|
||||
def setup_hooks(self):
|
||||
"""Register forward and backward hooks on each decoder layer."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
for idx in range(self.n_layers):
|
||||
layer = self.layers[idx]
|
||||
|
||||
def make_pre_fwd(i):
|
||||
def hook(module, args):
|
||||
# Ensure this layer is on GPU
|
||||
if i not in self._on_gpu:
|
||||
self._load_layer(i)
|
||||
self._wait_transfer()
|
||||
# Prefetch next layer(s)
|
||||
for offset in range(1, self.num_prefetch + 1):
|
||||
self._prefetch_layer(i + offset)
|
||||
|
||||
return hook
|
||||
|
||||
def make_post_fwd(i):
|
||||
def hook(module, args, output):
|
||||
# Offload previous layer (no longer needed in forward)
|
||||
if i > 0:
|
||||
self._offload_layer(i - 1)
|
||||
# Offload last layer after forward
|
||||
if i == self.n_layers - 1:
|
||||
self._offload_layer(i)
|
||||
|
||||
return hook
|
||||
|
||||
def make_pre_bwd(i):
|
||||
def hook(module, grad_output):
|
||||
# Load this layer for backward
|
||||
if i not in self._on_gpu:
|
||||
self._load_layer(i)
|
||||
self._wait_transfer()
|
||||
# Prefetch previous layer(s)
|
||||
for offset in range(1, self.num_prefetch + 1):
|
||||
self._prefetch_layer(i - offset)
|
||||
|
||||
return hook
|
||||
|
||||
def make_post_bwd(i):
|
||||
def hook(module, grad_input, grad_output):
|
||||
# Offload the layer above
|
||||
if i < self.n_layers - 1:
|
||||
self._offload_layer(i + 1)
|
||||
# Offload first layer after backward
|
||||
if i == 0:
|
||||
self._offload_layer(i)
|
||||
|
||||
return hook
|
||||
|
||||
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
|
||||
h2 = layer.register_forward_hook(make_post_fwd(idx))
|
||||
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
|
||||
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
|
||||
self._hooks.extend([h1, h2, h3, h4])
|
||||
|
||||
def remove_hooks(self):
|
||||
"""Remove all hooks and restore layers to GPU."""
|
||||
for h in self._hooks:
|
||||
h.remove()
|
||||
self._hooks.clear()
|
||||
if self.enabled:
|
||||
for i in range(self.n_layers):
|
||||
if i not in self._on_gpu:
|
||||
self._load_layer(i)
|
||||
|
||||
def pre_step(self):
|
||||
"""Called before each training step — ensure layers start offloaded."""
|
||||
if not self.enabled:
|
||||
return
|
||||
for i in list(self._on_gpu):
|
||||
self._offload_layer(i)
|
||||
# Prefetch layer 0 for forward
|
||||
self._prefetch_layer(0)
|
||||
|
||||
def post_step(self):
|
||||
"""Called after each training step — ensure layers are offloaded."""
|
||||
if not self.enabled:
|
||||
return
|
||||
for i in list(self._on_gpu):
|
||||
self._offload_layer(i)
|
||||
# Prefetch layer 0 for next step
|
||||
self._prefetch_layer(0)
|
||||
|
||||
|
||||
class _LayerOffloadContext:
|
||||
"""Context manager wrapping pre_step / post_step around a training step."""
|
||||
|
||||
def __init__(self, manager: LayerOffloadManager):
|
||||
self.manager = manager
|
||||
|
||||
def __enter__(self):
|
||||
self.manager.pre_step()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.manager.post_step()
|
||||
|
||||
|
||||
class LayerOffloadingMixin(Trainer):
|
||||
"""
|
||||
Trainer mixin class for layer-wise parameter offloading to CPU.
|
||||
|
||||
Offloads frozen decoder layer params to CPU at init, then streams them
|
||||
on/off GPU one layer at a time during each training step.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if getattr(self.args, "layer_offloading", False):
|
||||
LOG.info("Layer parameter offloading enabled")
|
||||
self._layer_offload_manager = LayerOffloadManager(
|
||||
model=self.model,
|
||||
num_prefetch=1,
|
||||
)
|
||||
self._layer_offload_manager.setup_hooks()
|
||||
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
|
||||
else:
|
||||
self._layer_offload_manager = None
|
||||
self._layer_offload_ctx = contextlib.nullcontext()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
with self._layer_offload_ctx:
|
||||
return super().training_step(*args, **kwargs)
|
||||
@@ -235,6 +235,13 @@ class AxolotlTrainingMixins:
|
||||
metadata={"help": "Use activation offloading with CUDA streams for training."},
|
||||
)
|
||||
|
||||
layer_offloading: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ SPARSE_MOE_BLOCK = {
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
||||
"qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock",
|
||||
"qwen3_next": "Qwen3NextSparseMoeBlock",
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
|
||||
@@ -35,6 +36,8 @@ SPARSE_MOE_BLOCK = {
|
||||
"glm4v_moe": "Glm4vMoeTextMoE",
|
||||
# sigmoid -> topk routing (no group selection)
|
||||
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
||||
# sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj)
|
||||
"nemotron_h": "NemotronHMoE",
|
||||
# Models below need custom routing (not yet implemented):
|
||||
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
||||
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||
@@ -58,7 +61,16 @@ def resolve_moe_block_classes(model_type: str):
|
||||
|
||||
cls_names = entry if isinstance(entry, list) else [entry]
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = importlib.import_module(module_path)
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ModuleNotFoundError:
|
||||
# Text sub-model types (e.g. qwen3_5_moe_text) share the parent module
|
||||
if model_type.endswith("_text"):
|
||||
parent_type = model_type.removesuffix("_text")
|
||||
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
|
||||
module = importlib.import_module(module_path)
|
||||
else:
|
||||
raise
|
||||
|
||||
classes = []
|
||||
for cls_name in cls_names:
|
||||
|
||||
@@ -195,6 +195,36 @@ def _estimate_smem_usage(
|
||||
_SMEM_SLACK = 10_000
|
||||
|
||||
|
||||
def _estimate_register_pressure(
|
||||
num_warps: int,
|
||||
*tile_sizes: tuple[int, int],
|
||||
) -> float:
|
||||
"""Rough estimate of per-thread register footprint from live tile sizes.
|
||||
|
||||
This is a heuristic, NOT an accurate register count. Triton uses tensor
|
||||
core MMA fragments that pack multiple elements per register, and can spill
|
||||
to local memory when the hardware limit (255 regs/thread) is exceeded.
|
||||
|
||||
The estimate is used to prune only truly extreme configs that would cause
|
||||
excessive spilling or compilation failures. The threshold is set high
|
||||
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
|
||||
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
|
||||
(est ~520) work fine in practice via spilling.
|
||||
|
||||
Returns estimated registers per thread.
|
||||
"""
|
||||
# Each thread in a warp holds ~1/32 of the tile elements
|
||||
tile_regs = sum(r * c for r, c in tile_sizes) / 32
|
||||
scalar_overhead = 40
|
||||
return tile_regs + scalar_overhead
|
||||
|
||||
|
||||
# Soft limit for register pressure pruning. Only prune configs with extreme
|
||||
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
|
||||
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
|
||||
_MAX_REGS_SOFT_LIMIT = 1024
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Forward Kernel: scatter2scatter with fused LoRA
|
||||
# =============================================================================
|
||||
@@ -313,12 +343,11 @@ def _compute_expert_block_lora(
|
||||
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
|
||||
) # [BLOCK_N, BLOCK_R]
|
||||
|
||||
# Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16)
|
||||
# Both operands must match; cast to float32 (accumulator type) for precision.
|
||||
b_f32 = b.to(tl.float32)
|
||||
# tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype
|
||||
b_inp = b.to(INPUT_DTYPE)
|
||||
|
||||
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
|
||||
lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32)
|
||||
lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)
|
||||
|
||||
acc += scaling * lora_out
|
||||
return acc
|
||||
@@ -327,28 +356,29 @@ def _compute_expert_block_lora(
|
||||
def _scatter2scatter_lora_configs():
|
||||
"""Generate forward kernel autotune configs.
|
||||
|
||||
Search space includes smaller tile sizes and fewer pipeline stages to
|
||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||
Search space includes BLOCK_M to allow trading token-tile size for
|
||||
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
|
||||
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
|
||||
(4× fewer inner-loop iterations).
|
||||
|
||||
Search space:
|
||||
BLOCK_N: {32, 64, 128, 256}
|
||||
BLOCK_M: {32, 64, 128}
|
||||
BLOCK_N: {32, 64}
|
||||
BLOCK_K: {32, 64, 128}
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
|
||||
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
|
||||
scatter2scatter pattern).
|
||||
"""
|
||||
configs = []
|
||||
for block_n, block_k, warps, stages in product(
|
||||
[32, 64, 128, 256], # BLOCK_N
|
||||
for block_m, block_n, block_k, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64], # BLOCK_N
|
||||
[32, 64, 128], # BLOCK_K
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_N": block_n, "BLOCK_K": block_k},
|
||||
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
|
||||
num_stages=stages,
|
||||
num_warps=warps,
|
||||
)
|
||||
@@ -357,7 +387,7 @@ def _scatter2scatter_lora_configs():
|
||||
|
||||
|
||||
def _prune_fwd_configs(configs, named_args, **kwargs):
|
||||
"""Prune forward configs based on SMEM capacity.
|
||||
"""Prune forward configs based on SMEM capacity and register pressure.
|
||||
|
||||
The forward kernel inner loop loads three tiles per pipeline stage:
|
||||
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
|
||||
@@ -373,23 +403,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
|
||||
|
||||
scored = []
|
||||
for config in configs:
|
||||
block_m = config.kwargs["BLOCK_M"]
|
||||
block_n = config.kwargs["BLOCK_N"]
|
||||
block_k = config.kwargs["BLOCK_K"]
|
||||
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
|
||||
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
|
||||
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
|
||||
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
|
||||
smem_lora_loop = config.num_stages * block_r * block_k * 2
|
||||
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
|
||||
smem_lora_epilogue = block_n * block_r * 2
|
||||
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
||||
|
||||
# Register pressure: live tiles are acc[M,N], xa_acc[M,R],
|
||||
# x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_m, block_n), # acc
|
||||
(block_m, block_r), # xa_acc
|
||||
(block_m, block_k), # x tile
|
||||
(block_k, block_n), # w tile
|
||||
(block_r, block_k), # a tile
|
||||
(block_n, block_r), # b tile (epilogue)
|
||||
)
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
|
||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||
if pruned:
|
||||
return pruned
|
||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
if scored:
|
||||
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
# All configs pruned by register pressure — fall back to smallest tiles
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: (
|
||||
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"]
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -531,6 +587,89 @@ def _scatter2scatter_lora(
|
||||
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||
|
||||
|
||||
def _scatter2scatter_lora_split(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
k: int,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
scaling: float,
|
||||
b: Optional[torch.Tensor] = None,
|
||||
x_grouped: bool = False,
|
||||
y_grouped: bool = False,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
|
||||
|
||||
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
|
||||
because the base kernel runs at full speed without LoRA SMEM overhead,
|
||||
and the LoRA matmuls (R=16) are tiny separate passes.
|
||||
|
||||
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
|
||||
"""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
|
||||
scatter2scatter,
|
||||
)
|
||||
|
||||
E = W.size(0)
|
||||
R = lora_A.size(0) // E
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
|
||||
output = scatter2scatter(
|
||||
X=X,
|
||||
W=W,
|
||||
b=b,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=y_grouped,
|
||||
out=out,
|
||||
)
|
||||
|
||||
# 2. XA = X @ A^T (tiny: output is [M*k, R])
|
||||
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
|
||||
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
|
||||
XA = scatter2scatter(
|
||||
X=X,
|
||||
W=W_A,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=True,
|
||||
)
|
||||
|
||||
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
|
||||
# Reshape B: [N, R*E] → [E, R, N]
|
||||
W_B = lora_B.T.reshape(E, R, N).contiguous()
|
||||
Y_lora = scatter2scatter(
|
||||
X=XA,
|
||||
W=W_B,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1,
|
||||
x_grouped=True,
|
||||
y_grouped=y_grouped,
|
||||
)
|
||||
|
||||
# 4. Y = Y_base + scaling * Y_lora
|
||||
output.add_(Y_lora, alpha=scaling)
|
||||
return output
|
||||
|
||||
|
||||
# Threshold for switching from fused to split LoRA forward.
|
||||
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
|
||||
# loads dominate in the fused kernel's inner loop).
|
||||
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
|
||||
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
|
||||
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
|
||||
|
||||
|
||||
def scatter2scatter_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
@@ -546,7 +685,13 @@ def scatter2scatter_lora(
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
|
||||
Automatically selects between:
|
||||
- Fused kernel: single Triton kernel with LoRA in the inner loop.
|
||||
Best for many small experts (E>=64, small K*N).
|
||||
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
|
||||
Best for few large experts (E<=32, large K*N like Mixtral).
|
||||
|
||||
Args:
|
||||
X: Input [M, K] or [M*k, K] if x_grouped
|
||||
@@ -565,12 +710,30 @@ def scatter2scatter_lora(
|
||||
Returns:
|
||||
Y: Output [M*k, N]
|
||||
"""
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
E = W.size(0)
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# Dispatch: split for few large experts, fused for many small experts
|
||||
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
|
||||
return _scatter2scatter_lora_split(
|
||||
X,
|
||||
W,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
k,
|
||||
lora_A,
|
||||
lora_B,
|
||||
scaling,
|
||||
b,
|
||||
x_grouped,
|
||||
y_grouped,
|
||||
out,
|
||||
)
|
||||
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
R = lora_A.size(0) // E
|
||||
|
||||
# Pad R to power of 2 for Triton tile size
|
||||
@@ -610,11 +773,9 @@ def scatter2scatter_lora(
|
||||
b_ptr,
|
||||
stride_be,
|
||||
stride_bn,
|
||||
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
|
||||
lora_A,
|
||||
lora_A.stride(0),
|
||||
lora_A.stride(1),
|
||||
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
@@ -625,9 +786,8 @@ def scatter2scatter_lora(
|
||||
K=K,
|
||||
N=N,
|
||||
E=E,
|
||||
ACTUAL_R=R, # True LoRA rank for weight indexing
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
ACC_TYPE=tl.float32,
|
||||
scaling=scaling,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
@@ -761,13 +921,13 @@ def _compute_expert_block_lora_dX(
|
||||
+ (A_expert_offset + R_block)[:, None] * stride_ar
|
||||
+ K_block[None, :] * stride_ak
|
||||
)
|
||||
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0)
|
||||
|
||||
# Cast to float32 for precision
|
||||
a_f32 = a_e.to(tl.float32)
|
||||
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(
|
||||
INPUT_DTYPE
|
||||
)
|
||||
|
||||
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
|
||||
lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32)
|
||||
# tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype
|
||||
lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)
|
||||
|
||||
acc += scaling * lora_dx
|
||||
return acc
|
||||
@@ -779,25 +939,26 @@ def _scatter2scatter_lora_dX_configs():
|
||||
The inner loop is over N (not K as in forward). The output dimension is K.
|
||||
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
|
||||
|
||||
Search space includes smaller tile sizes and fewer pipeline stages to
|
||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||
BLOCK_M is now autotunable (was fixed at 128).
|
||||
|
||||
Search space:
|
||||
BLOCK_K: {32, 64, 128, 256} (output tile)
|
||||
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
||||
BLOCK_M: {32, 64, 128} (token tile)
|
||||
BLOCK_K: {32, 64, 128} (output tile)
|
||||
BLOCK_N: {32, 64} (reduction tile)
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
"""
|
||||
configs = []
|
||||
for block_k, block_n, warps, stages in product(
|
||||
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
||||
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
||||
for block_m, block_k, block_n, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128], # BLOCK_K (output dimension)
|
||||
[32, 64], # BLOCK_N (reduction dimension)
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_K": block_k, "BLOCK_N": block_n},
|
||||
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
|
||||
num_stages=stages,
|
||||
num_warps=warps,
|
||||
)
|
||||
@@ -806,7 +967,7 @@ def _scatter2scatter_lora_dX_configs():
|
||||
|
||||
|
||||
def _prune_dX_configs(configs, named_args, **kwargs):
|
||||
"""Prune backward dX configs based on SMEM capacity.
|
||||
"""Prune backward dX configs based on SMEM capacity and register pressure.
|
||||
|
||||
The dX kernel inner loop loads three tiles per pipeline stage:
|
||||
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
|
||||
@@ -822,23 +983,49 @@ def _prune_dX_configs(configs, named_args, **kwargs):
|
||||
|
||||
scored = []
|
||||
for config in configs:
|
||||
block_m = config.kwargs["BLOCK_M"]
|
||||
block_k = config.kwargs["BLOCK_K"]
|
||||
block_n = config.kwargs["BLOCK_N"]
|
||||
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
|
||||
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
|
||||
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
|
||||
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
|
||||
smem_lora_loop = config.num_stages * block_n * block_r * 2
|
||||
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
|
||||
smem_lora_epilogue = block_r * block_k * 2
|
||||
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
||||
|
||||
# Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],
|
||||
# dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_m, block_k), # acc
|
||||
(block_m, block_r), # dy_b_acc
|
||||
(block_m, block_n), # dy tile
|
||||
(block_n, block_k), # wt tile
|
||||
(block_n, block_r), # b tile
|
||||
(block_r, block_k), # a tile (epilogue)
|
||||
)
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
|
||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||
if pruned:
|
||||
return pruned
|
||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
if scored:
|
||||
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
# All configs pruned by register pressure — fall back to smallest tiles
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: (
|
||||
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -1067,7 +1254,7 @@ def scatter2scatter_lora_dX(
|
||||
N=N,
|
||||
E=E,
|
||||
ACTUAL_R=R,
|
||||
BLOCK_M=BLOCK_M,
|
||||
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
|
||||
BLOCK_R=BLOCK_R,
|
||||
ACC_TYPE=tl.float32,
|
||||
scaling=scaling,
|
||||
@@ -1091,9 +1278,9 @@ def _group_bwd_lora_configs():
|
||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
|
||||
BLOCK_K: {32, 64, 128, 256}
|
||||
BLOCK_N: {32, 64, 128, 256}
|
||||
BLOCK_M: {32, 64, 128} (token-loop tile)
|
||||
BLOCK_K: {32, 64, 128}
|
||||
BLOCK_N: {32, 64}
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
|
||||
@@ -1102,9 +1289,9 @@ def _group_bwd_lora_configs():
|
||||
"""
|
||||
configs = []
|
||||
for block_m, block_k, block_n, warps, stages in product(
|
||||
[32, 64, 128, 256], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_K
|
||||
[32, 64, 128, 256], # BLOCK_N
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128], # BLOCK_K
|
||||
[32, 64], # BLOCK_N
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
@@ -1119,7 +1306,7 @@ def _group_bwd_lora_configs():
|
||||
|
||||
|
||||
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
||||
"""Prune backward configs based on SMEM capacity.
|
||||
"""Prune backward configs based on SMEM capacity and register pressure.
|
||||
|
||||
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
|
||||
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
|
||||
@@ -1138,14 +1325,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
||||
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
|
||||
smem_lora = (block_r * block_k + block_n * block_r) * 2
|
||||
smem = smem_base + smem_lora
|
||||
|
||||
# Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],
|
||||
# a[R,K], b[N,R], xa[M,R], dy_b[M,R]
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_r, block_k), # dA_acc
|
||||
(block_n, block_r), # dB_acc
|
||||
(block_m, block_k), # x tile
|
||||
(block_m, block_n), # dy tile
|
||||
(block_r, block_k), # a tile
|
||||
(block_n, block_r), # b tile
|
||||
(block_m, block_r), # xa intermediate
|
||||
)
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
|
||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||
if pruned:
|
||||
return pruned
|
||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
if scored:
|
||||
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
# All configs pruned by register pressure — fall back to smallest tiles
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: (
|
||||
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -1330,6 +1543,279 @@ def _group_bwd_lora(
|
||||
)
|
||||
|
||||
|
||||
def _group_bwd_split_configs():
|
||||
"""Autotune configs for split dA/dB kernels."""
|
||||
configs = []
|
||||
for block_m, block_dim, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M (token tile)
|
||||
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
|
||||
num_stages=stages,
|
||||
num_warps=warps,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def _prune_split_configs(configs, named_args, **kwargs):
|
||||
"""Prune split kernel configs based on SMEM capacity and register pressure."""
|
||||
smem_cap = _get_smem_capacity()
|
||||
block_r = named_args.get("BLOCK_R", 64)
|
||||
|
||||
# Fixed inner tile for reduction dimension
|
||||
BLOCK_INNER = 64
|
||||
|
||||
pruned = []
|
||||
for config in configs:
|
||||
block_m = config.kwargs["BLOCK_M"]
|
||||
block_dim = config.kwargs["BLOCK_DIM"]
|
||||
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
|
||||
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
|
||||
# LoRA weights held in registers: [INNER, R] or [R, DIM]
|
||||
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
|
||||
|
||||
# Register pressure check
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_r, block_dim), # acc
|
||||
(block_m, BLOCK_INNER), # input tile
|
||||
(block_m, block_dim), # other tile
|
||||
(block_r, BLOCK_INNER), # lora weight
|
||||
)
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
if smem <= smem_cap - _SMEM_SLACK:
|
||||
pruned.append(config)
|
||||
|
||||
if pruned:
|
||||
return pruned
|
||||
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
|
||||
return [configs[0]]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_group_bwd_split_configs(),
|
||||
key=["M", "K", "N"],
|
||||
prune_configs_by={"early_config_prune": _prune_split_configs},
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"NO_DIM_MASK": lambda args: (
|
||||
(args["K"] % args["BLOCK_DIM"]) == 0
|
||||
if args["COMPUTE_DA"]
|
||||
else (args["N"] % args["BLOCK_DIM"]) == 0
|
||||
),
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _group_bwd_lora_split(
|
||||
# Data tensors (DY and X are always present)
|
||||
DY_ptr,
|
||||
stride_dym,
|
||||
stride_dyn,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
# LoRA weight for the inner reduction (B for dA, A for dB)
|
||||
LW_ptr,
|
||||
stride_lw0,
|
||||
stride_lw1,
|
||||
# Output gradient tensor (dA or dB)
|
||||
OUT_ptr,
|
||||
stride_out0,
|
||||
stride_out1,
|
||||
# Expert offsets
|
||||
expert_offsets_ptr,
|
||||
# Dimensions
|
||||
M,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
ACTUAL_R: tl.constexpr,
|
||||
BLOCK_R: tl.constexpr,
|
||||
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
|
||||
scaling,
|
||||
# Mode flag
|
||||
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
|
||||
# Tile sizes
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DIM: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
NO_DIM_MASK: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Unified split kernel for LoRA gradient computation.
|
||||
|
||||
When COMPUTE_DA=True:
|
||||
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
|
||||
Grid: (E, cdiv(K, BLOCK_DIM))
|
||||
- outer_ptr/stride = X (read [M, K_block])
|
||||
- inner reduction over N using DY and B
|
||||
- output shape [BLOCK_R, BLOCK_DIM]
|
||||
|
||||
When COMPUTE_DA=False:
|
||||
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
|
||||
Grid: (E, cdiv(N, BLOCK_DIM))
|
||||
- outer_ptr/stride = DY (read [M, N_block])
|
||||
- inner reduction over K using X and A
|
||||
- output shape [BLOCK_DIM, BLOCK_R]
|
||||
|
||||
No atomic adds — each (E, dim_block) pair is written by exactly one block.
|
||||
"""
|
||||
E_idx = tl.program_id(0)
|
||||
dim_block_id = tl.program_id(1)
|
||||
|
||||
if E_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
# Output dimension tile (K for dA, N for dB)
|
||||
if COMPUTE_DA:
|
||||
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
|
||||
else:
|
||||
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
|
||||
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||
dim_mask = dim_block < OUT_DIM
|
||||
R_block = tl.arange(0, BLOCK_R)
|
||||
R_mask = R_block < ACTUAL_R
|
||||
lora_offset = E_idx * ACTUAL_R
|
||||
|
||||
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
|
||||
if COMPUTE_DA:
|
||||
out_blk_ptrs = (
|
||||
OUT_ptr
|
||||
+ (lora_offset + R_block)[:, None] * stride_out0
|
||||
+ dim_block[None, :] * stride_out1
|
||||
)
|
||||
out_mask = R_mask[:, None] & dim_mask[None, :]
|
||||
else:
|
||||
out_blk_ptrs = (
|
||||
OUT_ptr
|
||||
+ dim_block[:, None] * stride_out0
|
||||
+ (lora_offset + R_block)[None, :] * stride_out1
|
||||
)
|
||||
out_mask = dim_mask[:, None] & R_mask[None, :]
|
||||
|
||||
if num_tokens > 0:
|
||||
M_block = tl.arange(0, BLOCK_M)
|
||||
INPUT_DTYPE = X_ptr.dtype.element_ty
|
||||
BLOCK_INNER: tl.constexpr = 64
|
||||
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
|
||||
|
||||
if COMPUTE_DA:
|
||||
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
|
||||
else:
|
||||
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
|
||||
|
||||
M_iters = tl.cdiv(num_tokens, BLOCK_M)
|
||||
for i in range(M_iters):
|
||||
M_idx = start_idx + i * BLOCK_M + M_block
|
||||
M_mask = M_idx < end_idx
|
||||
|
||||
if COMPUTE_DA:
|
||||
# Load X[M, K_block] (the "outer" tensor for dA)
|
||||
outer = tl.load(
|
||||
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
|
||||
mask=M_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
|
||||
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
|
||||
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||
inner_range = tl.arange(0, BLOCK_INNER)
|
||||
for j in range(inner_iters):
|
||||
inn_off = j * BLOCK_INNER + inner_range
|
||||
inn_mask = inn_off < N
|
||||
|
||||
dy_tile = tl.load(
|
||||
DY_ptr
|
||||
+ M_idx[:, None] * stride_dym
|
||||
+ inn_off[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & inn_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
|
||||
lw_tile = tl.load(
|
||||
LW_ptr
|
||||
+ inn_off[:, None] * stride_lw0
|
||||
+ (lora_offset + R_block)[None, :] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
|
||||
|
||||
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
|
||||
acc += tl.dot(
|
||||
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
|
||||
)
|
||||
else:
|
||||
# Load DY[M, N_block] (the "outer" tensor for dB)
|
||||
outer = tl.load(
|
||||
DY_ptr
|
||||
+ M_idx[:, None] * stride_dym
|
||||
+ dim_block[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
|
||||
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
|
||||
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||
inner_range = tl.arange(0, BLOCK_INNER)
|
||||
for j in range(inner_iters):
|
||||
inn_off = j * BLOCK_INNER + inner_range
|
||||
inn_mask = inn_off < K
|
||||
|
||||
x_tile = tl.load(
|
||||
X_ptr
|
||||
+ M_idx[:, None] * stride_xm
|
||||
+ inn_off[None, :] * stride_xk,
|
||||
mask=M_mask[:, None] & inn_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
|
||||
# We want A[e]^T: [K, R], so load as [K_inner, R]
|
||||
lw_tile = tl.load(
|
||||
LW_ptr
|
||||
+ (lora_offset + R_block)[None, :] * stride_lw0
|
||||
+ inn_off[:, None] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
|
||||
|
||||
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
|
||||
acc += tl.dot(
|
||||
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
|
||||
)
|
||||
|
||||
tl.store(
|
||||
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
|
||||
)
|
||||
else:
|
||||
# Zero out this expert's slice — needed because output uses empty_like
|
||||
if COMPUTE_DA:
|
||||
tl.store(
|
||||
out_blk_ptrs,
|
||||
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
|
||||
mask=out_mask,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
out_blk_ptrs,
|
||||
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
|
||||
mask=out_mask,
|
||||
)
|
||||
|
||||
|
||||
def group_bwd_lora(
|
||||
DY: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
@@ -1344,6 +1830,9 @@ def group_bwd_lora(
|
||||
"""
|
||||
Compute LoRA gradients for A and B on expert-grouped data.
|
||||
|
||||
Uses split dA/dB kernels that eliminate atomic adds by giving each
|
||||
(expert, output_block) pair its own thread block.
|
||||
|
||||
Args:
|
||||
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
|
||||
X: Input [M_total, K] (grouped by expert)
|
||||
@@ -1361,19 +1850,46 @@ def group_bwd_lora(
|
||||
K = X.size(1)
|
||||
N = DY.size(1)
|
||||
|
||||
# Zero-init for atomic accumulation
|
||||
dA = torch.zeros_like(lora_A)
|
||||
dB = torch.zeros_like(lora_B)
|
||||
# No zero-init needed: the split kernels write zeros for experts with
|
||||
# zero routed tokens directly in the kernel (else branch).
|
||||
dA = torch.empty_like(lora_A)
|
||||
dB = torch.empty_like(lora_B)
|
||||
|
||||
BLOCK_R = _block_r_for_rank(R)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
E * triton.cdiv(K, META["BLOCK_K"]),
|
||||
triton.cdiv(N, META["BLOCK_N"]),
|
||||
)
|
||||
def grid_dA(META):
|
||||
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
|
||||
|
||||
_group_bwd_lora[grid](
|
||||
_group_bwd_lora_split[grid_dA](
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
dA,
|
||||
dA.stride(0),
|
||||
dA.stride(1),
|
||||
expert_offsets,
|
||||
M=DY.size(0),
|
||||
K=K,
|
||||
N=N,
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
INNER_DIM=N,
|
||||
scaling=scaling,
|
||||
COMPUTE_DA=True,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
def grid_dB(META):
|
||||
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
|
||||
|
||||
_group_bwd_lora_split[grid_dB](
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
@@ -1383,12 +1899,6 @@ def group_bwd_lora(
|
||||
lora_A,
|
||||
lora_A.stride(0),
|
||||
lora_A.stride(1),
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
dA,
|
||||
dA.stride(0),
|
||||
dA.stride(1),
|
||||
dB,
|
||||
dB.stride(0),
|
||||
dB.stride(1),
|
||||
@@ -1396,9 +1906,11 @@ def group_bwd_lora(
|
||||
M=DY.size(0),
|
||||
K=K,
|
||||
N=N,
|
||||
ACTUAL_R=R, # True LoRA rank
|
||||
BLOCK_R=BLOCK_R, # Padded tile size
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
INNER_DIM=K,
|
||||
scaling=scaling,
|
||||
COMPUTE_DA=False,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
@@ -168,6 +168,9 @@ def _unwrap_experts_lora(experts_module):
|
||||
-> base_layer: ParamWrapper(gate_up_proj)
|
||||
-> base_layer: OlmoeExperts (the real module)
|
||||
|
||||
For non-gated experts (e.g. NemotronH), the chain targets ``up_proj``
|
||||
instead of ``gate_up_proj``.
|
||||
|
||||
This function walks the chain, collects LoRA params keyed by
|
||||
``parameter_name``, and returns the base experts module.
|
||||
|
||||
@@ -176,6 +179,7 @@ def _unwrap_experts_lora(experts_module):
|
||||
|
||||
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
|
||||
A/B are already in scattermoe layout.
|
||||
For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA.
|
||||
"""
|
||||
# Collect ParamWrapper layers by their parameter_name
|
||||
wrappers = {}
|
||||
@@ -195,13 +199,15 @@ def _unwrap_experts_lora(experts_module):
|
||||
num_experts = getattr(base_experts, "num_experts", None)
|
||||
if num_experts is None:
|
||||
# Fallback: infer from parameter shape
|
||||
gup = getattr(base_experts, "gate_up_proj", None)
|
||||
if gup is not None:
|
||||
num_experts = gup.shape[0]
|
||||
for attr in ("gate_up_proj", "up_proj"):
|
||||
param = getattr(base_experts, attr, None)
|
||||
if param is not None:
|
||||
num_experts = param.shape[0]
|
||||
break
|
||||
|
||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
||||
# Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition)
|
||||
gup_lora = None
|
||||
gup_wrapper = wrappers.get("gate_up_proj")
|
||||
gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj")
|
||||
if gup_wrapper is not None:
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
|
||||
if lora_A is not None:
|
||||
@@ -441,10 +447,12 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
Supports:
|
||||
|
||||
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH
|
||||
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
||||
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
||||
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
* **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj)
|
||||
* **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -467,7 +475,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||
|
||||
# ====================================================================
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH)
|
||||
# ====================================================================
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
@@ -490,19 +498,86 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||
|
||||
# ====================================================================
|
||||
# Gate + Up projection
|
||||
# Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only)
|
||||
# ====================================================================
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
is_gated = hasattr(experts, "gate_up_proj")
|
||||
up_proj_attr = "gate_up_proj" if is_gated else "up_proj"
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection (NemotronH: fc1/fc2_latent_proj)
|
||||
# ====================================================================
|
||||
fc1_latent_proj = getattr(self, "fc1_latent_proj", None)
|
||||
fc2_latent_proj = getattr(self, "fc2_latent_proj", None)
|
||||
|
||||
expert_input = hidden_states_flat
|
||||
if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity):
|
||||
expert_input = fc1_latent_proj(hidden_states_flat)
|
||||
|
||||
# ====================================================================
|
||||
# Selective expert weight dequantization
|
||||
# ====================================================================
|
||||
# When experts are BnB-quantized (quantize_moe_experts), dequantize
|
||||
# only the active experts instead of all E. This saves ~97% memory
|
||||
# for the transient dequant buffer when few experts are active.
|
||||
use_selective = (
|
||||
getattr(self, "_use_selective_dequant", False)
|
||||
and hasattr(experts, "parametrizations")
|
||||
and up_proj_attr in experts.parametrizations
|
||||
)
|
||||
|
||||
if use_selective:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
|
||||
get_active_experts,
|
||||
remap_expert_indices,
|
||||
selective_expert_weights,
|
||||
selective_lora_weights,
|
||||
)
|
||||
|
||||
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
|
||||
remapped_expert_idxs, compact_offsets = remap_expert_indices(
|
||||
sorted_expert_idxs,
|
||||
expert_offsets,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
# Dequantize only active experts' weights
|
||||
up_W = selective_expert_weights(
|
||||
experts,
|
||||
up_proj_attr,
|
||||
active_experts,
|
||||
).transpose(2, 1)
|
||||
|
||||
# Remap LoRA weights to match compact expert indices
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup_A, gup_B = selective_lora_weights(
|
||||
gup_A,
|
||||
gup_B,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
gup_lora = (gup_A, gup_B, gup_scaling)
|
||||
|
||||
# Use remapped indices for ScatterMoE kernels
|
||||
sei_gup = remapped_expert_idxs
|
||||
eo_gup = compact_offsets
|
||||
else:
|
||||
up_W = getattr(experts, up_proj_attr).transpose(2, 1)
|
||||
sei_gup = sorted_expert_idxs
|
||||
eo_gup = expert_offsets
|
||||
|
||||
# ====================================================================
|
||||
# Up projection (gated: gate_up_proj; non-gated: up_proj)
|
||||
# ====================================================================
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup = parallel_linear_lora(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
up_out = parallel_linear_lora(
|
||||
expert_input,
|
||||
up_W,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_gup,
|
||||
lora_A=gup_A,
|
||||
lora_B=gup_B,
|
||||
scaling=gup_scaling,
|
||||
@@ -512,24 +587,52 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
use_fused_gather=True,
|
||||
)
|
||||
else:
|
||||
gup = parallel_linear(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
up_out = parallel_linear(
|
||||
expert_input,
|
||||
up_W,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_gup,
|
||||
grouped_in=False,
|
||||
grouped_out=True,
|
||||
)
|
||||
|
||||
gates, h = gup.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
# ====================================================================
|
||||
# Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up))
|
||||
# ====================================================================
|
||||
if is_gated:
|
||||
gates, h = up_out.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
else:
|
||||
h = experts.act_fn(up_out)
|
||||
|
||||
# ====================================================================
|
||||
# Down projection
|
||||
# ====================================================================
|
||||
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
|
||||
if use_selective:
|
||||
down_W = selective_expert_weights(
|
||||
experts,
|
||||
"down_proj",
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, inter, hidden]
|
||||
|
||||
if down_lora is not None:
|
||||
down_A, down_B, down_scaling = down_lora
|
||||
down_A, down_B = selective_lora_weights(
|
||||
down_A,
|
||||
down_B,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
down_lora = (down_A, down_B, down_scaling)
|
||||
|
||||
sei_down = remapped_expert_idxs
|
||||
eo_down = compact_offsets
|
||||
else:
|
||||
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
|
||||
sei_down = sorted_expert_idxs
|
||||
eo_down = expert_offsets
|
||||
|
||||
if down_lora is not None:
|
||||
down_A, down_B, down_scaling = down_lora
|
||||
@@ -537,9 +640,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
h,
|
||||
down_W,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sei_down,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_down,
|
||||
lora_A=down_A,
|
||||
lora_B=down_B,
|
||||
scaling=down_scaling,
|
||||
@@ -554,14 +657,20 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
h,
|
||||
down_W,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sei_down,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_down,
|
||||
grouped_in=True,
|
||||
grouped_out=False,
|
||||
gates=routing_weights,
|
||||
)
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection back to hidden_size (NemotronH)
|
||||
# ====================================================================
|
||||
if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity):
|
||||
expert_output = fc2_latent_proj(expert_output)
|
||||
|
||||
# ====================================================================
|
||||
# Combine with shared expert and reshape
|
||||
# ====================================================================
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Selective Expert Dequantization
|
||||
===============================
|
||||
|
||||
Instead of dequantizing all E expert weight matrices at once (which creates
|
||||
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
|
||||
are actually routed to by the current batch's top-k selection.
|
||||
|
||||
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
|
||||
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
|
||||
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
|
||||
- Savings: ~97% memory reduction per layer
|
||||
|
||||
This module provides format-agnostic selective weight extraction:
|
||||
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
|
||||
- bf16/fp32: direct indexing (no dequant needed)
|
||||
- FP8: slice + cast
|
||||
|
||||
The ScatterMoE kernel itself doesn't change — we remap expert indices
|
||||
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
|
||||
weight tensor.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
|
||||
"""Get sorted unique expert indices from the routing output.
|
||||
|
||||
Args:
|
||||
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
|
||||
E: Total number of experts
|
||||
|
||||
Returns:
|
||||
active: Sorted unique expert indices [num_active]
|
||||
"""
|
||||
return torch.unique(sorted_expert_idxs)
|
||||
|
||||
|
||||
def remap_expert_indices(
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
E: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Remap global expert indices to compact indices.
|
||||
|
||||
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
|
||||
sort order. Also compacts expert_offsets to only active experts.
|
||||
|
||||
Args:
|
||||
sorted_expert_idxs: [T*k] expert ids in sorted order
|
||||
expert_offsets: [E] cumulative token counts (original)
|
||||
active_experts: [num_active] sorted unique expert ids
|
||||
E: Total number of experts
|
||||
|
||||
Returns:
|
||||
remapped_idxs: [T*k] expert ids in [0..num_active-1]
|
||||
compact_offsets: [num_active] cumulative token counts
|
||||
"""
|
||||
# Build remap table: global_id -> compact_id
|
||||
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
|
||||
remap[active_experts] = torch.arange(
|
||||
len(active_experts), device=sorted_expert_idxs.device
|
||||
)
|
||||
|
||||
remapped_idxs = remap[sorted_expert_idxs]
|
||||
|
||||
# Compact the expert_offsets: only keep active experts' cumulative counts
|
||||
compact_offsets = expert_offsets[active_experts]
|
||||
|
||||
return remapped_idxs, compact_offsets
|
||||
|
||||
|
||||
def _selective_dequant_bnb4(
|
||||
raw_param: torch.Tensor,
|
||||
quant_state,
|
||||
active_experts: torch.Tensor,
|
||||
expert_shape: tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""Dequantize only selected experts from BnB 4-bit packed data.
|
||||
|
||||
The raw parameter is a flattened 4-bit packed tensor. Each expert's
|
||||
data is contiguous (stored in expert-major order), so we can gather
|
||||
the packed data and absmax blocks for active experts, then dequantize
|
||||
as one contiguous block.
|
||||
|
||||
Args:
|
||||
raw_param: Flattened uint8 tensor of packed 4-bit weights
|
||||
quant_state: BnB QuantState with absmax, blocksize, code, etc.
|
||||
active_experts: [num_active] expert indices to dequantize
|
||||
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
|
||||
|
||||
Returns:
|
||||
Dequantized weights [num_active, dim1, dim2] in original dtype
|
||||
"""
|
||||
import bitsandbytes.functional as F # noqa: N812
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
expert_numel = expert_shape[0] * expert_shape[1]
|
||||
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
|
||||
blocks_per_expert = expert_numel // quant_state.blocksize
|
||||
num_active = len(active_experts)
|
||||
|
||||
if blocks_per_expert == 0:
|
||||
# Expert is smaller than one quantization block — blocks span across
|
||||
# expert boundaries, so per-expert slicing isn't possible.
|
||||
# Fallback: full dequantize + index.
|
||||
full = F.dequantize_4bit(raw_param, quant_state)
|
||||
E_total = full.numel() // expert_numel
|
||||
return full.reshape(E_total, *expert_shape)[active_experts]
|
||||
|
||||
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
|
||||
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
|
||||
selective_dequant_nf4_triton,
|
||||
)
|
||||
|
||||
# Handle nested (double) quantization: dequantize absmax first
|
||||
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
|
||||
if quant_state.nested:
|
||||
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
else:
|
||||
absmax = quant_state.absmax
|
||||
|
||||
return selective_dequant_nf4_triton(
|
||||
packed_data=raw_param,
|
||||
absmax=absmax,
|
||||
active_experts=active_experts,
|
||||
expert_shape=expert_shape,
|
||||
blocksize=quant_state.blocksize,
|
||||
dtype=quant_state.dtype,
|
||||
codebook=quant_state.code,
|
||||
)
|
||||
|
||||
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
|
||||
raw_flat = raw_param.reshape(-1)
|
||||
|
||||
offsets_qt = (
|
||||
active_experts.long()[:, None] * packed_per_expert
|
||||
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
|
||||
).reshape(-1)
|
||||
qt_gathered = raw_flat[offsets_qt]
|
||||
|
||||
offsets_abs = (
|
||||
active_experts.long()[:, None] * blocks_per_expert
|
||||
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
|
||||
).reshape(-1)
|
||||
|
||||
if quant_state.nested:
|
||||
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||
full_absmax += quant_state.offset
|
||||
if full_absmax.dtype != torch.float32:
|
||||
full_absmax = full_absmax.float()
|
||||
absmax_gathered = full_absmax[offsets_abs]
|
||||
else:
|
||||
absmax_gathered = quant_state.absmax[offsets_abs]
|
||||
|
||||
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
|
||||
|
||||
gathered_qs = QuantState(
|
||||
absmax=absmax_gathered,
|
||||
shape=torch.Size([num_active * expert_numel]),
|
||||
blocksize=quant_state.blocksize,
|
||||
quant_type=quant_state.quant_type,
|
||||
code=quant_state.code,
|
||||
dtype=quant_state.dtype,
|
||||
)
|
||||
|
||||
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
|
||||
return deq.reshape(num_active, *expert_shape)
|
||||
|
||||
|
||||
def _selective_index_dense(
|
||||
param: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Select experts from a dense (bf16/fp32) weight tensor.
|
||||
|
||||
Simple indexing — no dequantization needed.
|
||||
"""
|
||||
return param[active_experts]
|
||||
|
||||
|
||||
def selective_expert_weights(
|
||||
experts_module: nn.Module,
|
||||
param_name: str,
|
||||
active_experts: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract and dequantize only the active experts' weights.
|
||||
|
||||
Format-agnostic: dispatches based on whether the parameter is
|
||||
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
|
||||
|
||||
Args:
|
||||
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
|
||||
param_name: "gate_up_proj" or "down_proj"
|
||||
active_experts: [num_active] sorted unique expert indices
|
||||
|
||||
Returns:
|
||||
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
|
||||
"""
|
||||
# Check if the parameter is BnB-quantized via parametrize
|
||||
if (
|
||||
hasattr(experts_module, "parametrizations")
|
||||
and param_name in experts_module.parametrizations
|
||||
):
|
||||
param_list = experts_module.parametrizations[param_name]
|
||||
parametrization = param_list[0]
|
||||
|
||||
# BnB 4-bit parametrization
|
||||
if hasattr(parametrization, "quant_state"):
|
||||
# The raw quantized data is on the ParametrizationList, not the
|
||||
# individual Bnb4bitParametrization module
|
||||
raw_param = param_list.original
|
||||
qs = parametrization.quant_state
|
||||
# qs.shape is the original tensor shape before flattening.
|
||||
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
|
||||
orig_shape = qs.shape
|
||||
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
|
||||
expert_shape = (orig_shape[1], orig_shape[2])
|
||||
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
|
||||
# Flattened — need to infer from module attributes
|
||||
E_total = getattr(experts_module, "num_experts", None)
|
||||
if E_total is None:
|
||||
E_total = int(active_experts.max().item()) + 1
|
||||
expert_numel = orig_shape[0] // E_total
|
||||
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
|
||||
experts_module, "intermediate_dim", None
|
||||
)
|
||||
if d2 and expert_numel % d2 == 0:
|
||||
expert_shape = (expert_numel // d2, d2)
|
||||
else:
|
||||
full = getattr(experts_module, param_name)
|
||||
return full[active_experts]
|
||||
else:
|
||||
full = getattr(experts_module, param_name)
|
||||
return full[active_experts]
|
||||
|
||||
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
|
||||
|
||||
# Dense parameter (bf16/fp32) — direct indexing
|
||||
param = getattr(experts_module, param_name)
|
||||
if param.dim() == 3:
|
||||
return param[active_experts]
|
||||
|
||||
# Fallback: full access
|
||||
return param
|
||||
|
||||
|
||||
def selective_lora_weights(
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
E: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Select LoRA A and B weights for only the active experts.
|
||||
|
||||
LoRA layout (scattermoe format):
|
||||
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
|
||||
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
|
||||
|
||||
Returns compact:
|
||||
A: [r*num_active, K]
|
||||
B: [N, r*num_active]
|
||||
"""
|
||||
R = lora_A.size(0) // E
|
||||
|
||||
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
|
||||
row_idx = (
|
||||
active_experts.long()[:, None] * R
|
||||
+ torch.arange(R, device=lora_A.device)[None, :]
|
||||
).reshape(-1)
|
||||
|
||||
compact_A = lora_A[row_idx] # [r*num_active, K]
|
||||
compact_B = lora_B[:, row_idx] # [N, r*num_active]
|
||||
|
||||
return compact_A, compact_B
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Triton kernel for fused selective expert gather + NF4 dequantization.
|
||||
|
||||
Instead of:
|
||||
1. Gather packed uint8 data for active experts (memory copy)
|
||||
2. Gather absmax for active experts (memory copy)
|
||||
3. Call BnB dequantize_4bit CUDA kernel
|
||||
|
||||
This kernel does all three in one pass:
|
||||
- Reads packed NF4 bytes from expert-strided positions
|
||||
- Looks up the NF4 codebook
|
||||
- Multiplies by the per-block absmax
|
||||
- Writes bf16 output directly
|
||||
|
||||
This eliminates the intermediate gather buffer entirely.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# NF4 codebook (16 values, precomputed by BnB)
|
||||
# These are the normalized float4 reconstruction values
|
||||
NF4_CODEBOOK = [
|
||||
-1.0,
|
||||
-0.6961928009986877,
|
||||
-0.5250730514526367,
|
||||
-0.39491748809814453,
|
||||
-0.28444138169288635,
|
||||
-0.18477343022823334,
|
||||
-0.09105003625154495,
|
||||
0.0,
|
||||
0.07958029955625534,
|
||||
0.16093020141124725,
|
||||
0.24611230194568634,
|
||||
0.33791524171829224,
|
||||
0.44070982933044434,
|
||||
0.5626170039176941,
|
||||
0.7229568362236023,
|
||||
1.0,
|
||||
]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _selective_dequant_nf4_kernel(
|
||||
# Input: packed NF4 data (flattened, expert-major order)
|
||||
packed_ptr,
|
||||
# Input: absmax values (flattened, expert-major order)
|
||||
absmax_ptr,
|
||||
# Input: active expert indices
|
||||
active_experts_ptr,
|
||||
# Input: NF4 codebook (16 float values)
|
||||
codebook_ptr,
|
||||
# Output: dequantized bf16 weights [num_active, expert_numel]
|
||||
out_ptr,
|
||||
stride_out_e, # stride for expert dim in output
|
||||
# Dimensions
|
||||
num_active,
|
||||
packed_per_expert, # expert_numel // 2
|
||||
blocks_per_expert, # expert_numel // blocksize
|
||||
blocksize: tl.constexpr,
|
||||
# Tile size
|
||||
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
|
||||
):
|
||||
"""
|
||||
Each program processes BLOCK_SIZE elements from one expert.
|
||||
|
||||
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
|
||||
|
||||
For each output element:
|
||||
1. Compute which byte in packed data contains this element
|
||||
2. Extract the 4-bit nibble (high or low)
|
||||
3. Look up in NF4 codebook
|
||||
4. Scale by absmax for this block
|
||||
"""
|
||||
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
|
||||
block_id = tl.program_id(1) # which element block
|
||||
|
||||
# Load the global expert index
|
||||
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
|
||||
|
||||
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
|
||||
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = elem_offset < expert_numel
|
||||
|
||||
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
|
||||
byte_idx = elem_offset // 2
|
||||
is_high = (elem_offset % 2) == 1
|
||||
|
||||
# Read packed bytes from the global expert's region
|
||||
packed_global_offset = expert_global * packed_per_expert + byte_idx
|
||||
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
|
||||
tl.int32
|
||||
)
|
||||
|
||||
# Extract 4-bit nibble
|
||||
# BnB packing: high nibble = even element, low nibble = odd element
|
||||
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
|
||||
|
||||
# NF4 codebook lookup
|
||||
# Load all 16 codebook values (small, fits in registers)
|
||||
# Use gather from codebook pointer
|
||||
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
|
||||
|
||||
# Load absmax for this element's quantization block
|
||||
block_idx = elem_offset // blocksize
|
||||
absmax_global_offset = expert_global * blocks_per_expert + block_idx
|
||||
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
|
||||
|
||||
# Dequantize: value = codebook[nibble] * absmax
|
||||
result = code_val * absmax_val
|
||||
|
||||
# Store to output
|
||||
out_offset = expert_local_idx * stride_out_e + elem_offset
|
||||
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
def selective_dequant_nf4_triton(
|
||||
packed_data: torch.Tensor,
|
||||
absmax: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
expert_shape: tuple[int, int],
|
||||
blocksize: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
codebook: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Fused selective gather + NF4 dequantization via Triton kernel.
|
||||
|
||||
Args:
|
||||
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
|
||||
absmax: Per-block scaling factors [total_blocks]
|
||||
active_experts: Sorted indices of experts to dequantize [num_active]
|
||||
expert_shape: (dim1, dim2) per expert
|
||||
blocksize: Quantization block size
|
||||
dtype: Output dtype (default bf16)
|
||||
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
|
||||
|
||||
Returns:
|
||||
Dequantized weights [num_active, dim1, dim2]
|
||||
"""
|
||||
num_active = active_experts.shape[0]
|
||||
expert_numel = expert_shape[0] * expert_shape[1]
|
||||
packed_per_expert = expert_numel // 2
|
||||
blocks_per_expert = expert_numel // blocksize
|
||||
|
||||
# Prepare codebook on device
|
||||
if codebook is None:
|
||||
codebook = torch.tensor(
|
||||
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
|
||||
)
|
||||
else:
|
||||
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
|
||||
|
||||
# Flatten inputs
|
||||
packed_flat = packed_data.reshape(-1)
|
||||
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
|
||||
|
||||
# Output buffer
|
||||
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
|
||||
|
||||
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
|
||||
|
||||
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
|
||||
|
||||
_selective_dequant_nf4_kernel[grid](
|
||||
packed_flat,
|
||||
absmax_flat,
|
||||
active_experts,
|
||||
codebook,
|
||||
out,
|
||||
out.stride(0),
|
||||
num_active=num_active,
|
||||
packed_per_expert=packed_per_expert,
|
||||
blocks_per_expert=blocks_per_expert,
|
||||
blocksize=blocksize,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return out.reshape(num_active, *expert_shape)
|
||||
@@ -61,7 +61,16 @@ class KernelsPlugin(BasePlugin):
|
||||
return "axolotl.integrations.kernels.KernelsArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
|
||||
|
||||
# Prefer text backbone type for VLMs, but fall back to base type
|
||||
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
|
||||
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||
if (
|
||||
moe_model_type not in SPARSE_MOE_BLOCK
|
||||
and cfg.model_config_type in SPARSE_MOE_BLOCK
|
||||
):
|
||||
moe_model_type = cfg.model_config_type
|
||||
|
||||
if cfg.use_scattermoe:
|
||||
self._register_kernels()
|
||||
|
||||
@@ -30,6 +30,15 @@ class LigerArgs(BaseModel):
|
||||
|
||||
liger_rope: bool | None = None
|
||||
liger_rms_norm: bool | None = None
|
||||
liger_rms_norm_gated: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
|
||||
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
|
||||
)
|
||||
},
|
||||
)
|
||||
liger_layer_norm: bool | None = None
|
||||
liger_swiglu: bool | None = None
|
||||
liger_glu_activation: bool | None = None
|
||||
|
||||
175
src/axolotl/integrations/liger/models/qwen3_5.py
Normal file
175
src/axolotl/integrations/liger/models/qwen3_5.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3_5(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
rms_norm_gated: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
|
||||
|
||||
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
||||
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
)
|
||||
|
||||
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
|
||||
|
||||
if rms_norm:
|
||||
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
||||
class LigerRMSNormForQwen3_5(LigerRMSNorm):
|
||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
||||
super().__init__(
|
||||
dim,
|
||||
eps=eps,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
|
||||
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
|
||||
|
||||
if rms_norm_gated:
|
||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||
|
||||
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
|
||||
|
||||
if glu_activation:
|
||||
|
||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
||||
config = deepcopy(config)
|
||||
if intermediate_size is not None:
|
||||
config.intermediate_size = intermediate_size
|
||||
return LigerSwiGLUMLP(config, **kwargs)
|
||||
|
||||
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
|
||||
|
||||
if layer_norm:
|
||||
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward
|
||||
198
src/axolotl/integrations/liger/models/qwen3_5_moe.py
Normal file
198
src/axolotl/integrations/liger/models/qwen3_5_moe.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits,
|
||||
labels,
|
||||
self.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3_5_moe(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
rms_norm_gated: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
|
||||
|
||||
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
||||
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
)
|
||||
|
||||
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
|
||||
|
||||
if rms_norm:
|
||||
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
||||
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
|
||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
||||
super().__init__(
|
||||
dim,
|
||||
eps=eps,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
|
||||
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
|
||||
|
||||
if rms_norm_gated:
|
||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||
|
||||
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
|
||||
|
||||
if glu_activation:
|
||||
|
||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
||||
config = deepcopy(config)
|
||||
if intermediate_size is not None:
|
||||
config.intermediate_size = intermediate_size
|
||||
return LigerSwiGLUMLP(config, **kwargs)
|
||||
|
||||
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
|
||||
|
||||
if layer_norm:
|
||||
modeling_mod.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward
|
||||
@@ -174,6 +174,19 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_5":
|
||||
from axolotl.integrations.liger.models.qwen3_5 import (
|
||||
apply_liger_kernel_to_qwen3_5,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3_5(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||
apply_liger_kernel_to_qwen3_moe,
|
||||
@@ -186,6 +199,19 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_5_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_5_moe import (
|
||||
apply_liger_kernel_to_qwen3_5_moe,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3_5_moe(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "granitemoe":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
||||
|
||||
|
||||
147
src/axolotl/kernels/dora.py
Normal file
147
src/axolotl/kernels/dora.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
|
||||
|
||||
Fuses the weight norm computation and magnitude scaling to avoid
|
||||
materializing the full [out_features, in_features] combined weight matrix.
|
||||
The B@A product is computed row-by-row inside the kernel.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .quantize import dequantize
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dora_fused_norm_kernel(
|
||||
# Pointers
|
||||
W_ptr, # base weight [out, in] (dequantized, row-major)
|
||||
B_ptr, # LoRA B [out, rank] (row-major)
|
||||
A_ptr, # LoRA A [rank, in] (row-major)
|
||||
mag_ptr, # magnitude vector [out]
|
||||
out_ptr, # output mag_norm_scale [out]
|
||||
# Shapes
|
||||
out_features,
|
||||
in_features,
|
||||
rank,
|
||||
# Scaling
|
||||
lora_scale, # float scaling factor
|
||||
# Block sizes
|
||||
BLOCK_IN: tl.constexpr,
|
||||
BLOCK_R: tl.constexpr, # >= rank, power of 2
|
||||
):
|
||||
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
|
||||
|
||||
Each program handles one output row. B[row,:] is loaded once (small),
|
||||
then we tile over in_features computing the dot product with A[:,tile]
|
||||
and accumulating the squared norm.
|
||||
|
||||
This avoids materializing the full [out, in] B@A matrix.
|
||||
"""
|
||||
row = tl.program_id(0)
|
||||
if row >= out_features:
|
||||
return
|
||||
|
||||
# Accumulate squared norm across tiles of in_features
|
||||
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
||||
|
||||
for start in range(0, in_features, BLOCK_IN):
|
||||
cols = start + tl.arange(0, BLOCK_IN)
|
||||
col_mask = cols < in_features
|
||||
|
||||
# Load W[row, cols]
|
||||
w_vals = tl.load(
|
||||
W_ptr + row * in_features + cols,
|
||||
mask=col_mask,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# Compute (B[row,:] @ A[:, cols]) for this tile
|
||||
# Load B[row, r] as scalar and A[r, cols] as vector for each r
|
||||
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
||||
for r in tl.static_range(BLOCK_R):
|
||||
# Load scalar B[row, r]
|
||||
b_val = tl.load(
|
||||
B_ptr + row * rank + r,
|
||||
mask=(r < rank),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
# Load vector A[r, cols]
|
||||
a_vals = tl.load(
|
||||
A_ptr + r * in_features + cols,
|
||||
mask=(col_mask & (r < rank)),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
ba_vals += b_val * a_vals
|
||||
|
||||
# Combined: W + s * (B @ A)
|
||||
combined = w_vals + lora_scale * ba_vals
|
||||
|
||||
# Accumulate squared values
|
||||
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
|
||||
|
||||
# Reduce to scalar norm
|
||||
norm_sq = tl.sum(norm_sq_acc, axis=0)
|
||||
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
|
||||
|
||||
# Load magnitude and compute scale
|
||||
mag = tl.load(mag_ptr + row).to(tl.float32)
|
||||
scale = mag / norm
|
||||
|
||||
tl.store(out_ptr + row, scale)
|
||||
|
||||
|
||||
def triton_dora_scale(
|
||||
W: torch.Tensor,
|
||||
W_quant,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
magnitude: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Compute DoRA mag_norm_scale using fused Triton kernel.
|
||||
|
||||
Computes B@A row-by-row inside the kernel, avoiding the full
|
||||
[out_features, in_features] materialization.
|
||||
|
||||
Args:
|
||||
W: base weight [out, in] (possibly quantized)
|
||||
W_quant: quantization state
|
||||
A: LoRA A [rank, in]
|
||||
B: LoRA B [out, rank]
|
||||
s: LoRA scaling factor
|
||||
magnitude: learned magnitude [out]
|
||||
dtype: compute dtype
|
||||
|
||||
Returns:
|
||||
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
|
||||
"""
|
||||
# Dequantize W to [out, in]
|
||||
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
|
||||
|
||||
out_features, in_features = W_full.shape
|
||||
rank = A.shape[0]
|
||||
|
||||
out = torch.empty(out_features, dtype=dtype, device=W.device)
|
||||
|
||||
# Block sizes
|
||||
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
|
||||
BLOCK_R = triton.next_power_of_2(rank)
|
||||
|
||||
_dora_fused_norm_kernel[(out_features,)](
|
||||
W_full,
|
||||
B.contiguous().to(dtype),
|
||||
A.contiguous().to(dtype),
|
||||
magnitude.contiguous(),
|
||||
out,
|
||||
out_features=out_features,
|
||||
in_features=in_features,
|
||||
rank=rank,
|
||||
lora_scale=s,
|
||||
BLOCK_IN=BLOCK_IN,
|
||||
BLOCK_R=BLOCK_R,
|
||||
)
|
||||
|
||||
return out.detach()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -105,6 +105,10 @@ def dequantize(
|
||||
# Extract quantization state
|
||||
if not isinstance(quant_state, list):
|
||||
# New style quant_state class
|
||||
# Non-double-quantized models have offset=None and state2=None
|
||||
if quant_state.offset is None or quant_state.state2 is None:
|
||||
# Fall back to bitsandbytes standard dequantize
|
||||
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
|
||||
absmax = quant_state.absmax.to(target_device)
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
|
||||
333
src/axolotl/kernels/rms_norm_gated.py
Normal file
333
src/axolotl/kernels/rms_norm_gated.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Fused RMSNorm + SiLU Gate Triton kernel.
|
||||
|
||||
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
|
||||
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
|
||||
and silu(G) = G * sigmoid(G)
|
||||
|
||||
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
|
||||
"""
|
||||
|
||||
import math
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from liger_kernel.ops.utils import (
|
||||
calculate_settings,
|
||||
compare_version,
|
||||
ensure_contiguous,
|
||||
torch_to_triton_dtype,
|
||||
)
|
||||
from liger_kernel.utils import is_npu_available
|
||||
|
||||
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
||||
try:
|
||||
from triton.language.extra.libdevice import rsqrt
|
||||
except ModuleNotFoundError:
|
||||
from triton.language.extra.cuda.libdevice import rsqrt
|
||||
else:
|
||||
from triton.language.math import rsqrt
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_gated_forward_kernel(
|
||||
Y_ptr,
|
||||
Y_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
G_ptr,
|
||||
G_row_stride,
|
||||
W_ptr,
|
||||
W_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
eps,
|
||||
offset,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Y = (W + offset) * (X / RMS(X)) * silu(G)
|
||||
|
||||
All computation done in fp32 (Gemma-style), result cast to input dtype.
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
||||
|
||||
X_row_dtype = X_row.dtype
|
||||
|
||||
# Cast everything to fp32
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
G_fp32 = G_row.to(tl.float32)
|
||||
W_fp32 = W_row.to(tl.float32)
|
||||
|
||||
# RMS norm
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
|
||||
X_norm = X_fp32 * rstd
|
||||
|
||||
# SiLU gate: silu(G) = G * sigmoid(G)
|
||||
sig_G = tl.sigmoid(G_fp32)
|
||||
silu_G = G_fp32 * sig_G
|
||||
|
||||
# Fused output
|
||||
Y_row = (offset + W_fp32) * X_norm * silu_G
|
||||
|
||||
tl.store(
|
||||
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
||||
Y_row.to(X_row_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_gated_backward_kernel(
|
||||
dY_ptr,
|
||||
dY_row_stride,
|
||||
dX_ptr,
|
||||
dX_row_stride,
|
||||
dG_ptr,
|
||||
dG_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
X_dtype: tl.constexpr,
|
||||
G_ptr,
|
||||
G_row_stride,
|
||||
W_ptr,
|
||||
W_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
dW_ptr,
|
||||
dW_row_stride,
|
||||
n_rows,
|
||||
n_cols,
|
||||
offset,
|
||||
rows_per_program,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
|
||||
|
||||
dW = sum_batch(dY * X_norm * silu(G))
|
||||
dG = dY * (W + offset) * X_norm * silu'(G)
|
||||
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
||||
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
|
||||
where m = dY * (W + offset) * silu(G)
|
||||
"""
|
||||
row_block_id = tl.program_id(0).to(tl.int64)
|
||||
row_start = row_block_id * rows_per_program
|
||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
||||
W_row = W_row.to(tl.float32) + offset
|
||||
|
||||
for row_idx in range(row_start, row_end):
|
||||
dY_row = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
|
||||
)
|
||||
X_row = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
|
||||
)
|
||||
G_row = tl.load(
|
||||
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
|
||||
)
|
||||
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
# Cast to fp32
|
||||
dY_fp32 = dY_row.to(tl.float32)
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
G_fp32 = G_row.to(tl.float32)
|
||||
|
||||
# Recompute intermediates
|
||||
X_norm = X_fp32 * rstd_row
|
||||
sig_G = tl.sigmoid(G_fp32)
|
||||
silu_G = G_fp32 * sig_G
|
||||
|
||||
# dW: accumulate dY * X_norm * silu(G)
|
||||
dW_acc += dY_fp32 * X_norm * silu_G
|
||||
|
||||
# dG: dY * (W + offset) * X_norm * silu'(G)
|
||||
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
||||
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
|
||||
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
|
||||
tl.store(
|
||||
dG_ptr + row_idx * dG_row_stride + col_offsets,
|
||||
dG_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
|
||||
m = dY_fp32 * W_row * silu_G
|
||||
dX_row = rstd_row * m
|
||||
dX_row += rstd_row * (
|
||||
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
|
||||
)
|
||||
tl.store(
|
||||
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
||||
dX_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
tl.store(
|
||||
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
||||
dW_acc,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_gated_forward(X, G, W, eps, offset):
|
||||
shape = X.shape
|
||||
dim = shape[-1]
|
||||
X = X.view(-1, dim)
|
||||
G = G.view(-1, dim)
|
||||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
|
||||
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
||||
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||
|
||||
assert X.shape[1] == W.shape[0], (
|
||||
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
|
||||
)
|
||||
assert X.shape == G.shape, (
|
||||
f"X and G must have same shape, got {X.shape} and {G.shape}"
|
||||
)
|
||||
|
||||
_rms_norm_gated_forward_kernel[(n_rows,)](
|
||||
Y,
|
||||
Y.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
G,
|
||||
G.stride(0),
|
||||
W,
|
||||
W.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
eps,
|
||||
offset,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
|
||||
|
||||
|
||||
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
|
||||
shape = dY.shape
|
||||
dim = shape[-1]
|
||||
dY = dY.view(-1, dim)
|
||||
n_rows, n_cols = dY.shape
|
||||
|
||||
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
||||
|
||||
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
||||
dX = torch.empty_like(dY)
|
||||
dG = torch.empty_like(dY)
|
||||
|
||||
rows_per_program = math.ceil(n_rows / sm_count)
|
||||
grid = (sm_count,)
|
||||
|
||||
_rms_norm_gated_backward_kernel[grid](
|
||||
dY,
|
||||
dY.stride(0),
|
||||
dX,
|
||||
dX.stride(0),
|
||||
dG,
|
||||
dG.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
torch_to_triton_dtype[X.dtype],
|
||||
G,
|
||||
G.stride(0),
|
||||
W,
|
||||
W.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
_dW,
|
||||
_dW.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
offset,
|
||||
rows_per_program,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
dX = dX.view(*shape)
|
||||
dG = dG.view(*shape)
|
||||
dW = _dW.sum(dim=0).to(W.dtype)
|
||||
return dX, dG, dW
|
||||
|
||||
|
||||
class FusedRMSNormGatedFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, G, W, eps, offset=0.0):
|
||||
"""
|
||||
X: (B, T, H) or (BxT, H) — input hidden states
|
||||
G: (B, T, H) or (BxT, H) — gate tensor
|
||||
W: (H,) — weight parameter
|
||||
"""
|
||||
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
|
||||
X, G, W, eps, offset
|
||||
)
|
||||
ctx.offset = offset
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(X, G, W, RSTD)
|
||||
return Y
|
||||
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def backward(ctx, dY):
|
||||
X, G, W, RSTD = ctx.saved_tensors
|
||||
dX, dG, dW = rms_norm_gated_backward(
|
||||
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
|
||||
)
|
||||
return dX, dG, dW, None, None
|
||||
|
||||
|
||||
class FusedRMSNormGated(torch.nn.Module):
|
||||
"""
|
||||
Fused RMSNorm + SiLU Gate.
|
||||
|
||||
Computes: Y = W * RMSNorm(X) * silu(G)
|
||||
|
||||
Drop-in replacement for Qwen3_5RMSNormGated with matching
|
||||
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
|
||||
and forward signature: forward(hidden_states, gate=None)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
self.offset = offset
|
||||
|
||||
def forward(self, hidden_states, gate=None):
|
||||
if gate is None:
|
||||
raise ValueError("FusedRMSNormGated requires a gate tensor")
|
||||
if hidden_states.device.type != "cuda":
|
||||
raise ValueError(
|
||||
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
|
||||
)
|
||||
return FusedRMSNormGatedFunction.apply(
|
||||
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
@@ -505,6 +505,20 @@ class ModelLoader:
|
||||
elif not is_ds_zero3:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
|
||||
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
|
||||
# so the actual VRAM usage is much less than bf16 estimates.
|
||||
# When device_map is "auto", accelerate's infer_auto_device_map computes
|
||||
# the device map at bf16 size (before quantization), causing it to offload
|
||||
# layers to CPU, which BnB then rejects. Force single-GPU placement to
|
||||
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
|
||||
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
|
||||
"auto",
|
||||
None,
|
||||
):
|
||||
self.model_kwargs["device_map"] = {
|
||||
"": int(os.environ.get("LOCAL_RANK", 0))
|
||||
}
|
||||
|
||||
cur_device = get_device_type()
|
||||
if "mps" in str(cur_device):
|
||||
self.model_kwargs["device_map"] = "mps:0"
|
||||
|
||||
@@ -571,15 +571,6 @@ class PatchManager:
|
||||
LOG.info("Patching with xformers attention...")
|
||||
hijack_llama_attention()
|
||||
|
||||
def _patch_llama_sample_packing(self):
|
||||
"""Apply sample packing patches for LLaMA models."""
|
||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
||||
hijack_llama_prepare_4d_mask,
|
||||
)
|
||||
|
||||
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
|
||||
hijack_llama_prepare_4d_mask()
|
||||
|
||||
def _patch_llama_derived_model(self):
|
||||
"""Modify all llama derived models in one block."""
|
||||
if self.cfg.is_llama_derived_model and not (
|
||||
@@ -591,8 +582,6 @@ class PatchManager:
|
||||
self._patch_llama_flash_attention()
|
||||
elif self.cfg.xformers_attention:
|
||||
self._patch_llama_xformers_attention()
|
||||
elif self.cfg.sample_packing:
|
||||
self._patch_llama_sample_packing()
|
||||
elif self.cfg.s2_attention:
|
||||
raise NotImplementedError(
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
|
||||
@@ -221,6 +221,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
# Generic fallback: if tokenizer still has no pad_token, use eos_token
|
||||
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
LOG.warning(
|
||||
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
|
||||
tokenizer.eos_token,
|
||||
)
|
||||
|
||||
additional_special_tokens = None
|
||||
if cfg.special_tokens:
|
||||
special_tokens = cfg.special_tokens.to_dict()
|
||||
|
||||
@@ -78,30 +78,21 @@ def patch_parallelism_config():
|
||||
|
||||
|
||||
def patch_prepare_cp():
|
||||
import functools
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def patched_prepare_cp(self, *args):
|
||||
if self.parallelism_config.cp_backend == "deepspeed":
|
||||
return args
|
||||
|
||||
from accelerate.big_modeling import _attach_context_parallel_hooks
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
|
||||
set_rotate_method(cp_comm_strategy)
|
||||
|
||||
self._cp_context = functools.partial(
|
||||
context_parallel, mesh=self.torch_device_mesh["cp"]
|
||||
)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
_attach_context_parallel_hooks(arg)
|
||||
@contextlib.contextmanager
|
||||
def _noop_cp_context(
|
||||
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
|
||||
):
|
||||
yield
|
||||
|
||||
self._cp_context = _noop_cp_context
|
||||
return args
|
||||
|
||||
Accelerator._prepare_cp = patched_prepare_cp
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
"""
|
||||
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.utils import mask_2d_to_4d
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
|
||||
inverted_mask = 1.0 - masked_zero_one_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def hijack_expand_mask():
|
||||
import transformers
|
||||
|
||||
transformers.models.llama.modeling_llama._expand_mask = _expand_mask
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
from axolotl.monkeypatch.utils import (
|
||||
patched_prepare_4d_causal_attention_mask,
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
|
||||
|
||||
def hijack_llama_prepare_4d_mask():
|
||||
from transformers import modeling_attn_mask_utils
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = (
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = (
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
modeling_llama._prepare_4d_causal_attention_mask = (
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = (
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
@@ -12,6 +12,7 @@ from torch import nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_embedding,
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
@@ -51,6 +52,29 @@ QKV_PATCHES = [
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
(
|
||||
"""
|
||||
query_states, gate = torch.chunk(
|
||||
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
||||
)
|
||||
gate = gate.reshape(*input_shape, -1)
|
||||
|
||||
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
"""
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states, gate = torch.chunk(
|
||||
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
||||
)
|
||||
gate = gate.reshape(*input_shape, -1)
|
||||
|
||||
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
]
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
@@ -299,6 +323,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
|
||||
if hasattr(pretrained_model, "language_model"):
|
||||
return pretrained_model.language_model.layers
|
||||
if hasattr(pretrained_model, "model"):
|
||||
if hasattr(pretrained_model.model, "language_model"):
|
||||
return pretrained_model.model.language_model.layers
|
||||
return pretrained_model.model.layers
|
||||
|
||||
raise NotImplementedError(
|
||||
@@ -345,13 +371,13 @@ def apply_lora_kernel_patches(
|
||||
active_adapter = model.active_adapter
|
||||
lora_config = model.model.peft_config[active_adapter]
|
||||
|
||||
# Only patch if conditions are met
|
||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||
return model
|
||||
# Log what features are active
|
||||
if lora_config.lora_dropout > 0:
|
||||
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
|
||||
if lora_config.bias != "none":
|
||||
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
|
||||
if lora_config.use_dora:
|
||||
LOG.info("LoRA kernels: DoRA enabled")
|
||||
|
||||
# This needs to be reset after patching
|
||||
original_level = LOG.getEffectiveLevel()
|
||||
@@ -394,44 +420,33 @@ def apply_lora_kernel_patches(
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
hasattr(module, "lora_A") for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_qkv:
|
||||
# Add optimized implementation
|
||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA "
|
||||
"adapters and no lora_magnitude_vector (DoRA)"
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
layer_modules = [
|
||||
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||
|
||||
if can_patch_o:
|
||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA "
|
||||
"adapters and no lora_magnitude_vector (DoRA)"
|
||||
"Cannot patch some attention output projection - requires LoRA adapters"
|
||||
)
|
||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
|
||||
if can_patch_mlp:
|
||||
@@ -439,15 +454,50 @@ def apply_lora_kernel_patches(
|
||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
||||
"lora_magnitude_vector (DoRA)"
|
||||
"Cannot patch some MLP layers - requires LoRA adapters"
|
||||
)
|
||||
|
||||
# Patch embedding layers (model-level, not per-layer)
|
||||
if cfg.lora_embedding_kernel:
|
||||
_patch_embedding_layers(model, cfg)
|
||||
|
||||
LOG.setLevel(original_level)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
|
||||
"""Patch embedding layers with fused LoRA kernel.
|
||||
|
||||
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
|
||||
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
|
||||
"""
|
||||
pretrained_model = model.model
|
||||
patched = 0
|
||||
|
||||
# Find embedding modules - check common locations
|
||||
for attr_path in [
|
||||
("model", "embed_tokens"),
|
||||
("model", "language_model", "embed_tokens"),
|
||||
]:
|
||||
parent = pretrained_model
|
||||
for attr in attr_path:
|
||||
parent = getattr(parent, attr, None)
|
||||
if parent is None:
|
||||
break
|
||||
if parent is not None and hasattr(parent, "lora_embedding_A"):
|
||||
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
|
||||
parent.forward = types.MethodType(apply_lora_embedding, parent)
|
||||
patched += 1
|
||||
|
||||
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
|
||||
# when included in target_modules. No special embedding handling needed since
|
||||
# PEFT wraps it as a Linear (not Embedding) even for tied models.
|
||||
|
||||
if not patched:
|
||||
LOG.debug("No embedding layers with LoRA found to patch")
|
||||
|
||||
|
||||
class FakeMLP(nn.Module):
|
||||
"""
|
||||
placeholder MLP for triton patching
|
||||
|
||||
@@ -59,6 +59,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"ministral3",
|
||||
"mistral4",
|
||||
"afmoe",
|
||||
"nemotron",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -3,15 +3,10 @@ Shared utils for the monkeypatches
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -170,65 +165,6 @@ def set_module_name(model, name, value):
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
|
||||
def mask_2d_to_4d(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1, device=mask.device).to(dtype),
|
||||
torch.tensor(0, device=mask.device).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
|
||||
return masked_zero_one_mask
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def detab_code(code: str) -> Tuple[str, str]:
|
||||
try:
|
||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||
|
||||
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Synthetic dataset generator for benchmarking and testing.
|
||||
|
||||
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
|
||||
Useful for benchmarking memory usage and speed by sequence length, and for validating
|
||||
weighted dataset mixes.
|
||||
|
||||
YAML configuration example:
|
||||
|
||||
datasets:
|
||||
- path: synthetic
|
||||
type: _synthetic
|
||||
length: 1000
|
||||
sequence_length: 2048
|
||||
min_input_id: 100
|
||||
max_input_id: 32000
|
||||
seed: 42
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
|
||||
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sequence_length: int = 2048,
|
||||
length: int = 1000,
|
||||
min_input_id: int = 100,
|
||||
max_input_id: int = 32000,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
self.sequence_length = sequence_length
|
||||
self.length = length
|
||||
self.min_input_id = min_input_id
|
||||
self.max_input_id = max_input_id
|
||||
self.seed = seed
|
||||
|
||||
def wrap_dataset(
|
||||
self,
|
||||
dataset,
|
||||
process_count: int | None = None,
|
||||
keep_in_memory: bool | None = False,
|
||||
**kwargs,
|
||||
) -> Dataset:
|
||||
LOG.info(
|
||||
f"Generating synthetic dataset: {self.length} samples, "
|
||||
f"sequence_length={self.sequence_length}, "
|
||||
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
|
||||
)
|
||||
|
||||
rng = np.random.default_rng(self.seed)
|
||||
input_ids = rng.integers(
|
||||
low=self.min_input_id,
|
||||
high=self.max_input_id,
|
||||
size=(self.length, self.sequence_length),
|
||||
).tolist()
|
||||
|
||||
attention_mask = [[1] * self.sequence_length] * self.length
|
||||
# labels == input_ids means we train on all tokens
|
||||
labels = [row[:] for row in input_ids]
|
||||
|
||||
return Dataset.from_dict(
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
|
||||
length = ds_cfg.get("length", 1000)
|
||||
min_input_id = ds_cfg.get("min_input_id", 100)
|
||||
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
|
||||
seed = ds_cfg.get("seed", None)
|
||||
|
||||
return SyntheticDatasetStrategy(
|
||||
sequence_length=sequence_length,
|
||||
length=length,
|
||||
min_input_id=min_input_id,
|
||||
max_input_id=max_input_id,
|
||||
seed=seed,
|
||||
)
|
||||
@@ -82,7 +82,7 @@ def setup_model_and_tokenizer(
|
||||
|
||||
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
||||
model, peft_config = model_loader.load()
|
||||
if model.generation_config is not None:
|
||||
if getattr(model, "generation_config", None) is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_properties = model.config.to_dict()
|
||||
|
||||
@@ -17,6 +17,8 @@ from transformers import (
|
||||
class PytorchProfilerCallback(TrainerCallback):
|
||||
"""
|
||||
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
||||
|
||||
Also runs torch.profiler to produce a Chrome trace for timing analysis.
|
||||
"""
|
||||
|
||||
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
|
||||
@@ -26,9 +28,10 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
if profiler_steps_start == 0:
|
||||
# start recording memory allocations before everything is allocated, because if we start
|
||||
# at the beginning of step 0, we won't have any memory allocations in the traces
|
||||
torch.cuda.memory._record_memory_history(enabled="all")
|
||||
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||
profiler_steps_start = -1
|
||||
self.profiler_steps_start = profiler_steps_start
|
||||
self._profiler = None
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
@@ -38,7 +41,21 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
**kwargs,
|
||||
):
|
||||
if state.global_step == self.profiler_steps_start:
|
||||
torch.cuda.memory._record_memory_history(enabled="all")
|
||||
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||
|
||||
# Start torch.profiler on the first profiled step
|
||||
if state.global_step == max(self.profiler_steps_start, 0):
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
profiler.__enter__()
|
||||
self._profiler = profiler
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
@@ -55,6 +72,13 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
# tell CUDA to stop recording memory allocations now
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
|
||||
# Stop and export torch.profiler trace
|
||||
if self._profiler is not None:
|
||||
self._profiler.__exit__(None, None, None)
|
||||
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||
self._profiler.export_chrome_trace(str(trace_path))
|
||||
self._profiler = None
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
@@ -73,3 +97,9 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
|
||||
# tell CUDA to stop recording memory allocations now
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
|
||||
if self._profiler is not None:
|
||||
self._profiler.__exit__(None, None, None)
|
||||
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||
self._profiler.export_chrome_trace(str(trace_path))
|
||||
self._profiler = None
|
||||
|
||||
@@ -25,9 +25,11 @@ def toggle_fake_quant(mod: nn.Module, enable: bool):
|
||||
if (
|
||||
isinstance(mod, FakeQuantizedLinear)
|
||||
and mod.activation_fake_quantizer is not None
|
||||
and hasattr(mod.activation_fake_quantizer, "enabled")
|
||||
):
|
||||
mod.activation_fake_quantizer.enabled = enable
|
||||
mod.weight_fake_quantizer.enabled = enable
|
||||
if hasattr(mod.weight_fake_quantizer, "enabled"):
|
||||
mod.weight_fake_quantizer.enabled = enable
|
||||
|
||||
|
||||
class QATCallback(TrainerCallback):
|
||||
|
||||
@@ -12,12 +12,11 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state.json"
|
||||
|
||||
|
||||
class TokensPerSecondCallback(TrainerCallback):
|
||||
"""
|
||||
|
||||
@@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||
from axolotl.utils.schemas.datasets import (
|
||||
DPODataset,
|
||||
KTODataset,
|
||||
SFTDataset,
|
||||
SyntheticDataset,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -308,6 +313,14 @@ def validate_config(
|
||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||
elif (
|
||||
ds_cfg.get("type")
|
||||
if isinstance(ds_cfg, dict)
|
||||
else getattr(ds_cfg, "type", None)
|
||||
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
|
||||
cfg["datasets"][idx] = SyntheticDataset(
|
||||
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
|
||||
)
|
||||
elif not isinstance(ds_cfg, SFTDataset):
|
||||
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||
|
||||
|
||||
@@ -376,10 +376,14 @@ def _load_and_process_single_dataset(
|
||||
streaming: bool = False,
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Load and process a single dataset based on the passed config."""
|
||||
# Load the dataset
|
||||
dataset = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
||||
)
|
||||
# For synthetic datasets, create a minimal placeholder instead of loading from path
|
||||
if dataset_config.type == "_synthetic":
|
||||
dataset = Dataset.from_dict({"text": [""]})
|
||||
else:
|
||||
# Load the dataset
|
||||
dataset = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
||||
)
|
||||
|
||||
# Parse dataset type
|
||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||
|
||||
@@ -10,9 +10,11 @@ from torchao.quantization import quantize_
|
||||
from torchao.quantization.qat import (
|
||||
QATConfig,
|
||||
)
|
||||
from torchao.quantization.qat.fake_quantize_config import Int4WeightFakeQuantizeConfig
|
||||
from torchao.quantization.quant_api import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
)
|
||||
|
||||
@@ -173,6 +175,70 @@ def quantize_model(
|
||||
)
|
||||
|
||||
|
||||
def _make_qat_config(
|
||||
base_config: AOBaseConfig,
|
||||
weight_dtype: TorchAOQuantDType,
|
||||
activation_dtype: TorchAOQuantDType | None,
|
||||
group_size: int | None,
|
||||
) -> QATConfig:
|
||||
"""Build a QATConfig, explicitly constructing fake quantize configs to ensure
|
||||
group_size and other params are properly propagated (torchao's QATConfig(base_config)
|
||||
does not always map these correctly)."""
|
||||
from torchao.quantization.qat.fake_quantize_config import (
|
||||
Float8FakeQuantizeConfig,
|
||||
IntxFakeQuantizeConfig,
|
||||
)
|
||||
|
||||
if isinstance(base_config, MXFakeQuantizeConfig):
|
||||
return QATConfig(
|
||||
activation_config=base_config,
|
||||
weight_config=base_config,
|
||||
)
|
||||
|
||||
# Build explicit weight config
|
||||
weight_fq_config: (
|
||||
Int4WeightFakeQuantizeConfig
|
||||
| IntxFakeQuantizeConfig
|
||||
| Float8FakeQuantizeConfig
|
||||
| None
|
||||
) = None
|
||||
if weight_dtype == TorchAOQuantDType.int4:
|
||||
gs = (
|
||||
group_size
|
||||
if group_size is not None
|
||||
else getattr(base_config, "group_size", 128)
|
||||
)
|
||||
activation_dt = None
|
||||
if activation_dtype == TorchAOQuantDType.int8:
|
||||
activation_dt = torch.bfloat16
|
||||
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
|
||||
activation_dt = torch.float8_e4m3fn
|
||||
kwargs = {"group_size": gs}
|
||||
if activation_dt is not None:
|
||||
kwargs["activation_dtype"] = activation_dt
|
||||
weight_fq_config = Int4WeightFakeQuantizeConfig(**kwargs)
|
||||
elif weight_dtype == TorchAOQuantDType.float8_e4m3fn:
|
||||
weight_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
|
||||
|
||||
# Build explicit activation config
|
||||
activation_fq_config = None
|
||||
if activation_dtype == TorchAOQuantDType.int8:
|
||||
activation_fq_config = IntxFakeQuantizeConfig(
|
||||
dtype=torch.int8, granularity="per_token", is_symmetric=False
|
||||
)
|
||||
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
|
||||
activation_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
|
||||
|
||||
if weight_fq_config is not None:
|
||||
return QATConfig(
|
||||
weight_config=weight_fq_config,
|
||||
activation_config=activation_fq_config,
|
||||
)
|
||||
|
||||
# Fallback to base_config for unhandled combos
|
||||
return QATConfig(base_config)
|
||||
|
||||
|
||||
def prepare_model_for_qat(
|
||||
model,
|
||||
weight_dtype: TorchAOQuantDType,
|
||||
@@ -200,13 +266,9 @@ def prepare_model_for_qat(
|
||||
activation_dtype=activation_dtype,
|
||||
group_size=group_size,
|
||||
)
|
||||
if isinstance(base_config, MXFakeQuantizeConfig):
|
||||
qat_config = QATConfig(
|
||||
activation_config=base_config,
|
||||
weight_config=base_config,
|
||||
)
|
||||
else:
|
||||
qat_config = QATConfig(base_config)
|
||||
qat_config = _make_qat_config(
|
||||
base_config, weight_dtype, activation_dtype, group_size
|
||||
)
|
||||
quantize_(model, qat_config)
|
||||
if quantize_embedding:
|
||||
# activation fake quantization is not supported for embedding layers
|
||||
@@ -215,12 +277,9 @@ def prepare_model_for_qat(
|
||||
activation_dtype=None,
|
||||
group_size=group_size,
|
||||
)
|
||||
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
|
||||
embedding_qat_config = QATConfig(
|
||||
weight_config=embedding_base_config,
|
||||
)
|
||||
else:
|
||||
embedding_qat_config = QATConfig(embedding_base_config)
|
||||
embedding_qat_config = _make_qat_config(
|
||||
embedding_base_config, weight_dtype, None, group_size
|
||||
)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_qat_config,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
from typing import Any, Sequence
|
||||
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
@@ -340,3 +340,19 @@ class JaggedLRRestartScheduler(LRScheduler):
|
||||
return [lr * scale for lr in original]
|
||||
|
||||
return original * scale
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
"""Return serializable state, saving inner_schedule as its own state_dict."""
|
||||
state = {
|
||||
key: value
|
||||
for key, value in self.__dict__.items()
|
||||
if key not in ("optimizer", "inner_schedule")
|
||||
}
|
||||
state["inner_schedule_state"] = self.inner_schedule.state_dict()
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
||||
"""Restore state, including inner_schedule."""
|
||||
inner_state = state_dict.pop("inner_schedule_state")
|
||||
self.__dict__.update(state_dict)
|
||||
self.inner_schedule.load_state_dict(inner_state)
|
||||
|
||||
@@ -22,6 +22,7 @@ from axolotl.utils.schemas.datasets import (
|
||||
PretrainingDataset,
|
||||
SFTDataset,
|
||||
StepwiseSupervisedDataset,
|
||||
SyntheticDataset,
|
||||
)
|
||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
||||
@@ -185,7 +186,13 @@ class AxolotlInputConfig(
|
||||
|
||||
datasets: (
|
||||
Annotated[
|
||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||
list[
|
||||
SFTDataset
|
||||
| DPODataset
|
||||
| KTODataset
|
||||
| StepwiseSupervisedDataset
|
||||
| SyntheticDataset
|
||||
],
|
||||
MinLen(1),
|
||||
]
|
||||
| None
|
||||
@@ -198,7 +205,13 @@ class AxolotlInputConfig(
|
||||
|
||||
test_datasets: (
|
||||
Annotated[
|
||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||
list[
|
||||
SFTDataset
|
||||
| DPODataset
|
||||
| KTODataset
|
||||
| StepwiseSupervisedDataset
|
||||
| SyntheticDataset
|
||||
],
|
||||
MinLen(1),
|
||||
]
|
||||
| None
|
||||
@@ -433,6 +446,12 @@ class AxolotlInputConfig(
|
||||
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
|
||||
},
|
||||
)
|
||||
layer_offloading: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Offload model layer parameters to CPU during forward, prefetch back during backward."
|
||||
},
|
||||
)
|
||||
|
||||
unfrozen_parameters: list[str] | None = Field(
|
||||
default=None,
|
||||
@@ -684,6 +703,12 @@ class AxolotlInputConfig(
|
||||
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||
},
|
||||
)
|
||||
lora_embedding_kernel: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||
},
|
||||
)
|
||||
|
||||
chunked_cross_entropy: bool | None = Field(
|
||||
default=None,
|
||||
@@ -1294,6 +1319,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
or data.get("lora_embedding_kernel")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp_config") is not None
|
||||
@@ -1341,7 +1367,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("adapter") in ["lora", "qlora"]:
|
||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||
kernel_fields = [
|
||||
"lora_mlp_kernel",
|
||||
"lora_qkv_kernel",
|
||||
"lora_o_kernel",
|
||||
"lora_embedding_kernel",
|
||||
]
|
||||
if (
|
||||
any(data.get(k) is not None for k in kernel_fields)
|
||||
or any(data.get(k) for k in unsloth_fields)
|
||||
@@ -1354,9 +1385,38 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("trust_remote_code"):
|
||||
return data
|
||||
|
||||
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
||||
if data.get("lora_dropout") != 0:
|
||||
return data
|
||||
# Skip auto-enable for MoE models when native grouped_mm is unavailable
|
||||
# (torch < 2.9). The grouped_mm fallback in transformers uses torch.mm
|
||||
# with out= which bypasses autocast and fails on mixed dtypes during eval.
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
has_grouped_mm = version.parse(torch_version) >= version.parse("2.9.0")
|
||||
if not has_grouped_mm:
|
||||
is_moe = False
|
||||
model_type = data.get("model_config_type", "")
|
||||
if model_type and "moe" in model_type.lower():
|
||||
is_moe = True
|
||||
if not is_moe:
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
base_model = data.get("base_model")
|
||||
if base_model:
|
||||
auto_cfg = AutoConfig.from_pretrained(
|
||||
base_model, trust_remote_code=False
|
||||
)
|
||||
if getattr(auto_cfg, "num_local_experts", None) or getattr(
|
||||
auto_cfg, "num_experts", None
|
||||
):
|
||||
is_moe = True
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
if is_moe:
|
||||
return data
|
||||
|
||||
# Check multi-GPU compatibility
|
||||
capabilities = data.get("capabilities")
|
||||
@@ -1379,6 +1439,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("lora_o_kernel") is None:
|
||||
data["lora_o_kernel"] = True
|
||||
|
||||
if data.get("lora_embedding_kernel") is None:
|
||||
data["lora_embedding_kernel"] = True
|
||||
|
||||
LOG.warning(
|
||||
"Auto-enabling LoRA kernel optimizations for faster training. "
|
||||
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
||||
|
||||
@@ -296,4 +296,42 @@ class KTODataset(BaseModel):
|
||||
revision: str | None = None
|
||||
|
||||
|
||||
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
||||
class SyntheticDataset(BaseModel):
|
||||
"""Synthetic dataset configuration for benchmarking and testing.
|
||||
|
||||
Generates datasets with configurable sequence length, dataset size, and token ID
|
||||
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
|
||||
validating weighted dataset mixes.
|
||||
"""
|
||||
|
||||
path: Literal["synthetic"] = "synthetic"
|
||||
type: Literal["_synthetic"] = "_synthetic"
|
||||
length: int = Field(
|
||||
default=1000,
|
||||
json_schema_extra={"description": "Number of rows to generate"},
|
||||
)
|
||||
sequence_length: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Sequence length per row (defaults to sequence_len from config)"
|
||||
},
|
||||
)
|
||||
min_input_id: int = Field(
|
||||
default=100,
|
||||
json_schema_extra={"description": "Minimum token ID for generation"},
|
||||
)
|
||||
max_input_id: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
|
||||
},
|
||||
)
|
||||
seed: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Random seed for reproducibility"},
|
||||
)
|
||||
|
||||
|
||||
DatasetConfig = (
|
||||
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
|
||||
)
|
||||
|
||||
@@ -87,6 +87,11 @@ class CustomSupportedOptimizers(str, Enum):
|
||||
came_pytorch = "came_pytorch"
|
||||
muon = "muon"
|
||||
dion = "dion"
|
||||
flash_adamw = "flash_adamw"
|
||||
flash_adam = "flash_adam"
|
||||
flash_sgd = "flash_sgd"
|
||||
flash_sgdw = "flash_sgdw"
|
||||
flash_lion = "flash_lion"
|
||||
|
||||
|
||||
class RingAttnFunc(str, Enum):
|
||||
|
||||
@@ -253,6 +253,23 @@ class TrainingValidationMixin:
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_reward_model_defaults(cls, data):
|
||||
if data.get("reward_model"):
|
||||
if data.get("num_labels") is None:
|
||||
data["num_labels"] = 1
|
||||
if not (data.get("type_of_model") or data.get("model_type")):
|
||||
data["model_type"] = "AutoModelForSequenceClassification"
|
||||
|
||||
if data.get("process_reward_model"):
|
||||
if data.get("num_labels") is None:
|
||||
data["num_labels"] = 2
|
||||
if not (data.get("type_of_model") or data.get("model_type")):
|
||||
data["model_type"] = "AutoModelForTokenClassification"
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gas_bsz(cls, data):
|
||||
@@ -664,15 +681,7 @@ class LoRAValidationMixin:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernels_dora(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
) and data.get("peft_use_dora"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||
"compatible with DoRA at the moment."
|
||||
)
|
||||
# DoRA is now supported by lora kernels
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -773,6 +782,14 @@ class OptimizationValidationMixin:
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _resolve_fsdp_version(data):
|
||||
"""Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version."""
|
||||
fsdp_version = data.get("fsdp_version")
|
||||
if fsdp_version is None:
|
||||
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||
return fsdp_version
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
@@ -782,15 +799,32 @@ class OptimizationValidationMixin:
|
||||
"Muon optimizer is currently incompatible with DeepSpeed"
|
||||
)
|
||||
if data.get("fsdp") or data.get("fsdp_config"):
|
||||
fsdp_version = data.get("fsdp_version")
|
||||
if fsdp_version is None:
|
||||
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||
fsdp_version = cls._resolve_fsdp_version(data)
|
||||
if str(fsdp_version) != "2":
|
||||
raise ValueError(
|
||||
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_flashoptim_deepspeed_fsdp(cls, data):
|
||||
optimizer = data.get("optimizer") or ""
|
||||
if str(optimizer).startswith("flash_"):
|
||||
if data.get("deepspeed"):
|
||||
raise ValueError(
|
||||
f"{optimizer} optimizer is incompatible with DeepSpeed. "
|
||||
"Flash optimizers only support DDP and FSDP2."
|
||||
)
|
||||
if data.get("fsdp") or data.get("fsdp_config"):
|
||||
fsdp_version = cls._resolve_fsdp_version(data)
|
||||
if str(fsdp_version) != "2":
|
||||
raise ValueError(
|
||||
f"{optimizer} optimizer is only compatible with FSDP2. "
|
||||
"Set fsdp_version: 2 to use flash optimizers with FSDP."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_flattening_fa(cls, data):
|
||||
|
||||
@@ -15,6 +15,8 @@ import datasets
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
import transformers.utils as _transformers_utils
|
||||
import transformers.utils.import_utils as _import_utils
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
from tokenizers import AddedToken
|
||||
@@ -29,6 +31,26 @@ from tests.hf_offline_utils import (
|
||||
|
||||
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||
|
||||
# Shim for deepseek v3
|
||||
if not hasattr(_import_utils, "is_torch_fx_available"):
|
||||
|
||||
def _is_torch_fx_available():
|
||||
try:
|
||||
import torch.fx # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
_import_utils.is_torch_fx_available = _is_torch_fx_available
|
||||
|
||||
if not hasattr(_transformers_utils, "is_flash_attn_greater_or_equal_2_10"):
|
||||
from transformers.utils import is_flash_attn_greater_or_equal as _is_flash_attn_gte
|
||||
|
||||
_transformers_utils.is_flash_attn_greater_or_equal_2_10 = lambda: (
|
||||
_is_flash_attn_gte("2.10")
|
||||
)
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
def decorator(func):
|
||||
|
||||
@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
||||
|
||||
proj.base_layer = base_layer
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
||||
# quant_state should be None since weight is bf16, not FP8
|
||||
self.assertIsNone(quant_state)
|
||||
|
||||
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
||||
scale_inv = torch.ones(1)
|
||||
base_layer.weight_scale_inv = scale_inv
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
||||
self.assertIs(quant_state, scale_inv)
|
||||
|
||||
|
||||
|
||||
@@ -536,7 +536,7 @@ class TestHFCausalTrainerBuilder:
|
||||
"cfg_string",
|
||||
[
|
||||
"sft_cfg",
|
||||
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
|
||||
"rm_cfg",
|
||||
"prm_cfg",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ Test strategy:
|
||||
- Tolerances account for tf32 accumulation in Triton kernels
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
@@ -34,6 +35,21 @@ pytestmark = pytest.mark.skipif(
|
||||
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
|
||||
|
||||
|
||||
def skip_on_out_of_resources(func):
|
||||
"""Skip test if Triton kernel exceeds GPU shared memory limits."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
if "OutOfResources" in type(exc).__name__:
|
||||
pytest.skip(f"GPU shared memory too small: {exc}")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helpers
|
||||
# =============================================================================
|
||||
@@ -209,6 +225,7 @@ def make_test_data(
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestForwardPass:
|
||||
"""Test forward pass of fused scatter2scatter_lora kernel."""
|
||||
|
||||
@@ -288,6 +305,7 @@ class TestForwardPass:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestForwardGrouped:
|
||||
"""Test forward pass with grouped_in/grouped_out configurations."""
|
||||
|
||||
@@ -377,6 +395,7 @@ class TestForwardGrouped:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestLoRAGradients:
|
||||
"""Test backward LoRA gradient computation (dA, dB)."""
|
||||
|
||||
@@ -452,6 +471,7 @@ class TestLoRAGradients:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestAutograd:
|
||||
"""Test full autograd integration through ScatterMoELoRA."""
|
||||
|
||||
@@ -620,6 +640,7 @@ class TestAutograd:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestBaseEquivalence:
|
||||
"""When scaling=0, fused kernel should match base scatter2scatter."""
|
||||
|
||||
@@ -692,6 +713,7 @@ class TestBaseEquivalence:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestLoRAAdditivity:
|
||||
"""Test that the LoRA component is correctly additive."""
|
||||
|
||||
@@ -749,6 +771,7 @@ class TestLoRAAdditivity:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestParallelExpertsModule:
|
||||
"""Test the ParallelExperts module with LoRA."""
|
||||
|
||||
@@ -816,6 +839,7 @@ class TestParallelExpertsModule:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestEdgeCases:
|
||||
"""Edge cases and boundary conditions."""
|
||||
|
||||
@@ -913,6 +937,7 @@ class TestEdgeCases:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestFusedDX:
|
||||
"""Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A."""
|
||||
|
||||
@@ -980,6 +1005,7 @@ class TestFusedDX:
|
||||
def test_basic(self):
|
||||
self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
||||
|
||||
@skip_on_out_of_resources
|
||||
def test_large(self):
|
||||
self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
||||
|
||||
@@ -1122,6 +1148,7 @@ class TestFusedDX:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestFusedGatherBackward:
|
||||
"""Test fused gather + backward dA/dB kernel."""
|
||||
|
||||
@@ -1174,6 +1201,7 @@ class TestFusedGatherBackward:
|
||||
def test_basic(self):
|
||||
self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2)
|
||||
|
||||
@skip_on_out_of_resources
|
||||
def test_large(self):
|
||||
self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2)
|
||||
|
||||
@@ -1183,6 +1211,7 @@ class TestFusedGatherBackward:
|
||||
def test_k1(self):
|
||||
self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1)
|
||||
|
||||
@skip_on_out_of_resources
|
||||
def test_many_experts(self):
|
||||
self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4)
|
||||
|
||||
@@ -1269,6 +1298,8 @@ class TestFusedGatherBackward:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.xfail(reason="flaky", strict=False)
|
||||
class TestTokenRounding:
|
||||
"""Test token rounding utility and its integration with backward kernels."""
|
||||
|
||||
@@ -1315,6 +1346,7 @@ class TestTokenRounding:
|
||||
)
|
||||
prev = padded_offsets[e].item()
|
||||
|
||||
@skip_on_out_of_resources
|
||||
def test_round_with_fused_gather(self):
|
||||
"""Token rounding + fused gather gives same result as plain fused gather."""
|
||||
from importlib import import_module
|
||||
@@ -1414,6 +1446,7 @@ class TestTokenRounding:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestCombinedOptimizations:
|
||||
"""Test all optimizations together."""
|
||||
|
||||
@@ -1583,6 +1616,7 @@ def _make_mock_sigmoid_moe_block(
|
||||
return moe_block, T, H, FF, E, K
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestHFScatterMoESigmoidRouting:
|
||||
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
|
||||
|
||||
@@ -1724,6 +1758,7 @@ class TestHFScatterMoESigmoidRouting:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestHFScatterMoESigmoidWithSharedExperts:
|
||||
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
|
||||
|
||||
|
||||
@@ -933,7 +933,7 @@ class TestKernelizeIntegration:
|
||||
def _get_repo_path():
|
||||
"""Get the path to scattermoe_lora within axolotl's plugin."""
|
||||
return (
|
||||
Path(__file__).parent.parent.parent
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "src"
|
||||
/ "axolotl"
|
||||
/ "integrations"
|
||||
@@ -1219,7 +1219,7 @@ class TestSharedExpertHandling:
|
||||
|
||||
# Kernelize
|
||||
repo_path = (
|
||||
Path(__file__).parent.parent.parent
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "src"
|
||||
/ "axolotl"
|
||||
/ "integrations"
|
||||
|
||||
@@ -102,7 +102,7 @@ def mock_proj():
|
||||
def test_get_lora_parameters(mock_proj):
|
||||
"""Tests get_lora_parameters function"""
|
||||
# Test with LoRA enabled
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||
|
||||
assert isinstance(W, torch.Tensor)
|
||||
assert W.shape == (128, 64)
|
||||
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
|
||||
|
||||
# Test with LoRA disabled
|
||||
mock_proj.disable_adapters = True
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
# Test with merged state
|
||||
mock_proj.disable_adapters = False
|
||||
mock_proj.merged = True
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
|
||||
@@ -176,24 +176,31 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
|
||||
X.requires_grad = True
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
None, # gate_lora_bias
|
||||
None, # gate_magnitude
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
None, # up_lora_bias
|
||||
None, # up_magnitude
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
None, # down_scale
|
||||
None, # down_lora_bias
|
||||
None, # down_magnitude
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True, # inplace
|
||||
@@ -247,24 +254,31 @@ def test_lora_mlp_with_adapters(
|
||||
# Forward pass with adapters
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
None, # gate_lora_bias
|
||||
None, # gate_magnitude
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
None, # up_lora_bias
|
||||
None, # up_magnitude
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
scale,
|
||||
None, # down_lora_bias
|
||||
None, # down_magnitude
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True,
|
||||
@@ -334,25 +348,32 @@ def test_lora_qkv(sample_tensors):
|
||||
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # Q: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # K
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
None, # V
|
||||
True, # inplace
|
||||
)
|
||||
|
||||
assert Q1.shape == K1.shape == V1.shape == X.shape
|
||||
@@ -366,25 +387,32 @@ def test_lora_qkv(sample_tensors):
|
||||
# Test with LoRA adapters
|
||||
Q2, K2, V2 = LoRA_QKV.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
None,
|
||||
None, # Q
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
None,
|
||||
None, # K
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
True,
|
||||
None,
|
||||
None, # V
|
||||
True, # inplace
|
||||
)
|
||||
|
||||
assert Q2.shape == K2.shape == V2.shape == X.shape
|
||||
@@ -427,7 +455,9 @@ def test_lora_o(sample_tensors):
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(X, W, b, None, A, B, scale)
|
||||
output = LoRA_O.apply(
|
||||
X, None, W, b, None, A, B, scale, None, None
|
||||
) # X_drop, ..., lora_bias, magnitude
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
@@ -542,6 +572,7 @@ def test_inplace_operations(sample_tensors, apply_function):
|
||||
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"training": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
1245
tests/e2e/kernels/test_lora_features.py
Normal file
1245
tests/e2e/kernels/test_lora_features.py
Normal file
File diff suppressed because it is too large
Load Diff
120
tests/e2e/multigpu/test_fsdp2_lora_kernels.py
Normal file
120
tests/e2e/multigpu/test_fsdp2_lora_kernels.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Test LoRA kernels under FSDP2 multi-GPU training.
|
||||
|
||||
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
|
||||
lora_embedding_kernel work correctly with FSDP2 sharding, including
|
||||
with bias, dropout, and DoRA enabled.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def _run_training(temp_dir, cfg):
|
||||
"""Write config and launch multi-GPU training."""
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _base_lora_fsdp2_config(temp_dir, **overrides):
|
||||
"""Base config for LoRA + FSDP2 + kernel tests."""
|
||||
cfg = {
|
||||
"base_model": "Qwen/Qwen3-0.6B",
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:1%]",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"bf16": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
# Enable all LoRA kernels
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
"lora_embedding_kernel": True,
|
||||
"save_safetensors": True,
|
||||
}
|
||||
cfg.update(overrides)
|
||||
return DictDefault(cfg)
|
||||
|
||||
|
||||
class TestFSDP2LoRAKernels:
|
||||
"""Test LoRA kernels under FSDP2."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_kernels_basic(self, temp_dir):
|
||||
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
|
||||
cfg = _base_lora_fsdp2_config(temp_dir)
|
||||
_run_training(temp_dir, cfg)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_kernels_with_dropout(self, temp_dir):
|
||||
"""LoRA kernels + dropout + FSDP2."""
|
||||
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
|
||||
_run_training(temp_dir, cfg)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_kernels_with_dora(self, temp_dir):
|
||||
"""LoRA kernels + DoRA + FSDP2."""
|
||||
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
|
||||
_run_training(temp_dir, cfg)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
|
||||
"""LoRA kernels + DoRA + dropout + FSDP2."""
|
||||
cfg = _base_lora_fsdp2_config(
|
||||
temp_dir,
|
||||
peft_use_dora=True,
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
_run_training(temp_dir, cfg)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
|
||||
|
||||
|
||||
def test_kernel_patch_conditions():
|
||||
"""Test various conditions that should prevent kernel patching."""
|
||||
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
|
||||
test_configs = [
|
||||
# Dropout prevents patching
|
||||
# Dropout — kernels now support this
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
},
|
||||
# Bias prevents patching
|
||||
# Bias — kernels now support this
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
@@ -252,13 +252,14 @@ def test_kernel_patch_conditions():
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
# Should not patch
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
|
||||
# Verify no patches applied
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||
# Verify patches ARE applied (dropout and bias are now supported)
|
||||
assert (
|
||||
layer.forward.__func__ is apply_lora_mlp_swiglu
|
||||
or layer.forward.__func__ is apply_lora_mlp_geglu
|
||||
)
|
||||
|
||||
|
||||
def test_kernel_config_options():
|
||||
@@ -511,7 +512,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
||||
|
||||
|
||||
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||
"""Test model loading with dropout non-zero should not patch."""
|
||||
"""Test model loading with dropout non-zero DOES patch (now supported)."""
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
@@ -546,31 +547,18 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||
# Load config
|
||||
cfg = load_cfg(str(path))
|
||||
|
||||
# Get original attention class
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Store original state before patching
|
||||
original_forward_method = attention_cls.forward
|
||||
|
||||
# Load model
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# We call modelloader as that's where the patches are applied
|
||||
# despite the fact that we're not using it to load the model
|
||||
model_loader = ModelLoader(cfg, tokenizer)
|
||||
|
||||
# Apply patch
|
||||
# Apply patches — should succeed even with dropout > 0
|
||||
model_loader.patch_manager._apply_self_attention_lora_patch()
|
||||
|
||||
# Verify patch was not applied
|
||||
assert attention_cls.forward == original_forward_method
|
||||
|
||||
# Apply apply_lora_kernel_patches
|
||||
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
||||
|
||||
# Verify patch was not applied
|
||||
# Verify patches WERE applied (dropout is now supported by kernels)
|
||||
layers = get_layers(model)
|
||||
for layer in layers:
|
||||
for self_attn in find_self_attn_in_layer(layer):
|
||||
assert not hasattr(self_attn, "apply_qkv")
|
||||
assert not hasattr(self_attn, "apply_o")
|
||||
assert hasattr(self_attn, "apply_qkv")
|
||||
assert hasattr(self_attn, "apply_o")
|
||||
|
||||
@@ -4,8 +4,7 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
@@ -68,51 +67,3 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
|
||||
@with_temp_dir
|
||||
def test_lora_gptq_packed(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"gptq": True,
|
||||
"gptq_disable_exllama": True,
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"max_steps": 20,
|
||||
"save_steps": 0.5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -9,8 +9,8 @@ import subprocess
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ from axolotl.utils.dict import DictDefault
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="DeepSeek-V3-11M remote model code needs _supports_flash_attn=True for newer transformers"
|
||||
)
|
||||
class TestDeepseekV3:
|
||||
"""
|
||||
Test case for DeepseekV3 models
|
||||
|
||||
@@ -262,6 +262,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@pytest.mark.skip(reason="TRL ORPO trainer has internal zip() length mismatch bug")
|
||||
@with_temp_dir
|
||||
def test_orpo_lora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
|
||||
@@ -70,7 +70,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -125,7 +125,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -183,7 +183,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
model.base_model.model.model.layers[0].mlp.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,6 +4,8 @@ E2E tests for custom optimizers using Llama
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -282,3 +284,60 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"optimizer_name,expected_class,learning_rate",
|
||||
[
|
||||
("flash_adamw", "FlashAdamW", 0.00001),
|
||||
("flash_adam", "FlashAdam", 0.00001),
|
||||
("flash_sgd", "FlashSGD", 0.01),
|
||||
("flash_sgdw", "FlashSGDW", 0.01),
|
||||
("flash_lion", "FlashLion", 0.0001),
|
||||
],
|
||||
)
|
||||
def test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate):
|
||||
pytest.importorskip("flashoptim")
|
||||
temp_dir = str(tmp_path)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": learning_rate,
|
||||
"optimizer": optimizer_name,
|
||||
"max_steps": 5,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert trainer.optimizer.optimizer.__class__.__name__ == expected_class
|
||||
|
||||
@@ -35,6 +35,14 @@ from tests.e2e.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_fake_quant_config_dtype(config):
|
||||
"""Get the weight dtype from a fake quantize config, handling different config types."""
|
||||
if hasattr(config, "dtype"):
|
||||
return config.dtype
|
||||
# Int4WeightFakeQuantizeConfig doesn't have .dtype — weight is always int4
|
||||
return torch.int4
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def model():
|
||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -157,6 +165,18 @@ class TestQuantization:
|
||||
expected_exception,
|
||||
expected_tensor_class,
|
||||
):
|
||||
# TODO: add mslk-cuda as a CI dependency once pytorch 2.10.x is available
|
||||
# (see https://pypi.org/project/mslk-cuda/)
|
||||
if expected_tensor_class is Int4Tensor and activation_dtype is None:
|
||||
try:
|
||||
from torchao.quantization.quantize_.workflows.int4.int4_tensor import (
|
||||
int4_row_quantize_zp,
|
||||
)
|
||||
|
||||
if int4_row_quantize_zp is None:
|
||||
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
|
||||
except ImportError:
|
||||
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
quantize_model(
|
||||
@@ -252,28 +272,24 @@ class TestQuantization:
|
||||
if quantize_embedding:
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
||||
== weight_dtype.value
|
||||
)
|
||||
embed_config = model.model.embed_tokens.weight_fake_quantizer.config
|
||||
assert _get_fake_quant_config_dtype(embed_config) == weight_dtype.value
|
||||
if group_size:
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||
== group_size
|
||||
)
|
||||
assert embed_config.group_size == group_size
|
||||
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child, FakeQuantizedLinear)
|
||||
assert hasattr(child, "weight_fake_quantizer")
|
||||
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
||||
w_config = child.weight_fake_quantizer.config
|
||||
assert _get_fake_quant_config_dtype(w_config) == weight_dtype.value
|
||||
if group_size:
|
||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||
assert w_config.group_size == group_size
|
||||
if activation_dtype:
|
||||
assert hasattr(child, "activation_fake_quantizer")
|
||||
a_config = child.activation_fake_quantizer.config
|
||||
assert (
|
||||
child.activation_fake_quantizer.config.dtype
|
||||
== activation_dtype.value
|
||||
_get_fake_quant_config_dtype(a_config) == activation_dtype.value
|
||||
)
|
||||
else:
|
||||
assert child.activation_fake_quantizer is None
|
||||
@@ -374,9 +390,16 @@ class TestQuantizationCallback:
|
||||
|
||||
# ensure model has been quantized
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
# Only test enable/disable toggling if the fake quantizer supports it
|
||||
# (Int4WeightFakeQuantizer does not have an 'enabled' attribute)
|
||||
supports_toggle = hasattr(
|
||||
model.model.embed_tokens.weight_fake_quantizer, "enabled"
|
||||
)
|
||||
if supports_toggle:
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
qat_callback = QATCallback(cfg)
|
||||
|
||||
@@ -388,9 +411,10 @@ class TestQuantizationCallback:
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should have been disabled
|
||||
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert not model.lm_head.weight_fake_quantizer.enabled
|
||||
if supports_toggle:
|
||||
# quantization should have been disabled
|
||||
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert not model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
trainer_state.global_step = 100
|
||||
qat_callback.on_step_begin(
|
||||
@@ -400,9 +424,10 @@ class TestQuantizationCallback:
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should have been enabled
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
if supports_toggle:
|
||||
# quantization should have been enabled
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
@require_torch_2_8_0
|
||||
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
||||
@@ -424,9 +449,10 @@ class TestQuantizationCallback:
|
||||
|
||||
# ensure model has been quantized
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
qat_callback = QATCallback(cfg)
|
||||
# simulate first training step
|
||||
@@ -438,5 +464,6 @@ class TestQuantizationCallback:
|
||||
)
|
||||
|
||||
# quantization should be enabled from the get-go
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
@@ -179,7 +179,7 @@ def check_tensorboard(
|
||||
tag: str,
|
||||
lt_val: float,
|
||||
assertion_err: str,
|
||||
rtol: float = 0.02,
|
||||
rtol: float = 0.05,
|
||||
gt_zero: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
407
tests/integrations/test_scattermoe_lora_kernels.py
Normal file
407
tests/integrations/test_scattermoe_lora_kernels.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
Unit tests for ScatterMoE LoRA Triton kernels.
|
||||
|
||||
Tests correctness of:
|
||||
- scatter2scatter_lora (forward)
|
||||
- scatter2scatter_lora_dX (backward input gradient)
|
||||
- group_bwd_lora (backward LoRA weight gradients via split dA/dB)
|
||||
- ScatterMoELoRA autograd function (full forward + backward)
|
||||
|
||||
Each kernel is tested against a pure PyTorch per-expert-loop reference
|
||||
implementation at multiple model shapes and LoRA ranks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||
ScatterMoELoRA,
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
def _requires_cuda():
|
||||
return pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="CUDA not available"
|
||||
)
|
||||
|
||||
|
||||
pytestmark = _requires_cuda()
|
||||
|
||||
|
||||
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _setup(E, K, N, T, top_k, R, seed=42):
|
||||
"""Create synthetic expert weights, LoRA, routing, and grouped inputs."""
|
||||
torch.manual_seed(seed)
|
||||
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, E, device=DEVICE)
|
||||
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||
return x, W, lora_A, lora_B, sei, ssi, eo
|
||||
|
||||
|
||||
def _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E):
|
||||
"""Per-expert loop reference: Y = X@W + scaling*(X@A^T)@B^T."""
|
||||
grouped_x = base_ops.group(x, ssi, fan_out=k)
|
||||
M, N = grouped_x.size(0), W.size(2)
|
||||
R = lora_A.size(0) // E
|
||||
out = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
|
||||
for e in range(E):
|
||||
s = eo[e - 1].item() if e > 0 else 0
|
||||
end = eo[e].item()
|
||||
if s == end:
|
||||
continue
|
||||
xe = grouped_x[s:end].float()
|
||||
we = W[e].float()
|
||||
ae = lora_A[e * R : (e + 1) * R].float()
|
||||
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||
out[s:end] = (xe @ we + scaling * (xe @ ae.T) @ be.T).to(DTYPE)
|
||||
result = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
|
||||
result[ssi] = out
|
||||
return result
|
||||
|
||||
|
||||
def _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E):
|
||||
"""Per-expert loop reference: dX = dY@W^T + scaling*(dY@B)@A."""
|
||||
M, K = dy_grouped.size(0), W.size(1)
|
||||
R = lora_A.size(0) // E
|
||||
out = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
|
||||
for e in range(E):
|
||||
s = eo[e - 1].item() if e > 0 else 0
|
||||
end = eo[e].item()
|
||||
if s == end:
|
||||
continue
|
||||
dye = dy_grouped[s:end].float()
|
||||
we = W[e].float()
|
||||
ae = lora_A[e * R : (e + 1) * R].float()
|
||||
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||
out[s:end] = (dye @ we.T + scaling * (dye @ be) @ ae).to(DTYPE)
|
||||
result = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
|
||||
result[ssi] = out
|
||||
return result
|
||||
|
||||
|
||||
def _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling):
|
||||
"""Per-expert loop reference: dA, dB for LoRA weight gradients."""
|
||||
R = lora_A.size(0) // E
|
||||
dA = torch.zeros_like(lora_A)
|
||||
dB = torch.zeros_like(lora_B)
|
||||
for e in range(E):
|
||||
s = eo[e - 1].item() if e > 0 else 0
|
||||
end = eo[e].item()
|
||||
if s == end:
|
||||
continue
|
||||
xe = grouped_x[s:end].float()
|
||||
dye = dy[s:end].float()
|
||||
ae = lora_A[e * R : (e + 1) * R].float()
|
||||
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||
dA[e * R : (e + 1) * R] = (scaling * (dye @ be).T @ xe).to(DTYPE)
|
||||
dB[:, e * R : (e + 1) * R] = (scaling * dye.T @ (xe @ ae.T)).to(DTYPE)
|
||||
return dA, dB
|
||||
|
||||
|
||||
# ─── Model shape configs ────────────────────────────────────────────────────
|
||||
|
||||
# (E, K, N, T, top_k, R, description)
|
||||
CONFIGS_SMALL = [
|
||||
(32, 128, 64, 64, 2, 4, "tiny"),
|
||||
(64, 256, 128, 128, 4, 8, "small"),
|
||||
]
|
||||
|
||||
CONFIGS_REAL = [
|
||||
(256, 2048, 1024, 2048, 8, 16, "qwen35_gate_up"),
|
||||
(256, 512, 2048, 2048, 8, 16, "qwen35_down"),
|
||||
(64, 2048, 2048, 2048, 8, 16, "olmoe_gate_up"),
|
||||
(128, 2048, 1536, 2048, 8, 16, "qwen3_gate_up"),
|
||||
]
|
||||
|
||||
SCALING = 2.0
|
||||
|
||||
|
||||
# ─── Forward tests ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestScatter2ScatterLoRAForward:
|
||||
"""Test scatter2scatter_lora forward kernel vs reference."""
|
||||
|
||||
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||
def config(self, request):
|
||||
return request.param
|
||||
|
||||
def test_matches_reference(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
kernel_out = lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)
|
||||
|
||||
err = (kernel_out.float() - ref_out.float()).abs().max().item()
|
||||
assert err < 1.0, f"[{desc}] fwd max_err={err}"
|
||||
|
||||
def test_output_shape(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
out = lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
assert out.shape == (T * k, N)
|
||||
assert out.dtype == DTYPE
|
||||
|
||||
|
||||
# ─── Backward dX tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestScatter2ScatterLoRADX:
|
||||
"""Test scatter2scatter_lora_dX backward kernel vs reference."""
|
||||
|
||||
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||
def config(self, request):
|
||||
return request.param
|
||||
|
||||
def test_matches_reference(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
gx = base_ops.group(x, ssi, fan_out=k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
|
||||
kernel_dx = lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)
|
||||
|
||||
err = (kernel_dx.float() - ref_dx.float()).abs().max().item()
|
||||
assert err < 1.0, f"[{desc}] dX max_err={err}"
|
||||
|
||||
|
||||
# ─── Backward LoRA gradient tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGroupBwdLoRA:
|
||||
"""Test group_bwd_lora (split dA/dB kernel) vs reference."""
|
||||
|
||||
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||
def config(self, request):
|
||||
return request.param
|
||||
|
||||
def test_matches_reference(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
gx = base_ops.group(x, ssi, fan_out=k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
|
||||
kern_dA, kern_dB = lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=E,
|
||||
scaling=SCALING,
|
||||
)
|
||||
ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)
|
||||
|
||||
# Use norm-relative error: bf16 accumulation order differs between
|
||||
# kernel (tiled + different reduction order) and reference (per-expert
|
||||
# fp32 loop), so max absolute error can be large on individual elements
|
||||
# while the overall tensor is correct.
|
||||
dA_norm_err = (
|
||||
(kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6)
|
||||
).item()
|
||||
dB_norm_err = (
|
||||
(kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6)
|
||||
).item()
|
||||
assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}"
|
||||
assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}"
|
||||
|
||||
def test_zero_expert_tokens(self):
|
||||
"""Experts with zero routed tokens produce zero gradients."""
|
||||
E, K, N, R = 8, 64, 32, 4
|
||||
torch.manual_seed(42)
|
||||
# Route all tokens to expert 0 only
|
||||
T, k = 16, 1
|
||||
top_idx = torch.zeros(T, k, dtype=torch.long, device=DEVICE)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||
gx = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
dy = torch.randn(T, N, device=DEVICE, dtype=DTYPE)
|
||||
lA = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE)
|
||||
lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)
|
||||
|
||||
dA, dB = lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=E,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
# Experts 1..7 should have zero gradients
|
||||
for e in range(1, E):
|
||||
assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero"
|
||||
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, (
|
||||
f"Expert {e} dB not zero"
|
||||
)
|
||||
|
||||
|
||||
# ─── Full autograd tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestScatterMoELoRAAutograd:
|
||||
"""Test full forward + backward through ScatterMoELoRA autograd function."""
|
||||
|
||||
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL[:2])
|
||||
def config(self, request):
|
||||
return request.param
|
||||
|
||||
def test_gradients_exist_and_finite(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
x = x.requires_grad_(True)
|
||||
lA = lA.requires_grad_(True)
|
||||
lB = lB.requires_grad_(True)
|
||||
|
||||
out = ScatterMoELoRA.apply(
|
||||
x,
|
||||
W,
|
||||
k,
|
||||
sei,
|
||||
ssi,
|
||||
eo,
|
||||
lA,
|
||||
lB,
|
||||
SCALING,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
|
||||
assert x.grad is not None, f"[{desc}] x.grad is None"
|
||||
assert lA.grad is not None, f"[{desc}] lA.grad is None"
|
||||
assert lB.grad is not None, f"[{desc}] lB.grad is None"
|
||||
assert torch.isfinite(x.grad).all(), f"[{desc}] x.grad has non-finite"
|
||||
assert torch.isfinite(lA.grad).all(), f"[{desc}] lA.grad has non-finite"
|
||||
assert torch.isfinite(lB.grad).all(), f"[{desc}] lB.grad has non-finite"
|
||||
assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero"
|
||||
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
|
||||
|
||||
def test_split_matches_fused(self):
|
||||
"""Split dispatch (for few large experts) matches fused kernel."""
|
||||
# Use a shape where split would be dispatched (large K*N, few E)
|
||||
E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
# Force fused path
|
||||
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
|
||||
out_fused = lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
|
||||
# Force split path
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
|
||||
out_split = lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
|
||||
|
||||
norm_err = (
|
||||
(out_fused.float() - out_split.float()).norm()
|
||||
/ (out_fused.float().norm() + 1e-6)
|
||||
).item()
|
||||
assert norm_err < 0.01, f"split vs fused norm_err={norm_err}"
|
||||
|
||||
def test_scaling_zero_gives_base_only(self):
|
||||
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
|
||||
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
out_lora = ScatterMoELoRA.apply(
|
||||
x,
|
||||
W,
|
||||
k,
|
||||
sei,
|
||||
ssi,
|
||||
eo,
|
||||
lA,
|
||||
lB,
|
||||
0.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out_base = base_ops.scatter2scatter(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
)
|
||||
err = (out_lora.float() - out_base.float()).abs().max().item()
|
||||
assert err < 0.01, f"scaling=0 should match base: err={err}"
|
||||
229
tests/kernels/test_rms_norm_gated.py
Normal file
229
tests/kernels/test_rms_norm_gated.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Correctness tests for fused RMSNorm + SiLU Gate kernel.
|
||||
|
||||
Tests against the eager Qwen3_5RMSNormGated implementation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
pytest.importorskip("triton", reason="triton required for fused kernels")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
|
||||
|
||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||
|
||||
|
||||
class EagerRMSNormGated(torch.nn.Module):
|
||||
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, gate=None):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = self.weight * hidden_states.to(input_dtype)
|
||||
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
def _sync_weights(eager_mod, fused_mod):
|
||||
"""Copy weights from eager to fused module."""
|
||||
fused_mod.weight.data.copy_(eager_mod.weight.data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(2, 128, 256),
|
||||
(4, 64, 512),
|
||||
(1, 32, 1024),
|
||||
(2, 16, 2560), # Qwen3.5-4B hidden_size
|
||||
(2, 16, 4096), # Qwen3.5-9B hidden_size
|
||||
(1, 8, 5120), # Qwen3.5-27B hidden_size
|
||||
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
|
||||
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
|
||||
],
|
||||
)
|
||||
class TestRMSNormGatedForward:
|
||||
def test_output_matches_eager(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X, gate=G)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
if dtype == torch.float32:
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
|
||||
else:
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_output_shape(self, dtype, shape):
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
y = fused(X, gate=G)
|
||||
assert y.shape == (B, T, H)
|
||||
assert y.dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(2, 32, 256),
|
||||
(2, 16, 512),
|
||||
(2, 16, 2560), # Qwen3.5-4B
|
||||
(1, 8, 4096), # Qwen3.5-9B
|
||||
(1, 8, 5120), # Qwen3.5-27B
|
||||
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
|
||||
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
|
||||
],
|
||||
)
|
||||
class TestRMSNormGatedBackward:
|
||||
def test_grad_x(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
if dtype == torch.float32:
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
|
||||
|
||||
def test_grad_gate(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
if dtype == torch.float32:
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
|
||||
|
||||
def test_grad_weight(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
if dtype == torch.float32:
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(
|
||||
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
|
||||
class TestRMSNormGatedEdgeCases:
|
||||
def test_gate_none_raises(self):
|
||||
fused = FusedRMSNormGated(256).cuda()
|
||||
X = torch.randn(2, 4, 256, device="cuda")
|
||||
with pytest.raises(ValueError, match="requires a gate tensor"):
|
||||
fused(X, gate=None)
|
||||
|
||||
def test_2d_input(self):
|
||||
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
|
||||
torch.manual_seed(42)
|
||||
H = 512
|
||||
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
|
||||
|
||||
def test_random_weight_init(self):
|
||||
"""Test with non-default weight values."""
|
||||
torch.manual_seed(123)
|
||||
H = 256
|
||||
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
||||
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
# Randomize weights
|
||||
eager.weight.data = torch.randn_like(eager.weight.data)
|
||||
|
||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X, gate=G)
|
||||
y_fused = fused(X, gate=G)
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user