From 91393c4dc8315da2d2d7a6e655b75bafef575f80 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 23 Sep 2025 18:20:57 -0400 Subject: [PATCH] allocator --- .../kernels/moe/tt_mg_gemm/mg_grouped_gemm.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py index 6149dfc72..e7ade61b6 100644 --- a/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py +++ b/src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py @@ -26,6 +26,21 @@ logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) + +_allocator_registered = False + + +def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor: + return torch.empty(size, device="cuda", dtype=torch.int8) + + +def _ensure_triton_allocator() -> None: + global _allocator_registered + if not _allocator_registered: + triton.set_allocator(_torch_allocator) + _allocator_registered = True + + # ============== Start Triton Kernels =============== @@ -402,6 +417,7 @@ def grouped_gemm_forward( using_fp8: bool = False, ) -> torch.Tensor: """Grouped GEMM forward using Hopper TMA kernels.""" + _ensure_triton_allocator() if not CudaUtils.verify_tma(): raise NotImplementedError("Grouped GEMM without TMA is not supported yet") if using_fp8: @@ -549,6 +565,7 @@ def grouped_gemm_dx_tma( tma_size: int = 128, ) -> torch.Tensor: """Compute grad_x using the Hopper grouped GEMM kernel.""" + _ensure_triton_allocator() if not CudaUtils.verify_tma(): raise NotImplementedError("Optimized dx computation requires TMA support") @@ -600,6 +617,7 @@ def grouped_gemm_dw_tma( tma_size: int = 128, ) -> torch.Tensor: """Compute grad_w using the Hopper grouped GEMM kernel.""" + _ensure_triton_allocator() if not CudaUtils.verify_tma(): raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")