diff --git a/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py b/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py index eec527ef9..514211d50 100644 --- a/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py +++ b/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py @@ -69,10 +69,10 @@ def _kernel_cg_backward_dx( mask_w = mask_n[:, None] & mask_k[None, :] go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :] - go = tl.load(go_ptrs, mask=mask_go, other=0.0) + go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32) w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :] - w = tl.load(w_ptrs, mask=mask_w, other=0.0) + w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32) grad_input += tl.dot(go, w) diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py index d613a49f6..94791a1d6 100644 --- a/src/axolotl/monkeypatch/deepseek_v3/__init__.py +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -17,7 +17,16 @@ LOG = get_logger(__name__) def _is_triton_eligible(hidden_states: torch.Tensor) -> bool: - return hidden_states.is_cuda and hidden_states.shape[0] > 0 + if not hidden_states.is_cuda or hidden_states.shape[0] == 0: + return False + major, _ = torch.cuda.get_device_capability(hidden_states.device) + if major < 9: + LOG.debug( + "Skipping Triton MoE kernels: requires compute capability >= 90, found %s", + major, + ) + return False + return True def _ensure_combined_expert_weights( @@ -261,7 +270,7 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None: group_size_m, original_moe, ) - except RuntimeError as err: + except Exception as err: # fall back if Triton compilation or runtime fails if not getattr(self, "_axolotl_triton_warned", False): LOG.warning( "DeepseekV3MoE Triton path failed; falling back to baseline: %s",