diff --git a/scripts/benchmarks/deepseek_v3_moe.py b/scripts/benchmarks/deepseek_v3_moe.py index 7a48e47ac..1e75973d7 100644 --- a/scripts/benchmarks/deepseek_v3_moe.py +++ b/scripts/benchmarks/deepseek_v3_moe.py @@ -32,7 +32,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]: else: # pragma: no cover - defensive guard raise SystemExit("Unable to locate axolotl repository root for imports") -from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe +from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402 ACCURACY_TOLERANCE = 5e-3 @@ -98,6 +98,12 @@ def parse_args() -> argparse.Namespace: default=128, help="GROUP_SIZE_M used by the Triton kernel", ) + parser.add_argument( + "--backend", + choices=["cg", "mg"], + default="mg", + help="MoE kernel backend to benchmark", + ) return parser.parse_args() @@ -163,7 +169,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict: baseline_module.moe = MethodType(original_moe, baseline_module) state_dict = baseline_module.state_dict() - patch_deepseek_v3_moe(group_size_m=args.group_size) + patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend) patched_module = build_module(args) patched_module.load_state_dict(state_dict) @@ -250,6 +256,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict: return { "device": device, + "backend": args.backend, "dtype": dtype, "baseline_ms": baseline_ms, "patched_ms": patched_ms, @@ -270,7 +277,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint result = benchmark_deepseek_v3(args) print( - f"Device={result['device'].type} dtype={result['dtype']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}" + f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}" ) print( f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}" diff --git a/scripts/benchmarks/deepseek_v3_moe_sweep.py b/scripts/benchmarks/deepseek_v3_moe_sweep.py index 7b7c68763..de7f9426e 100644 --- a/scripts/benchmarks/deepseek_v3_moe_sweep.py +++ b/scripts/benchmarks/deepseek_v3_moe_sweep.py @@ -21,7 +21,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]: else: # pragma: no cover - defensive guard raise SystemExit("Unable to locate axolotl repository root for imports") -from scripts.benchmarks.deepseek_v3_moe import ( +from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402 ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3, @@ -42,6 +42,12 @@ def parse_args() -> argparse.Namespace: choices=["auto", "cpu", "cuda"], help="Execution device", ) + parser.add_argument( + "--backend", + choices=["cg", "mg"], + default="mg", + help="MoE kernel backend to benchmark", + ) parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations") parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations") parser.add_argument("--seed", type=int, default=0, help="Random seed") @@ -105,6 +111,7 @@ def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace: { "dtype": args.dtype, "device": args.device, + "backend": args.backend, "warmup": args.warmup, "iters": args.iters, "seed": args.seed, @@ -164,6 +171,7 @@ def main() -> None: # pragma: no cover - utility script "n_experts", "top_k", "groups", + "backend", "baseline_ms", "patched_ms", "speedup", @@ -177,10 +185,10 @@ def main() -> None: # pragma: no cover - utility script rows = [] print( - f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}" + f"Running sweep on device={args.device} dtype={args.dtype} backend={args.backend} uniform_routing={args.uniform_routing}" ) print( - f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}" + f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}" f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}" ) @@ -206,6 +214,7 @@ def main() -> None: # pragma: no cover - utility script cfg["n_experts"], cfg["top_k"], cfg["groups"], + args.backend, result["baseline_ms"], result["patched_ms"], result["speedup"], @@ -219,7 +228,7 @@ def main() -> None: # pragma: no cover - utility script ) status = "OK" if result["accuracy_ok"] else "FAIL" print( - f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}" + f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {args.backend:>8}" f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x" f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}" ) diff --git a/src/axolotl/kernels/moe/__init__.py b/src/axolotl/kernels/moe/__init__.py index eb1e5c3cf..92df41765 100644 --- a/src/axolotl/kernels/moe/__init__.py +++ b/src/axolotl/kernels/moe/__init__.py @@ -8,6 +8,7 @@ from .tt_cg_gemm import ( cg_grouped_gemm_forward, cg_grouped_gemm_forward_dynamic, ) +from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm __all__ = [ "cg_grouped_gemm", @@ -16,4 +17,5 @@ __all__ = [ "ContiguousGroupedGEMM", "ContiguousGroupedGEMMForwardOnly", "generate_permute_indices", + "mg_grouped_gemm", ] diff --git a/src/axolotl/kernels/moe/tt_mg_gemm/__init__.py b/src/axolotl/kernels/moe/tt_mg_gemm/__init__.py new file mode 100644 index 000000000..c90da16c2 --- /dev/null +++ b/src/axolotl/kernels/moe/tt_mg_gemm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .mg_grouped_gemm import grouped_gemm_forward +from .tma_autotuning import ALIGN_SIZE_M + +__all__ = [ + "grouped_gemm_forward", + "ALIGN_SIZE_M", +] diff --git a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py new file mode 100644 index 000000000..58deaeddb --- /dev/null +++ b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py @@ -0,0 +1,1291 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# credit - flat index forward kernel is derived from FBGemm: +# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm + +# pyre-unsafe +import logging +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from .tma_autotuning import ( + _NV_CONFIGS, + CudaUtils, + TmaDescriptorHelper, + early_config_prune, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# ============== Start Triton Kernels =============== + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_hopper( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_EPILOGUE_SUBTILING: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel for Hopper. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty # output dtype + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store + + M_end = 0 + M_start = 0 + processed_tiles = 0 + # Size of individual weight matrix + n_size = N // G + n_start = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + n_start = n_size * g + + if m_size > 0: + # Process this group + + # Acquire hold on c_desc_ptr for TMA Store + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * n_size, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # columnwise + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + global_n_offset = (n_start + n_offset).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [global_n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + + if USE_EPILOGUE_SUBTILING: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c0, [m_offset, n_offset] + ) + c1 = acc1.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2] + ) + else: + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + # move to next tile in group + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_tma( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + a_scale_ptr, + b_scale_ptr, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_FP8: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # TMA Store prep + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_no_tma( + a_ptr, + b_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For bc and Ampere, we never use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :] + b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :] + + for _ in range(0, K, BLOCK_SIZE_K): + # Load with bounds checking + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + + # Main matmul + accumulator += tl.dot(a, b.T) + + # Update pointers for next block + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + # Store without TMA + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + c = accumulator.to(c_dtype) + + tl.store( + c_ptr + + (M_start + offs_am[:, None]) * N # Row stride is N + + offs_bn[None, :], # Column offset + c, + mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, + ) + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +""" +Backward pass for grouped GEMM with Triton, where grouping is M*G +We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`). +""" + + +# ---- dx flat linear indexed ---- +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dx_tma( + grad_output_desc_ptr, # [MG, N] + w_desc_ptr, # [N, K] + grad_input_ptr, # output grad_x [MG, K] + workspace, # for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + TMA-optimized kernel for computing gradients with respect to input (dx). + For the forward pass Y = X @ W.T, the backward for input is: + grad_X = grad_Y @ W + + This maps to [MG, N] @ [N, K] -> [MG, K] + + Key differences from forward: + 1. W is used directly and not transposed + 2. The reduction dimension is now N (not K) + 3. Output is [M, K] instead of [M, N] + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = grad_input_ptr.dtype.element_ty + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups - same as forward + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + # tiles for this group - now producing [M, K] output + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + group_num_tiles = num_m_tiles * num_k_tiles + + # TMA Store prep for [M, K] output + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=grad_input_ptr + M_start * K, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], + global_size=[m_size, K], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # Different tiling scheme for [M, K] output + tile_m_index = group_index % num_m_tiles + tile_k_index = group_index // num_m_tiles + + # for grad_input block [M, K] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + # Position in full matrix + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + # reduce along N dimension (instead of K in forward) + for n_offset in range(0, N, BLOCK_SIZE_N): + # grad_output block [M, N] + grad_output = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # weight block [N, K] - no transpose needed + w = tl._experimental_descriptor_load( + w_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + # grad_x = grad_output @ w + # reducing along N dimension + accumulator += tl.dot(grad_output, w) + + # Store using TMA + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, k_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +# ---- dw flat linear indexed ---- + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dw_tma( + x_desc_ptr, # input descriptor [M_total, K] + grad_output_desc_ptr, # grad_output descriptor [M_total, N] + grad_weight_ptr, # output grad_w [N, K] + workspace, # workspace for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension +) -> None: + """ + Improved TMA-optimized kernel for computing gradients with respect to weights (dw). + Uses flat index structure similar to forward. + + For the forward pass Y = X @ W.T, + the backward for weights is: + grad_W = grad_Y.T @ X + + Where: + - grad_Y is [MG, N] + - X is [MG, K] + - grad_W is [N, K] + - we return [N,K] + """ + # Get thread block index l + tbidx = tl.program_id(0) + + # Get output data type + c_dtype = grad_weight_ptr.dtype.element_ty + + # Calculate number of output tiles + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + total_output_tiles = num_n_tiles * num_k_tiles + + # Process tiles in strided manner across SMs + for tile_idx in range(tbidx, total_output_tiles, NUM_SMS): + # Calculate tile indices + tile_n_idx = tile_idx % num_n_tiles + tile_k_idx = tile_idx // num_n_tiles + + # Calculate global offsets + n_offset = tile_n_idx * BLOCK_SIZE_N + k_offset = tile_k_idx * BLOCK_SIZE_K + + # Initialize accumulator for this output tile [N, K] + accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + + # Process each group + M_end = 0 + for g in range(G): + # Get group boundaries + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + # Only process if group is non-empty + if m_size > 0: + # Process this group in chunks along the M dimension + for m_offset in range(0, m_size, BLOCK_SIZE_M): + # Calculate actual block size (handling boundary) + m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset) + + # Only process if we have actual work to do + if m_block_size > 0: + # Global offset for this chunk + m_global_offset = M_start + m_offset + + if USE_TMA_LOAD: + # Load input chunk [M_chunk, K] using TMA + x_block = tl._experimental_descriptor_load( + x_desc_ptr, + [m_global_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + + # Load grad_output chunk [M_chunk, N] using TMA + grad_output_block = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_global_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # Apply masks for valid regions + offs_m = tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < m_block_size + + # Zero out invalid elements + x_block = tl.where(m_mask[:, None], x_block, 0.0) + grad_output_block = tl.where( + m_mask[:, None], grad_output_block, 0.0 + ) + else: + # Manual load with bounds checking + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks + m_mask = offs_m < m_block_size + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + + # Combined masks + mk_mask = m_mask[:, None] & k_mask[None, :] + mn_mask = m_mask[:, None] & n_mask[None, :] + + # Global offsets for loading + m_global_offs = m_global_offset + offs_m + + # Load x block [M_chunk, K] + x_block = tl.load( + x_desc_ptr + + m_global_offs[:, None] * K + + (k_offset + offs_k)[None, :], + mask=mk_mask, + other=0.0, + ) + + # Load grad_output block [M_chunk, N] + grad_output_block = tl.load( + grad_output_desc_ptr + + m_global_offs[:, None] * N + + (n_offset + offs_n)[None, :], + mask=mn_mask, + other=0.0, + ) + + # Compute partial contribution: grad_W += grad_Y.T @ X + # transpose grad_output for the matmul + contribution = tl.dot( + grad_output_block.to(tl.float32).T, # [N, M_chunk] + x_block.to(tl.float32), # [M_chunk, K] + ) + + # Accumulate + accumulator += contribution + + # Store the result + if USE_TMA_STORE: + # Store using TMA + tl._experimental_descriptor_store( + workspace, # TMA store descriptor + accumulator.to(c_dtype), + [n_offset, k_offset], + ) + else: + # Manual store with bounds checking + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks for bounds checking + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + output_mask = n_mask[:, None] & k_mask[None, :] + + # Store the result + tl.store( + grad_weight_ptr + + (n_offset + offs_n)[:, None] * K + + (k_offset + offs_k)[None, :], + accumulator.to(c_dtype), + mask=output_mask, + ) + + +# ======== End Triton kernels ======== + +# ======== Triton wrapper functions ======== + +# ----- main forward pass wrapper ----- + + +def grouped_gemm_forward( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + tma_size: int = 128, + using_fp8: bool = False, +) -> torch.Tensor: + """ + M*G style grouped GEMM with TMA and Float8 support. + # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors. + + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Grouped GEMM without TMA is not supported yet") + + G = m_sizes.shape[0] + + assert x.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + # Total input size is now [M_total, K] where M_total is the sum of all group sizes + M_total, K = x.shape + N = w.shape[0] # N is now the same for all groups + + assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})" + + # Verify that all group sizes are multiples of ALIGN_SIZE_M + # This check is commented out because it will involve a GPU-CPU sync + # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M" + + # Create output tensor with correct shape [M_total, N] + y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype) + + if M_total == 0: + return y + + NUM_SMS = CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + USE_EPILOGUE_SUBTILING = False + + # TMA descriptor helper + desc_helper = None + desc_x = x + desc_w = w + workspace = None + + if USE_TMA_LOAD: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("w") + desc_x = desc_helper.get_tma_descriptor_kernel_param("x") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + if USE_TMA_STORE: + if desc_helper is None: + raise RuntimeError( + "TMA descriptors must be initialized when USE_TMA_STORE is True" + ) + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=x.device, + dtype=torch.uint8, + ) + + def grid(META): + if USE_TMA_LOAD: + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + _kernel_mg_forward_hopper[grid]( + desc_x, + desc_w, + y, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K, + NUM_SMS, + TMA_SIZE=tma_size, + USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING, + ) + + return y + + +# ======== Improved Backward ============= +def grouped_gemm_backward( + grad_output: torch.Tensor, + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Unified backward pass for grouped GeMM with M*G grouping. + Uses optimized TMA-based implementations for both dx and dw when available. + + Args: + grad_output: Gradient of output, shape [M_total, N] + x: Input tensor from forward pass, shape [M_total, K] + w: Weight tensor from forward pass, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + + + Returns: + Tuple of gradients with respect to x and w: (grad_x, grad_w) + """ + logging.info("Starting unified grouped_gemm_backward") + + # do this once, seems expensive + NUM_SMS = CudaUtils.get_num_sms() + + # Basic validation + M_total, K_x = x.shape + M_grad, N = grad_output.shape + N_w, K_w = w.shape + + # Check dimensions + if K_x != K_w: + raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}") + if M_total != M_grad: + raise ValueError( + f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}" + ) + + # Check total M matches sum of group sizes + sum_m_sizes = m_sizes.sum().item() + if M_total != sum_m_sizes: + raise ValueError( + f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + ) + + # Make sure inputs are contiguous + grad_output = grad_output.contiguous() + x = x.contiguous() + w = w.contiguous() + m_sizes = m_sizes.contiguous() + + # Check TMA support + if use_tma and not CudaUtils.verify_tma(): + logging.info("TMA requested but not supported on this device") + use_tma = False + + # Compute grad_x using flat linear implementation + try: + logging.info("Computing grad_x with flat linear kernel") + + # Use TMA-optimized implementation + grad_x = grouped_gemm_dx_tma( + grad_output=grad_output, + w=w, + m_sizes=m_sizes, + num_sms=NUM_SMS, + tma_size=tma_size, + ) + + except Exception as e: + logging.error(f"Error in grad_x computation: {e}") + raise + + # Compute grad_w using flat linear style implementation + try: + logging.info("Computing grad_w with flat linear kernel") + + grad_w = grouped_gemm_dw_tma( + x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size + ) + except Exception as e: + logging.error(f"Error in grad_w computation: {e}") + raise + + return grad_x, grad_w + + +# ----- dx backward pass wrapper ----- + + +def grouped_gemm_dx_tma( + grad_output: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized backward pass wrapper for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + # using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + """ + Optimized backward pass for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Optimized dx computation requires TMA support") + + G = m_sizes.shape[0] + + assert grad_output.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + M_total, N_grad = grad_output.shape + N_w, K = w.shape + + # Check dimensions + assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert M_total == sum_m_sizes, ( + f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + ) + + # Create output tensor (grad_x) with shape [M_total, K] + grad_x = torch.empty( + (M_total, K), device=grad_output.device, dtype=grad_output.dtype + ) + + NUM_SMS = num_sms # CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + + # Set up TMA descriptors + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("grad_output") + desc_helper.init_tma_descriptor("w") + desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + # Allocate workspace for TMA store + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=grad_output.device, + dtype=torch.uint8, + ) + + def grid(META): + # Fill TMA descriptors with appropriate dimensions + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N_grad, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N_w, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + # Launch the flat linear kernel for computing grad_x + _kernel_mg_dx_tma[grid]( + desc_grad_output, + desc_w, + grad_x, + workspace, + m_sizes, + G, + M_BUCKET, + N_grad, # N dimension is now the reduction dimension + K, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_x + + +# ======== dw wrapper function ========== + + +def grouped_gemm_dw_tma( + x: torch.Tensor, + grad_output: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA. + For the forward pass Y = X @ W.T, the backward for weights is: + grad_W = grad_Y.T @ X + + Args: + x: Input tensor, shape [M_total, K] + grad_output: Gradient of output, shape [M_total, N] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor in bytes + + + Returns: + grad_w: Gradient with respect to weights, shape [N, K] + """ + # Check TMA support + if not CudaUtils.verify_tma(): + raise RuntimeError("TMA grouped GEMM requested on a device without TMA support") + + # Get group count + G = m_sizes.shape[0] + + # Ensure contiguous tensors + x = x.contiguous() + grad_output = grad_output.contiguous() + m_sizes = m_sizes.contiguous() + + # Get dimensions + M_total, K_x = x.shape + M_grad, N = grad_output.shape + + # Check dimensions + assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert sum_m_sizes == M_total, ( + f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + ) + + # Create output tensor (grad_w) with shape [N, K] + grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype) + + NUM_SMS = num_sms + + # TODO - hardcoded for now...but should set TMA flags based on hardware support + USE_TMA_LOAD = True + USE_TMA_STORE = True + + # Set up TMA descriptors or direct pointers + if USE_TMA_LOAD or USE_TMA_STORE: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + + if USE_TMA_LOAD: + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("grad_output") + x_desc = desc_helper.get_tma_descriptor_kernel_param("x") + grad_output_desc = desc_helper.get_tma_descriptor_kernel_param( + "grad_output" + ) + else: + x_desc = x + grad_output_desc = grad_output + + if USE_TMA_STORE: + desc_helper.init_tma_descriptor("grad_w") + workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w") + else: + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + else: + # If not using TMA, just use the tensors directly + x_desc = x + grad_output_desc = grad_output + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + + # M_BUCKET for grid size + M_BUCKET = triton.next_power_of_2(M_total) + + # Define grid for kernel launch + def grid(META): + if USE_TMA_LOAD or USE_TMA_STORE: + if USE_TMA_LOAD: + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K_x, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + if USE_TMA_STORE: + desc_helper.fill_2d_tma_descriptor( + "grad_w", + grad_w.data_ptr(), + N, + K_x, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + grad_w.element_size(), + ) + + # Return grid size - one block per SM for balanced work distribution + return (NUM_SMS,) + + # Launch the optimized kernel + _kernel_mg_dw_tma[grid]( + x_desc, + grad_output_desc, + grad_w, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K_x, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_w + + +# ======== End Backwards Wrapper Functions ============= + +# ======== PyTorch wrapper functions ======== + + +class GroupedGemmMg(torch.autograd.Function): + """ + Autograd function for GroupedGEMM with M*G grouping. + Supports both standard and FP8 quantized operations. + """ + + @staticmethod + def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128, using_fp8=False): + """ + Forward pass of GroupedGEMM. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + + # Use regular forward without quantization + output = grouped_gemm_forward( + x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False + ) + + # Save inputs and parameters for backward pass + ctx.save_for_backward(x, w, m_sizes) + ctx.use_tma = use_tma + ctx.tma_size = tma_size + + ctx.save_for_backward(x, w, m_sizes) + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass of M*G GroupedGEMM. + + Args: + grad_output: Gradient of output, shape [M_total, N] + + Returns: + Tuple of gradients: + - grad_x: Gradient with respect to x, shape [M_total, K] + - grad_w: Gradient with respect to w, shape [N, K] + - None: Gradient with respect to m_sizes (not differentiable) + - None: Gradient with respect to use_tma (not differentiable) + - None: Gradient with respect to tma_size (not differentiable) + + """ + # Retrieve saved tensors and parameters + + x, w, m_sizes = ctx.saved_tensors + + use_tma = ctx.use_tma + tma_size = ctx.tma_size + + # Compute gradients using the unified implementation + grad_x, grad_w = grouped_gemm_backward( + grad_output=grad_output, + x=x, + w=w, + m_sizes=m_sizes, + use_tma=use_tma, + tma_size=tma_size, + ) + + # Return gradients for all inputs (None for non-differentiable parameters) + return grad_x, grad_w, None, None + + +def mg_grouped_gemm( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, + using_fp8: bool = False, +) -> torch.Tensor: + """ + Unified differentiable grouped GEMM operation for M*G grouped GEMM. + Supports both standard precision and FP8 quantized operations. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + return GroupedGemmMg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8) diff --git a/src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py b/src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py new file mode 100644 index 000000000..2105ba518 --- /dev/null +++ b/src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# credit - TMAHelper class, AutoTuning are derived from FBGemm: +# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm + +# pyre-unsafe + +import os +import sys +from typing import Dict + +import torch +import triton +import triton.language as tl +from triton.runtime import driver # @manual + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + + +# ===== Supporting utils, CUDA and TMA ===== + + +class CudaUtils: + @staticmethod + def is_cuda() -> bool: + """Check if Triton is running on CUDA backend.""" + return driver.active.get_current_target().backend == "cuda" + + @staticmethod + def verify_tma() -> bool: + """Check if TMA is supported on the current device.""" + return ( + CudaUtils.is_cuda() + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 9 + ) + + @staticmethod + def get_num_sms() -> int: + """Get the number of streaming multiprocessors on the current device.""" + if not CudaUtils.is_cuda(): + raise RuntimeError("Triton is not running on CUDA backend") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + return torch.cuda.get_device_properties("cuda").multi_processor_count + + +class TmaDescriptorHelper: + """Helper class for managing TMA descriptors in Triton kernels. + + Args: + tma_size: Size of the TMA descriptor in bytes + """ + + class KernelParamWrapper: + """Wrapper to implement the TmaDescKernelParam interface.""" + + def __init__(self, desc: torch.Tensor): + self.desc = desc + + def tma_desc_cpu_ptr(self) -> int: + """Return the CPU pointer to the TMA descriptor.""" + return self.desc.data_ptr() + + def __init__(self, tma_size: int = 128): + if not CudaUtils.verify_tma(): + raise RuntimeError( + "TMA not supported on this device (requires Hopper or newer)" + ) + if "nv_tma_desc_type" not in dir(tl): + raise RuntimeError( + "TMA grid constant descriptors not supported in your Triton version" + ) + + self.tma_size = tma_size + self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor + self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor + self.descriptors: Dict[str, torch.Tensor] = {} + + def init_tma_descriptor(self, name: str) -> None: + """Initialize a TMA descriptor with the given name. + + Call this method outside of the lambda function for grid size. + """ + self.descriptors[name] = torch.empty( + self.tma_size, device="cpu", dtype=torch.int8 + ) + + def fill_1d_tma_descriptor( + self, name: str, ptr: int, dim: int, block_dim: int, element_size: int + ) -> None: + """Fill a 1D TMA descriptor. + + Call this method inside the lambda function for grid size. + """ + if name not in self.descriptors: + raise ValueError(f"TMA descriptor '{name}' not initialized") + + desc_x = self.descriptors[name] + if desc_x.data_ptr() % 64 != 0: + raise ValueError("TMA descriptor must be 64-byte aligned") + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, desc_x.data_ptr() + ) + + def fill_2d_tma_descriptor( + self, + name: str, + ptr: int, + dim1: int, + dim0: int, + block_dim1: int, + block_dim0: int, + element_size: int, + ) -> None: + """Fill a 2D TMA descriptor. + + Call this method inside the lambda function for grid size. + """ + if name not in self.descriptors: + raise ValueError(f"TMA descriptor '{name}' not initialized") + + desc_x = self.descriptors[name] + if desc_x.data_ptr() % 64 != 0: + raise ValueError("TMA descriptor must be 64-byte aligned") + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() + ) + + def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper: + """Get the TMA descriptor kernel parameter for the given name.""" + if name not in self.descriptors or self.descriptors[name] is None: + raise ValueError(f"TMA descriptor '{name}' not initialized") + return self.KernelParamWrapper(self.descriptors[name]) + + +# ====== Autotuning utilities ====== +ALIGN_SIZE_M = 128 + +_NV_CONFIGS = [ + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + }, + num_stages=num_stages, + num_warps=num_warps, + num_ctas=num_ctas, + ) + for block_size_m in [ + ALIGN_SIZE_M, + ] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] + for num_ctas in [1] +] + + +def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs): + device = torch.cuda.current_device() + # Check for all possible pointer parameter names + if "grad_input_ptr" in named_args: + ptr_name = "grad_input_ptr" + elif "c_ptr" in named_args: + ptr_name = "c_ptr" + elif "grad_weight_ptr" in named_args: + ptr_name = "grad_weight_ptr" + else: + raise KeyError("No recognized pointer parameter found in kernel arguments") + + if dtsize is None: + dtsize = named_args[ptr_name].element_size() + if dtype is None: + dtype = named_args[ptr_name].dtype + + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( + kw["BLOCK_SIZE_M"], + kw["BLOCK_SIZE_N"], + kw["BLOCK_SIZE_K"], + config.num_stages, + ) + G, M, N, K = ( + named_args["G"], + named_args["M_BUCKET"], + named_args["N"], + named_args["K"], + ) + + # 1. make sure we have enough smem + max_shared_memory = driver.active.utils.get_device_properties(device)[ + "max_shared_mem" + ] + + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory > max_shared_memory: + continue + + M_PER_GROUP = M // G + MIN_M_TILES = 64 + # 2. make sure we don't load M tiles that are too big + if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2): + continue + # 3. make sure we don't load N tiles that are too small + if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2): + continue + + num_sm = driver.active.utils.get_device_properties(device)[ + "multiprocessor_count" + ] + N_TILES = N // BLOCK_N + MIN_N_TILES = 64 + # 4. make sure we don't load N tiles that are too big + if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm: + continue + # 5. make sure we don't load N tiles that are too small + if BLOCK_N < 128 and M * N_TILES > 2 * num_sm: + continue + # 6. make sure K can be evenly divided + if K % BLOCK_K != 0: + continue + + pruned_configs.append(config) + + return pruned_configs + + +# ======== End Autotuning utilities ======== diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 446c73640..99502a30f 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -193,7 +193,7 @@ class PatchManager: if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3": from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe - patch_deepseek_v3_moe() + patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend) elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels: LOG.info( "Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`" diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py index 94791a1d6..eceba8e7b 100644 --- a/src/axolotl/monkeypatch/deepseek_v3/__init__.py +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -8,6 +8,7 @@ import torch from axolotl.kernels.moe import ContiguousGroupedGEMM from axolotl.kernels.moe.indices import generate_permute_indices +from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm from axolotl.utils.logging import get_logger _GROUP_SIZE_M = 128 @@ -30,28 +31,36 @@ def _is_triton_eligible(hidden_states: torch.Tensor) -> bool: def _ensure_combined_expert_weights( - module, dtype: torch.dtype, device: torch.device + module, dtype: torch.dtype, device: torch.device, backend: str ) -> None: if not hasattr(module, "_axolotl_original_specs"): module._axolotl_original_specs = {} - if getattr(module, "_axolotl_combined_weights", False): - # Move cached combined weights to the working dtype/device if required. - for name in _COMBINED_SUBMODULES: - param_name = f"{name}_weight" - param = module.get_parameter(param_name) - if param.device != device or param.dtype != dtype: - module._parameters[param_name] = torch.nn.Parameter( - param.to(device=device, dtype=dtype).contiguous() - ) - module._axolotl_combined_dtype = dtype - module._axolotl_combined_device = device - return + if not hasattr(module, "_axolotl_mg_shapes"): + module._axolotl_mg_shapes = {} - combined = {} + prev_backend = getattr(module, "_axolotl_combined_backend", None) + if getattr(module, "_axolotl_combined_weights", False): + if prev_backend != backend: + _restore_expert_weights(module) + else: + for name in _COMBINED_SUBMODULES: + param_name = f"{name}_weight" + param = module.get_parameter(param_name) + if param.device != device or param.dtype != dtype: + module._parameters[param_name] = torch.nn.Parameter( + param.to(device=device, dtype=dtype).contiguous() + ) + module._axolotl_combined_dtype = dtype + module._axolotl_combined_device = device + module._axolotl_combined_backend = backend + return + + module._axolotl_mg_shapes = {} for name in _COMBINED_SUBMODULES: weights = [] orig_device = None orig_dtype = None + orig_shape = None for expert in module.experts: lin = expert.get_submodule(name) weight_param = lin._parameters.get("weight") @@ -60,19 +69,24 @@ def _ensure_combined_expert_weights( if orig_device is None: orig_device = weight_param.device orig_dtype = weight_param.dtype + orig_shape = tuple(weight_param.shape) weights.append(weight_param.detach().to(device=device, dtype=dtype)) if "weight" in lin._parameters: del lin._parameters["weight"] if "bias" in lin._parameters: - # DeepseekV3 MLP layers are bias-free, but keep this for safety. del lin._parameters["bias"] - combined[name] = torch.stack(weights, dim=0).contiguous() - module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined[name])) - module._axolotl_original_specs[name] = (orig_device, orig_dtype) + if backend == "cg": + combined_weight = torch.stack(weights, dim=0).contiguous() + else: + combined_weight = torch.cat(weights, dim=0).contiguous() + module._axolotl_mg_shapes[name] = orig_shape + module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight)) + module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape) module._axolotl_combined_weights = True module._axolotl_combined_dtype = dtype module._axolotl_combined_device = device + module._axolotl_combined_backend = backend def _restore_expert_weights(module) -> None: @@ -82,19 +96,111 @@ def _restore_expert_weights(module) -> None: for name in _COMBINED_SUBMODULES: param_name = f"{name}_weight" combined = module._parameters.pop(param_name) - orig_device, orig_dtype = module._axolotl_original_specs.get( - name, (combined.device, combined.dtype) + orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get( + name, (combined.device, combined.dtype, None) ) + rows_per = orig_shape[0] if orig_shape else None for idx, expert in enumerate(module.experts): lin = expert.get_submodule(name) + if combined.dim() == 3: + slice_tensor = combined[idx] + elif rows_per is not None: + start = idx * rows_per + end = start + rows_per + slice_tensor = combined[start:end] + else: + raise RuntimeError( + "Unable to recover expert weight shape during restore" + ) lin._parameters["weight"] = torch.nn.Parameter( - combined[idx].detach().clone().to(orig_device, dtype=orig_dtype) + slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype) ) module._axolotl_combined_weights = False module._axolotl_combined_dtype = None module._axolotl_combined_device = None + module._axolotl_combined_backend = None module._axolotl_original_specs = {} + module._axolotl_mg_shapes = {} + + +def _run_cg_grouped_gemm( + module, + grouped_hidden: torch.Tensor, + m_sizes: torch.Tensor, + num_experts: int, + group_size_m: int, + hidden_dtype: torch.dtype, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + _ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg") + + expert_index_tensor = torch.repeat_interleave( + torch.arange(num_experts, device=device, dtype=torch.int32), + m_sizes.to(torch.int64), + ) + + gate_weights = module.get_parameter("gate_proj_weight") + if gate_weights.dim() == 2: + out_dim = gate_weights.shape[0] // num_experts + gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1]) + + up_weights = module.get_parameter("up_proj_weight") + if up_weights.dim() == 2: + out_dim = up_weights.shape[0] // num_experts + up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1]) + + down_weights = module.get_parameter("down_proj_weight") + if down_weights.dim() == 2: + out_dim = down_weights.shape[0] // num_experts + down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1]) + + gate_out = ContiguousGroupedGEMM.apply( + grouped_hidden, + gate_weights, + expert_index_tensor, + group_size_m, + ) + up_out = ContiguousGroupedGEMM.apply( + grouped_hidden, + up_weights, + expert_index_tensor, + group_size_m, + ) + down_out = ContiguousGroupedGEMM.apply( + grouped_hidden, + down_weights, + expert_index_tensor, + group_size_m, + ) + + return ( + gate_out.to(hidden_dtype), + up_out.to(hidden_dtype), + down_out.to(hidden_dtype), + ) + + gate_out = mg_grouped_gemm( + grouped_hidden, + module.get_parameter("gate_proj_weight"), + m_sizes_tensor, + ) + up_out = mg_grouped_gemm( + grouped_hidden, + module.get_parameter("up_proj_weight"), + m_sizes_tensor, + ) + down_out = mg_grouped_gemm( + hidden_grouped, + module.get_parameter("down_proj_weight"), + m_sizes_tensor, + ) + + return ( + gate_out.to(hidden_dtype), + up_out.to(hidden_dtype), + down_out.to(hidden_dtype), + ) def _moe_triton_forward( @@ -103,6 +209,7 @@ def _moe_triton_forward( topk_indices: torch.Tensor, topk_weights: torch.Tensor, group_size_m: int, + backend: str, fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], ) -> torch.Tensor: if not _is_triton_eligible(hidden_states): @@ -146,7 +253,7 @@ def _moe_triton_forward( ) * group_size_m max_len = int(aligned_counts.sum().item()) - permuted_indices, m_sizes, m_offsets = generate_permute_indices( + permuted_indices, m_sizes, _ = generate_permute_indices( counts_int.to(device), experts_per_rank=num_experts, num_ranks=1, @@ -155,12 +262,8 @@ def _moe_triton_forward( use_cpu=not hidden_states.is_cuda, ) - if permuted_indices.device != device: - permuted_indices = permuted_indices.to(device) - if m_sizes.device != device: - m_sizes = m_sizes.to(device) - if m_offsets.device != device: - m_offsets = m_offsets.to(device) + permuted_indices = permuted_indices.to(device) + m_sizes = m_sizes.to(device) permuted_indices_long = permuted_indices.to(torch.int64) valid_mask = permuted_indices_long >= 0 @@ -178,34 +281,35 @@ def _moe_triton_forward( if valid_positions.numel() < max_len: grouped_hidden.index_fill_(0, padded_positions, 0) - expert_index_tensor = torch.repeat_interleave( - torch.arange(num_experts, device=device, dtype=torch.int32), - m_sizes.to(torch.int64), - ) + m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32) - _ensure_combined_expert_weights(module, hidden_dtype, device) - - gate_weights = module.get_parameter("gate_proj_weight") - up_weights = module.get_parameter("up_proj_weight") - down_weights = module.get_parameter("down_proj_weight") - - gate_out = ContiguousGroupedGEMM.apply( - grouped_hidden, - gate_weights, - expert_index_tensor, - group_size_m, - ) - up_out = ContiguousGroupedGEMM.apply( - grouped_hidden, - up_weights, - expert_index_tensor, - group_size_m, - ) + if backend == "mg": + _ensure_combined_expert_weights(module, hidden_dtype, device, backend) + gate_out = mg_grouped_gemm( + grouped_hidden, + module.get_parameter("gate_proj_weight"), + m_sizes_tensor, + ).to(hidden_dtype) + up_out = mg_grouped_gemm( + grouped_hidden, + module.get_parameter("up_proj_weight"), + m_sizes_tensor, + ).to(hidden_dtype) + else: + gate_out, up_out, down_out_cg = _run_cg_grouped_gemm( + module, + grouped_hidden, + m_sizes, + num_experts, + group_size_m, + hidden_dtype, + device, + ) act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn if valid_positions.numel() > 0: - gate_valid = gate_out.index_select(0, valid_positions).to(hidden_dtype) - up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype) + gate_valid = gate_out.index_select(0, valid_positions) + up_valid = up_out.index_select(0, valid_positions) hidden_concat = act_fn(gate_valid) * up_valid else: hidden_concat = torch.empty( @@ -219,15 +323,17 @@ def _moe_triton_forward( if valid_positions.numel() < max_len: hidden_grouped.index_fill_(0, padded_positions, 0) - down_out = ContiguousGroupedGEMM.apply( - hidden_grouped, - down_weights, - expert_index_tensor, - group_size_m, - ) + if backend == "mg": + down_out = mg_grouped_gemm( + hidden_grouped, + module.get_parameter("down_proj_weight"), + m_sizes_tensor, + ).to(hidden_dtype) + else: + down_out = down_out_cg if valid_positions.numel() > 0: - down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype) + down_valid = down_out.index_select(0, valid_positions) else: down_valid = torch.empty( (0, down_out.shape[-1]), device=device, dtype=hidden_dtype @@ -245,11 +351,16 @@ def _moe_triton_forward( return weighted.sum(dim=1) -def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None: +def patch_deepseek_v3_moe( + group_size_m: int = _GROUP_SIZE_M, backend: str = "mg" +) -> None: """Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels.""" from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + if backend not in {"cg", "mg"}: + raise ValueError(f"Unsupported MoE kernel backend: {backend}") + # Record the unpatched implementation so callers can access a true baseline even # after the Triton patch has been applied (e.g. repeated microbenchmarks). if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"): @@ -259,26 +370,26 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None: return original_moe = DeepseekV3MoE._axolotl_triton_original_moe + DeepseekV3MoE._axolotl_triton_backend = backend + DeepseekV3MoE._axolotl_group_size_m = group_size_m def patched_moe(self, hidden_states, topk_indices, topk_weights): + backend_sel = getattr(self, "_axolotl_triton_backend", backend) + group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m) try: return _moe_triton_forward( self, hidden_states, topk_indices, topk_weights, - group_size_m, + group_size_sel, + backend_sel, original_moe, ) - except Exception as err: # fall back if Triton compilation or runtime fails - if not getattr(self, "_axolotl_triton_warned", False): - LOG.warning( - "DeepseekV3MoE Triton path failed; falling back to baseline: %s", - err, - ) - self._axolotl_triton_warned = True + except Exception as err: # surface Triton failures explicitly _restore_expert_weights(self) - return original_moe(self, hidden_states, topk_indices, topk_weights) + LOG.error("DeepseekV3MoE Triton path failed: %s", err) + raise DeepseekV3MoE.moe = patched_moe DeepseekV3MoE._axolotl_triton_patch = True diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index b2ff154f3..1801a0b8b 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -119,6 +119,12 @@ class AxolotlInputConfig( "description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)" }, ) + moe_kernel_backend: Literal["cg", "mg"] | None = Field( + default="mg", + json_schema_extra={ + "description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel." + }, + ) trainer_cls: str | None = Field( default=None,