allocator

This commit is contained in:
Dan Saunders
2025-09-23 18:20:57 -04:00
parent d578c53603
commit 91393c4dc8

View File

@@ -26,6 +26,21 @@ logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
_allocator_registered = False
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
return torch.empty(size, device="cuda", dtype=torch.int8)
def _ensure_triton_allocator() -> None:
global _allocator_registered
if not _allocator_registered:
triton.set_allocator(_torch_allocator)
_allocator_registered = True
# ============== Start Triton Kernels ===============
@@ -402,6 +417,7 @@ def grouped_gemm_forward(
using_fp8: bool = False,
) -> torch.Tensor:
"""Grouped GEMM forward using Hopper TMA kernels."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
if using_fp8:
@@ -549,6 +565,7 @@ def grouped_gemm_dx_tma(
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_x using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Optimized dx computation requires TMA support")
@@ -600,6 +617,7 @@ def grouped_gemm_dw_tma(
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_w using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")