diff --git a/src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py b/src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py index 2105ba518..f77ecbfeb 100644 --- a/src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py +++ b/src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py @@ -77,8 +77,8 @@ class TmaDescriptorHelper: ) self.tma_size = tma_size - self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor - self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor + self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor + self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor self.descriptors: Dict[str, torch.Tensor] = {} def init_tma_descriptor(self, name: str) -> None: