allocator
This commit is contained in:
@@ -26,6 +26,21 @@ logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
_allocator_registered = False
|
||||
|
||||
|
||||
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
|
||||
return torch.empty(size, device="cuda", dtype=torch.int8)
|
||||
|
||||
|
||||
def _ensure_triton_allocator() -> None:
|
||||
global _allocator_registered
|
||||
if not _allocator_registered:
|
||||
triton.set_allocator(_torch_allocator)
|
||||
_allocator_registered = True
|
||||
|
||||
|
||||
# ============== Start Triton Kernels ===============
|
||||
|
||||
|
||||
@@ -402,6 +417,7 @@ def grouped_gemm_forward(
|
||||
using_fp8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Grouped GEMM forward using Hopper TMA kernels."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
||||
if using_fp8:
|
||||
@@ -549,6 +565,7 @@ def grouped_gemm_dx_tma(
|
||||
tma_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Compute grad_x using the Hopper grouped GEMM kernel."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise NotImplementedError("Optimized dx computation requires TMA support")
|
||||
|
||||
@@ -600,6 +617,7 @@ def grouped_gemm_dw_tma(
|
||||
tma_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Compute grad_w using the Hopper grouped GEMM kernel."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user