logs + fix

This commit is contained in:
Dan Saunders
2025-09-17 14:50:49 -04:00
parent 98dc945838
commit d024048d74

View File

@@ -249,9 +249,9 @@ def moe_ffn_forward_grouped(
for i in range(E):
sel = flat_idx == i
if sel.any():
Xi = x_rep[sel]
Xi = x_rep[sel].contiguous()
As.append(Xi)
Bs.append(W13[i])
Bs.append(W13[i].contiguous())
expert_slices.append((i, sel))
if not As:
@@ -263,13 +263,20 @@ def moe_ffn_forward_grouped(
b_tensors: List[torch.Tensor],
target_dtype: torch.dtype,
) -> Optional[List[torch.Tensor]]:
if target_dtype != dt:
a_tensors = [t.to(target_dtype) for t in a_tensors]
b_tensors = [t.to(target_dtype) for t in b_tensors]
outputs = _call_grouped_mm(a_tensors, b_tensors)
if outputs is not None and target_dtype != dt:
outputs = [t.to(dt) for t in outputs]
return outputs
global LAST_ERROR
try:
if target_dtype != dt:
a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors]
b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors]
outputs = _call_grouped_mm(a_tensors, b_tensors)
if outputs is not None and target_dtype != dt:
outputs = [t.to(dt).contiguous() for t in outputs]
return outputs
except RuntimeError as err:
LAST_ERROR = f"grouped_mm cast failure: {err}" # type: ignore[assignment]
if torch.cuda.is_available(): # pragma: no cover - defensive
torch.cuda.synchronize()
return None
def _try_grouped_mm(
a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor]
@@ -305,7 +312,7 @@ def moe_ffn_forward_grouped(
I2 = Yi.shape[-1] // 2
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
As2.append(Yi_hidden)
Bs2.append(W2[i])
Bs2.append(W2[i].contiguous())
Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2)
if Y2_list is None: