optimize moe + lora

This commit is contained in:
Wing Lian
2026-03-18 21:42:13 +00:00
parent 163bd4dd5a
commit c5db90aa3f
4 changed files with 309 additions and 36 deletions

View File

@@ -1330,6 +1330,224 @@ def _group_bwd_lora(
) )
def _group_bwd_split_configs():
"""Autotune configs for split dA/dB kernels."""
configs = []
for block_m, block_dim, warps, stages in product(
[32, 64, 128], # BLOCK_M (token tile)
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
num_stages=stages,
num_warps=warps,
)
)
return configs
def _prune_split_configs(configs, named_args, **kwargs):
"""Prune split kernel configs based on SMEM capacity."""
smem_cap = _get_smem_capacity()
block_r = named_args.get("BLOCK_R", 64)
inner_dim = named_args.get("INNER_DIM", 2048)
# Fixed inner tile for reduction dimension
BLOCK_INNER = 64
pruned = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_dim = config.kwargs["BLOCK_DIM"]
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
# LoRA weights held in registers: [INNER, R] or [R, DIM]
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
if smem <= smem_cap - _SMEM_SLACK:
pruned.append(config)
if pruned:
return pruned
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
return [configs[0]]
@triton.autotune(
configs=_group_bwd_split_configs(),
key=["M", "K", "N"],
prune_configs_by={"early_config_prune": _prune_split_configs},
)
@triton.heuristics({
"NO_DIM_MASK": lambda args: (
(args["K"] % args["BLOCK_DIM"]) == 0
if args["COMPUTE_DA"]
else (args["N"] % args["BLOCK_DIM"]) == 0
),
})
@triton.jit
def _group_bwd_lora_split(
# Data tensors (DY and X are always present)
DY_ptr, stride_dym, stride_dyn,
X_ptr, stride_xm, stride_xk,
# LoRA weight for the inner reduction (B for dA, A for dB)
LW_ptr, stride_lw0, stride_lw1,
# Output gradient tensor (dA or dB)
OUT_ptr, stride_out0, stride_out1,
# Expert offsets
expert_offsets_ptr,
# Dimensions
M, K: tl.constexpr, N: tl.constexpr,
ACTUAL_R: tl.constexpr, BLOCK_R: tl.constexpr,
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
scaling,
# Mode flag
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
# Tile sizes
BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_DIM_MASK: tl.constexpr,
):
"""
Unified split kernel for LoRA gradient computation.
When COMPUTE_DA=True:
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
Grid: (E, cdiv(K, BLOCK_DIM))
- outer_ptr/stride = X (read [M, K_block])
- inner reduction over N using DY and B
- output shape [BLOCK_R, BLOCK_DIM]
When COMPUTE_DA=False:
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
Grid: (E, cdiv(N, BLOCK_DIM))
- outer_ptr/stride = DY (read [M, N_block])
- inner reduction over K using X and A
- output shape [BLOCK_DIM, BLOCK_R]
No atomic adds — each (E, dim_block) pair is written by exactly one block.
"""
E_idx = tl.program_id(0)
dim_block_id = tl.program_id(1)
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
num_tokens = end_idx - start_idx
# Output dimension tile (K for dA, N for dB)
if COMPUTE_DA:
OUT_DIM: tl.constexpr = K
else:
OUT_DIM: tl.constexpr = N
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_mask = dim_block < OUT_DIM
R_block = tl.arange(0, BLOCK_R)
R_mask = R_block < ACTUAL_R
lora_offset = E_idx * ACTUAL_R
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
if COMPUTE_DA:
out_blk_ptrs = (
OUT_ptr
+ (lora_offset + R_block)[:, None] * stride_out0
+ dim_block[None, :] * stride_out1
)
out_mask = R_mask[:, None] & dim_mask[None, :]
else:
out_blk_ptrs = (
OUT_ptr
+ dim_block[:, None] * stride_out0
+ (lora_offset + R_block)[None, :] * stride_out1
)
out_mask = dim_mask[:, None] & R_mask[None, :]
if num_tokens > 0:
M_block = tl.arange(0, BLOCK_M)
INPUT_DTYPE = X_ptr.dtype.element_ty
BLOCK_INNER: tl.constexpr = 64
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
if COMPUTE_DA:
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
else:
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
M_iters = tl.cdiv(num_tokens, BLOCK_M)
for i in range(M_iters):
M_idx = start_idx + i * BLOCK_M + M_block
M_mask = M_idx < end_idx
if COMPUTE_DA:
# Load X[M, K_block] (the "outer" tensor for dA)
outer = tl.load(
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
).to(INPUT_DTYPE)
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < N
dy_tile = tl.load(
DY_ptr + M_idx[:, None] * stride_dym + inn_off[None, :] * stride_dyn,
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
).to(INPUT_DTYPE)
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
lw_tile = tl.load(
LW_ptr + inn_off[:, None] * stride_lw0 + (lora_offset + R_block)[None, :] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
).to(INPUT_DTYPE)
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
acc += tl.dot(tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32)
else:
# Load DY[M, N_block] (the "outer" tensor for dB)
outer = tl.load(
DY_ptr + M_idx[:, None] * stride_dym + dim_block[None, :] * stride_dyn,
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
).to(INPUT_DTYPE)
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < K
x_tile = tl.load(
X_ptr + M_idx[:, None] * stride_xm + inn_off[None, :] * stride_xk,
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
).to(INPUT_DTYPE)
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
# We want A[e]^T: [K, R], so load as [K_inner, R]
lw_tile = tl.load(
LW_ptr + (lora_offset + R_block)[None, :] * stride_lw0 + inn_off[:, None] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
).to(INPUT_DTYPE)
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
acc += tl.dot(tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32)
tl.store(out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask)
else:
# Zero out this expert's slice — needed because output uses empty_like
if COMPUTE_DA:
tl.store(out_blk_ptrs, tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
else:
tl.store(out_blk_ptrs, tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
def group_bwd_lora( def group_bwd_lora(
DY: torch.Tensor, DY: torch.Tensor,
X: torch.Tensor, X: torch.Tensor,
@@ -1344,6 +1562,9 @@ def group_bwd_lora(
""" """
Compute LoRA gradients for A and B on expert-grouped data. Compute LoRA gradients for A and B on expert-grouped data.
Uses split dA/dB kernels that eliminate atomic adds by giving each
(expert, output_block) pair its own thread block.
Args: Args:
DY: Gradient w.r.t. output [M_total, N] (grouped by expert) DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
X: Input [M_total, K] (grouped by expert) X: Input [M_total, K] (grouped by expert)
@@ -1361,46 +1582,45 @@ def group_bwd_lora(
K = X.size(1) K = X.size(1)
N = DY.size(1) N = DY.size(1)
# Zero-init for atomic accumulation # No zero-init needed: the split kernels write zeros for experts with
dA = torch.zeros_like(lora_A) # zero routed tokens directly in the kernel (else branch).
dB = torch.zeros_like(lora_B) dA = torch.empty_like(lora_A)
dB = torch.empty_like(lora_B)
BLOCK_R = _block_r_for_rank(R) BLOCK_R = _block_r_for_rank(R)
def grid(META): def grid_dA(META):
return ( return (E, triton.cdiv(K, META["BLOCK_DIM"]))
E * triton.cdiv(K, META["BLOCK_K"]),
triton.cdiv(N, META["BLOCK_N"]),
)
_group_bwd_lora[grid]( _group_bwd_lora_split[grid_dA](
DY, DY, DY.stride(0), DY.stride(1),
DY.stride(0), X, X.stride(0), X.stride(1),
DY.stride(1), lora_B, lora_B.stride(0), lora_B.stride(1),
X, dA, dA.stride(0), dA.stride(1),
X.stride(0),
X.stride(1),
lora_A,
lora_A.stride(0),
lora_A.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
dB,
dB.stride(0),
dB.stride(1),
expert_offsets, expert_offsets,
M=DY.size(0), M=DY.size(0), K=K, N=N,
K=K, ACTUAL_R=R, BLOCK_R=BLOCK_R,
N=N, INNER_DIM=N,
ACTUAL_R=R, # True LoRA rank
BLOCK_R=BLOCK_R, # Padded tile size
scaling=scaling, scaling=scaling,
ACC_TYPE=tl.float32, COMPUTE_DA=True,
allow_tf32=ALLOW_TF32, ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
)
def grid_dB(META):
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
_group_bwd_lora_split[grid_dB](
DY, DY.stride(0), DY.stride(1),
X, X.stride(0), X.stride(1),
lora_A, lora_A.stride(0), lora_A.stride(1),
dB, dB.stride(0), dB.stride(1),
expert_offsets,
M=DY.size(0), K=K, N=N,
ACTUAL_R=R, BLOCK_R=BLOCK_R,
INNER_DIM=K,
scaling=scaling,
COMPUTE_DA=False,
ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
) )
return dA, dB return dA, dB

View File

@@ -61,7 +61,13 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs" return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
# Prefer text backbone type for VLMs, but fall back to base type
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
moe_model_type = cfg.model_config_type_text or cfg.model_config_type moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if moe_model_type not in SPARSE_MOE_BLOCK and cfg.model_config_type in SPARSE_MOE_BLOCK:
moe_model_type = cfg.model_config_type
if cfg.use_scattermoe: if cfg.use_scattermoe:
self._register_kernels() self._register_kernels()

View File

@@ -505,6 +505,20 @@ class ModelLoader:
elif not is_ds_zero3: elif not is_ds_zero3:
self.model_kwargs["device_map"] = device_map self.model_kwargs["device_map"] = device_map
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
# so the actual VRAM usage is much less than bf16 estimates.
# When device_map is "auto", accelerate's infer_auto_device_map computes
# the device map at bf16 size (before quantization), causing it to offload
# layers to CPU, which BnB then rejects. Force single-GPU placement to
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
"auto",
None,
):
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
cur_device = get_device_type() cur_device = get_device_type()
if "mps" in str(cur_device): if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0" self.model_kwargs["device_map"] = "mps:0"

View File

@@ -17,6 +17,8 @@ from transformers import (
class PytorchProfilerCallback(TrainerCallback): class PytorchProfilerCallback(TrainerCallback):
""" """
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps. PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
Also runs torch.profiler to produce a Chrome trace for timing analysis.
""" """
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0): def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
@@ -26,9 +28,12 @@ class PytorchProfilerCallback(TrainerCallback):
if profiler_steps_start == 0: if profiler_steps_start == 0:
# start recording memory allocations before everything is allocated, because if we start # start recording memory allocations before everything is allocated, because if we start
# at the beginning of step 0, we won't have any memory allocations in the traces # at the beginning of step 0, we won't have any memory allocations in the traces
torch.cuda.memory._record_memory_history(enabled="all") torch.cuda.memory._record_memory_history(
enabled="all", stacks="all"
)
profiler_steps_start = -1 profiler_steps_start = -1
self.profiler_steps_start = profiler_steps_start self.profiler_steps_start = profiler_steps_start
self._profiler = None
def on_step_begin( def on_step_begin(
self, self,
@@ -38,7 +43,22 @@ class PytorchProfilerCallback(TrainerCallback):
**kwargs, **kwargs,
): ):
if state.global_step == self.profiler_steps_start: if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history(enabled="all") torch.cuda.memory._record_memory_history(
enabled="all", stacks="all"
)
# Start torch.profiler on the first profiled step
if state.global_step == max(self.profiler_steps_start, 0):
self._profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
)
self._profiler.__enter__()
def on_step_end( def on_step_end(
self, self,
@@ -55,6 +75,13 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now # tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None) torch.cuda.memory._record_memory_history(enabled=None)
# Stop and export torch.profiler trace
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None
def on_train_end( def on_train_end(
self, self,
args: TrainingArguments, args: TrainingArguments,
@@ -73,3 +100,9 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now # tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None) torch.cuda.memory._record_memory_history(enabled=None)
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None