dtype
This commit is contained in:
@@ -69,10 +69,10 @@ def _kernel_cg_backward_dx(
|
|||||||
mask_w = mask_n[:, None] & mask_k[None, :]
|
mask_w = mask_n[:, None] & mask_k[None, :]
|
||||||
|
|
||||||
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
|
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||||
|
|
||||||
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||||
w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
|
||||||
|
|
||||||
grad_input += tl.dot(go, w)
|
grad_input += tl.dot(go, w)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,16 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||||
return hidden_states.is_cuda and hidden_states.shape[0] > 0
|
if not hidden_states.is_cuda or hidden_states.shape[0] == 0:
|
||||||
|
return False
|
||||||
|
major, _ = torch.cuda.get_device_capability(hidden_states.device)
|
||||||
|
if major < 9:
|
||||||
|
LOG.debug(
|
||||||
|
"Skipping Triton MoE kernels: requires compute capability >= 90, found %s",
|
||||||
|
major,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _ensure_combined_expert_weights(
|
def _ensure_combined_expert_weights(
|
||||||
@@ -261,7 +270,7 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
|||||||
group_size_m,
|
group_size_m,
|
||||||
original_moe,
|
original_moe,
|
||||||
)
|
)
|
||||||
except RuntimeError as err:
|
except Exception as err: # fall back if Triton compilation or runtime fails
|
||||||
if not getattr(self, "_axolotl_triton_warned", False):
|
if not getattr(self, "_axolotl_triton_warned", False):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",
|
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",
|
||||||
|
|||||||
Reference in New Issue
Block a user