dtype
This commit is contained in:
@@ -69,10 +69,10 @@ def _kernel_cg_backward_dx(
|
||||
mask_w = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||
|
||||
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
||||
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
|
||||
|
||||
grad_input += tl.dot(go, w)
|
||||
|
||||
|
||||
@@ -17,7 +17,16 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
return hidden_states.is_cuda and hidden_states.shape[0] > 0
|
||||
if not hidden_states.is_cuda or hidden_states.shape[0] == 0:
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability(hidden_states.device)
|
||||
if major < 9:
|
||||
LOG.debug(
|
||||
"Skipping Triton MoE kernels: requires compute capability >= 90, found %s",
|
||||
major,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_combined_expert_weights(
|
||||
@@ -261,7 +270,7 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
||||
group_size_m,
|
||||
original_moe,
|
||||
)
|
||||
except RuntimeError as err:
|
||||
except Exception as err: # fall back if Triton compilation or runtime fails
|
||||
if not getattr(self, "_axolotl_triton_warned", False):
|
||||
LOG.warning(
|
||||
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",
|
||||
|
||||
Reference in New Issue
Block a user