allocator
This commit is contained in:
@@ -26,6 +26,21 @@ logging.basicConfig(
|
|||||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_allocator_registered = False
|
||||||
|
|
||||||
|
|
||||||
|
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
|
||||||
|
return torch.empty(size, device="cuda", dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_triton_allocator() -> None:
|
||||||
|
global _allocator_registered
|
||||||
|
if not _allocator_registered:
|
||||||
|
triton.set_allocator(_torch_allocator)
|
||||||
|
_allocator_registered = True
|
||||||
|
|
||||||
|
|
||||||
# ============== Start Triton Kernels ===============
|
# ============== Start Triton Kernels ===============
|
||||||
|
|
||||||
|
|
||||||
@@ -402,6 +417,7 @@ def grouped_gemm_forward(
|
|||||||
using_fp8: bool = False,
|
using_fp8: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Grouped GEMM forward using Hopper TMA kernels."""
|
"""Grouped GEMM forward using Hopper TMA kernels."""
|
||||||
|
_ensure_triton_allocator()
|
||||||
if not CudaUtils.verify_tma():
|
if not CudaUtils.verify_tma():
|
||||||
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
||||||
if using_fp8:
|
if using_fp8:
|
||||||
@@ -549,6 +565,7 @@ def grouped_gemm_dx_tma(
|
|||||||
tma_size: int = 128,
|
tma_size: int = 128,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute grad_x using the Hopper grouped GEMM kernel."""
|
"""Compute grad_x using the Hopper grouped GEMM kernel."""
|
||||||
|
_ensure_triton_allocator()
|
||||||
if not CudaUtils.verify_tma():
|
if not CudaUtils.verify_tma():
|
||||||
raise NotImplementedError("Optimized dx computation requires TMA support")
|
raise NotImplementedError("Optimized dx computation requires TMA support")
|
||||||
|
|
||||||
@@ -600,6 +617,7 @@ def grouped_gemm_dw_tma(
|
|||||||
tma_size: int = 128,
|
tma_size: int = 128,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute grad_w using the Hopper grouped GEMM kernel."""
|
"""Compute grad_w using the Hopper grouped GEMM kernel."""
|
||||||
|
_ensure_triton_allocator()
|
||||||
if not CudaUtils.verify_tma():
|
if not CudaUtils.verify_tma():
|
||||||
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
|
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user