This commit is contained in:
Dan Saunders
2025-09-23 12:14:55 -04:00
parent 1640cd4006
commit ab8fa56b16
2 changed files with 13 additions and 4 deletions

View File

@@ -69,10 +69,10 @@ def _kernel_cg_backward_dx(
mask_w = mask_n[:, None] & mask_k[None, :]
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
w = tl.load(w_ptrs, mask=mask_w, other=0.0)
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
grad_input += tl.dot(go, w)

View File

@@ -17,7 +17,16 @@ LOG = get_logger(__name__)
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
return hidden_states.is_cuda and hidden_states.shape[0] > 0
if not hidden_states.is_cuda or hidden_states.shape[0] == 0:
return False
major, _ = torch.cuda.get_device_capability(hidden_states.device)
if major < 9:
LOG.debug(
"Skipping Triton MoE kernels: requires compute capability >= 90, found %s",
major,
)
return False
return True
def _ensure_combined_expert_weights(
@@ -261,7 +270,7 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
group_size_m,
original_moe,
)
except RuntimeError as err:
except Exception as err: # fall back if Triton compilation or runtime fails
if not getattr(self, "_axolotl_triton_warned", False):
LOG.warning(
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",