fix
This commit is contained in:
@@ -258,7 +258,28 @@ def moe_ffn_forward_grouped(
|
||||
out = torch.zeros_like(x)
|
||||
return out.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
Y_list = _call_grouped_mm(As, Bs)
|
||||
comp_dtype = dt
|
||||
if dt == torch.bfloat16:
|
||||
comp_dtype = torch.float16
|
||||
if not getattr(experts_module, "_ax_grouped_logged_cast", False):
|
||||
_LOGGER.info("torch_grouped: casting grouped_mm operands to float16")
|
||||
experts_module._ax_grouped_logged_cast = True
|
||||
|
||||
def _maybe_cast(
|
||||
tensors: List[torch.Tensor], *, to_dtype: torch.dtype
|
||||
) -> List[torch.Tensor]:
|
||||
if to_dtype == dt:
|
||||
return tensors
|
||||
return [t.to(to_dtype) for t in tensors]
|
||||
|
||||
def _restore_dtype(tensors: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
if comp_dtype == dt:
|
||||
return tensors
|
||||
return [t.to(dt) for t in tensors]
|
||||
|
||||
As_mm = _maybe_cast(As, to_dtype=comp_dtype)
|
||||
Bs_mm = _maybe_cast(Bs, to_dtype=comp_dtype)
|
||||
Y_list = _call_grouped_mm(As_mm, Bs_mm)
|
||||
if Y_list is None:
|
||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||
_LOGGER.warning(
|
||||
@@ -266,6 +287,7 @@ def moe_ffn_forward_grouped(
|
||||
)
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
Y_list = _restore_dtype(Y_list)
|
||||
|
||||
As2: List[torch.Tensor] = []
|
||||
Bs2: List[torch.Tensor] = []
|
||||
@@ -276,7 +298,9 @@ def moe_ffn_forward_grouped(
|
||||
As2.append(Yi_hidden)
|
||||
Bs2.append(W2[i])
|
||||
|
||||
Y2_list = _call_grouped_mm(As2, Bs2)
|
||||
As2_mm = _maybe_cast(As2, to_dtype=comp_dtype)
|
||||
Bs2_mm = _maybe_cast(Bs2, to_dtype=comp_dtype)
|
||||
Y2_list = _call_grouped_mm(As2_mm, Bs2_mm)
|
||||
if Y2_list is None:
|
||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||
_LOGGER.warning(
|
||||
@@ -284,6 +308,7 @@ def moe_ffn_forward_grouped(
|
||||
)
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
Y2_list = _restore_dtype(Y2_list)
|
||||
|
||||
for (_i, sel), Out_i in zip(expert_slices, Y2_list, strict=False):
|
||||
y_buf[sel] = Out_i
|
||||
|
||||
Reference in New Issue
Block a user