optimize moe + lora

This commit is contained in:
Wing Lian
2026-03-18 21:42:13 +00:00
parent 163bd4dd5a
commit c5db90aa3f
4 changed files with 309 additions and 36 deletions

View File

@@ -1330,6 +1330,224 @@ def _group_bwd_lora(
)
def _group_bwd_split_configs():
"""Autotune configs for split dA/dB kernels."""
configs = []
for block_m, block_dim, warps, stages in product(
[32, 64, 128], # BLOCK_M (token tile)
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
num_stages=stages,
num_warps=warps,
)
)
return configs
def _prune_split_configs(configs, named_args, **kwargs):
"""Prune split kernel configs based on SMEM capacity."""
smem_cap = _get_smem_capacity()
block_r = named_args.get("BLOCK_R", 64)
inner_dim = named_args.get("INNER_DIM", 2048)
# Fixed inner tile for reduction dimension
BLOCK_INNER = 64
pruned = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_dim = config.kwargs["BLOCK_DIM"]
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
# LoRA weights held in registers: [INNER, R] or [R, DIM]
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
if smem <= smem_cap - _SMEM_SLACK:
pruned.append(config)
if pruned:
return pruned
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
return [configs[0]]
@triton.autotune(
configs=_group_bwd_split_configs(),
key=["M", "K", "N"],
prune_configs_by={"early_config_prune": _prune_split_configs},
)
@triton.heuristics({
"NO_DIM_MASK": lambda args: (
(args["K"] % args["BLOCK_DIM"]) == 0
if args["COMPUTE_DA"]
else (args["N"] % args["BLOCK_DIM"]) == 0
),
})
@triton.jit
def _group_bwd_lora_split(
# Data tensors (DY and X are always present)
DY_ptr, stride_dym, stride_dyn,
X_ptr, stride_xm, stride_xk,
# LoRA weight for the inner reduction (B for dA, A for dB)
LW_ptr, stride_lw0, stride_lw1,
# Output gradient tensor (dA or dB)
OUT_ptr, stride_out0, stride_out1,
# Expert offsets
expert_offsets_ptr,
# Dimensions
M, K: tl.constexpr, N: tl.constexpr,
ACTUAL_R: tl.constexpr, BLOCK_R: tl.constexpr,
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
scaling,
# Mode flag
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
# Tile sizes
BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_DIM_MASK: tl.constexpr,
):
"""
Unified split kernel for LoRA gradient computation.
When COMPUTE_DA=True:
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
Grid: (E, cdiv(K, BLOCK_DIM))
- outer_ptr/stride = X (read [M, K_block])
- inner reduction over N using DY and B
- output shape [BLOCK_R, BLOCK_DIM]
When COMPUTE_DA=False:
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
Grid: (E, cdiv(N, BLOCK_DIM))
- outer_ptr/stride = DY (read [M, N_block])
- inner reduction over K using X and A
- output shape [BLOCK_DIM, BLOCK_R]
No atomic adds — each (E, dim_block) pair is written by exactly one block.
"""
E_idx = tl.program_id(0)
dim_block_id = tl.program_id(1)
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
num_tokens = end_idx - start_idx
# Output dimension tile (K for dA, N for dB)
if COMPUTE_DA:
OUT_DIM: tl.constexpr = K
else:
OUT_DIM: tl.constexpr = N
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_mask = dim_block < OUT_DIM
R_block = tl.arange(0, BLOCK_R)
R_mask = R_block < ACTUAL_R
lora_offset = E_idx * ACTUAL_R
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
if COMPUTE_DA:
out_blk_ptrs = (
OUT_ptr
+ (lora_offset + R_block)[:, None] * stride_out0
+ dim_block[None, :] * stride_out1
)
out_mask = R_mask[:, None] & dim_mask[None, :]
else:
out_blk_ptrs = (
OUT_ptr
+ dim_block[:, None] * stride_out0
+ (lora_offset + R_block)[None, :] * stride_out1
)
out_mask = dim_mask[:, None] & R_mask[None, :]
if num_tokens > 0:
M_block = tl.arange(0, BLOCK_M)
INPUT_DTYPE = X_ptr.dtype.element_ty
BLOCK_INNER: tl.constexpr = 64
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
if COMPUTE_DA:
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
else:
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
M_iters = tl.cdiv(num_tokens, BLOCK_M)
for i in range(M_iters):
M_idx = start_idx + i * BLOCK_M + M_block
M_mask = M_idx < end_idx
if COMPUTE_DA:
# Load X[M, K_block] (the "outer" tensor for dA)
outer = tl.load(
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
).to(INPUT_DTYPE)
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < N
dy_tile = tl.load(
DY_ptr + M_idx[:, None] * stride_dym + inn_off[None, :] * stride_dyn,
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
).to(INPUT_DTYPE)
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
lw_tile = tl.load(
LW_ptr + inn_off[:, None] * stride_lw0 + (lora_offset + R_block)[None, :] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
).to(INPUT_DTYPE)
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
acc += tl.dot(tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32)
else:
# Load DY[M, N_block] (the "outer" tensor for dB)
outer = tl.load(
DY_ptr + M_idx[:, None] * stride_dym + dim_block[None, :] * stride_dyn,
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
).to(INPUT_DTYPE)
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < K
x_tile = tl.load(
X_ptr + M_idx[:, None] * stride_xm + inn_off[None, :] * stride_xk,
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
).to(INPUT_DTYPE)
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
# We want A[e]^T: [K, R], so load as [K_inner, R]
lw_tile = tl.load(
LW_ptr + (lora_offset + R_block)[None, :] * stride_lw0 + inn_off[:, None] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
).to(INPUT_DTYPE)
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
acc += tl.dot(tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32)
tl.store(out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask)
else:
# Zero out this expert's slice — needed because output uses empty_like
if COMPUTE_DA:
tl.store(out_blk_ptrs, tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
else:
tl.store(out_blk_ptrs, tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
def group_bwd_lora(
DY: torch.Tensor,
X: torch.Tensor,
@@ -1344,6 +1562,9 @@ def group_bwd_lora(
"""
Compute LoRA gradients for A and B on expert-grouped data.
Uses split dA/dB kernels that eliminate atomic adds by giving each
(expert, output_block) pair its own thread block.
Args:
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
X: Input [M_total, K] (grouped by expert)
@@ -1361,46 +1582,45 @@ def group_bwd_lora(
K = X.size(1)
N = DY.size(1)
# Zero-init for atomic accumulation
dA = torch.zeros_like(lora_A)
dB = torch.zeros_like(lora_B)
# No zero-init needed: the split kernels write zeros for experts with
# zero routed tokens directly in the kernel (else branch).
dA = torch.empty_like(lora_A)
dB = torch.empty_like(lora_B)
BLOCK_R = _block_r_for_rank(R)
def grid(META):
return (
E * triton.cdiv(K, META["BLOCK_K"]),
triton.cdiv(N, META["BLOCK_N"]),
)
def grid_dA(META):
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
_group_bwd_lora[grid](
DY,
DY.stride(0),
DY.stride(1),
X,
X.stride(0),
X.stride(1),
lora_A,
lora_A.stride(0),
lora_A.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
dB,
dB.stride(0),
dB.stride(1),
_group_bwd_lora_split[grid_dA](
DY, DY.stride(0), DY.stride(1),
X, X.stride(0), X.stride(1),
lora_B, lora_B.stride(0), lora_B.stride(1),
dA, dA.stride(0), dA.stride(1),
expert_offsets,
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R, # True LoRA rank
BLOCK_R=BLOCK_R, # Padded tile size
M=DY.size(0), K=K, N=N,
ACTUAL_R=R, BLOCK_R=BLOCK_R,
INNER_DIM=N,
scaling=scaling,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
COMPUTE_DA=True,
ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
)
def grid_dB(META):
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
_group_bwd_lora_split[grid_dB](
DY, DY.stride(0), DY.stride(1),
X, X.stride(0), X.stride(1),
lora_A, lora_A.stride(0), lora_A.stride(1),
dB, dB.stride(0), dB.stride(1),
expert_offsets,
M=DY.size(0), K=K, N=N,
ACTUAL_R=R, BLOCK_R=BLOCK_R,
INNER_DIM=K,
scaling=scaling,
COMPUTE_DA=False,
ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
)
return dA, dB

View File

@@ -61,7 +61,13 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
# Prefer text backbone type for VLMs, but fall back to base type
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if moe_model_type not in SPARSE_MOE_BLOCK and cfg.model_config_type in SPARSE_MOE_BLOCK:
moe_model_type = cfg.model_config_type
if cfg.use_scattermoe:
self._register_kernels()

View File

@@ -505,6 +505,20 @@ class ModelLoader:
elif not is_ds_zero3:
self.model_kwargs["device_map"] = device_map
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
# so the actual VRAM usage is much less than bf16 estimates.
# When device_map is "auto", accelerate's infer_auto_device_map computes
# the device map at bf16 size (before quantization), causing it to offload
# layers to CPU, which BnB then rejects. Force single-GPU placement to
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
"auto",
None,
):
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"

View File

@@ -17,6 +17,8 @@ from transformers import (
class PytorchProfilerCallback(TrainerCallback):
"""
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
Also runs torch.profiler to produce a Chrome trace for timing analysis.
"""
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
@@ -26,9 +28,12 @@ class PytorchProfilerCallback(TrainerCallback):
if profiler_steps_start == 0:
# start recording memory allocations before everything is allocated, because if we start
# at the beginning of step 0, we won't have any memory allocations in the traces
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(
enabled="all", stacks="all"
)
profiler_steps_start = -1
self.profiler_steps_start = profiler_steps_start
self._profiler = None
def on_step_begin(
self,
@@ -38,7 +43,22 @@ class PytorchProfilerCallback(TrainerCallback):
**kwargs,
):
if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(
enabled="all", stacks="all"
)
# Start torch.profiler on the first profiled step
if state.global_step == max(self.profiler_steps_start, 0):
self._profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
)
self._profiler.__enter__()
def on_step_end(
self,
@@ -55,6 +75,13 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
# Stop and export torch.profiler trace
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None
def on_train_end(
self,
args: TrainingArguments,
@@ -73,3 +100,9 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None