fix
This commit is contained in:
@@ -258,28 +258,38 @@ def moe_ffn_forward_grouped(
|
|||||||
out = torch.zeros_like(x)
|
out = torch.zeros_like(x)
|
||||||
return out.view(bsz, seqlen, hdim), router_logits
|
return out.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
comp_dtype = dt
|
def _run_grouped_mm(
|
||||||
if dt == torch.bfloat16:
|
a_tensors: List[torch.Tensor],
|
||||||
comp_dtype = torch.float16
|
b_tensors: List[torch.Tensor],
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_cast", False):
|
target_dtype: torch.dtype,
|
||||||
_LOGGER.info("torch_grouped: casting grouped_mm operands to float16")
|
) -> Optional[List[torch.Tensor]]:
|
||||||
experts_module._ax_grouped_logged_cast = True
|
if target_dtype != dt:
|
||||||
|
a_tensors = [t.to(target_dtype) for t in a_tensors]
|
||||||
|
b_tensors = [t.to(target_dtype) for t in b_tensors]
|
||||||
|
outputs = _call_grouped_mm(a_tensors, b_tensors)
|
||||||
|
if outputs is not None and target_dtype != dt:
|
||||||
|
outputs = [t.to(dt) for t in outputs]
|
||||||
|
return outputs
|
||||||
|
|
||||||
def _maybe_cast(
|
def _try_grouped_mm(
|
||||||
tensors: List[torch.Tensor], *, to_dtype: torch.dtype
|
a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor]
|
||||||
) -> List[torch.Tensor]:
|
) -> Tuple[Optional[List[torch.Tensor]], bool]:
|
||||||
if to_dtype == dt:
|
global LAST_ERROR
|
||||||
return tensors
|
result = _run_grouped_mm(a_tensors, b_tensors, target_dtype=dt)
|
||||||
return [t.to(to_dtype) for t in tensors]
|
cast_used_local = False
|
||||||
|
if result is None and dt == torch.bfloat16:
|
||||||
|
result = _run_grouped_mm(a_tensors, b_tensors, target_dtype=torch.float16)
|
||||||
|
if result is not None:
|
||||||
|
cast_used_local = True
|
||||||
|
LAST_ERROR = None
|
||||||
|
if not getattr(experts_module, "_ax_grouped_logged_cast", False):
|
||||||
|
_LOGGER.info(
|
||||||
|
"torch_grouped: grouped_mm casting bfloat16 operands to float16"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_cast = True
|
||||||
|
return result, cast_used_local
|
||||||
|
|
||||||
def _restore_dtype(tensors: List[torch.Tensor]) -> List[torch.Tensor]:
|
Y_list, _cast_used_up = _try_grouped_mm(As, Bs)
|
||||||
if comp_dtype == dt:
|
|
||||||
return tensors
|
|
||||||
return [t.to(dt) for t in tensors]
|
|
||||||
|
|
||||||
As_mm = _maybe_cast(As, to_dtype=comp_dtype)
|
|
||||||
Bs_mm = _maybe_cast(Bs, to_dtype=comp_dtype)
|
|
||||||
Y_list = _call_grouped_mm(As_mm, Bs_mm)
|
|
||||||
if Y_list is None:
|
if Y_list is None:
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
@@ -287,7 +297,6 @@ def moe_ffn_forward_grouped(
|
|||||||
)
|
)
|
||||||
experts_module._ax_grouped_logged_fail = True
|
experts_module._ax_grouped_logged_fail = True
|
||||||
return None, None
|
return None, None
|
||||||
Y_list = _restore_dtype(Y_list)
|
|
||||||
|
|
||||||
As2: List[torch.Tensor] = []
|
As2: List[torch.Tensor] = []
|
||||||
Bs2: List[torch.Tensor] = []
|
Bs2: List[torch.Tensor] = []
|
||||||
@@ -298,9 +307,7 @@ def moe_ffn_forward_grouped(
|
|||||||
As2.append(Yi_hidden)
|
As2.append(Yi_hidden)
|
||||||
Bs2.append(W2[i])
|
Bs2.append(W2[i])
|
||||||
|
|
||||||
As2_mm = _maybe_cast(As2, to_dtype=comp_dtype)
|
Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2)
|
||||||
Bs2_mm = _maybe_cast(Bs2, to_dtype=comp_dtype)
|
|
||||||
Y2_list = _call_grouped_mm(As2_mm, Bs2_mm)
|
|
||||||
if Y2_list is None:
|
if Y2_list is None:
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
@@ -308,7 +315,6 @@ def moe_ffn_forward_grouped(
|
|||||||
)
|
)
|
||||||
experts_module._ax_grouped_logged_fail = True
|
experts_module._ax_grouped_logged_fail = True
|
||||||
return None, None
|
return None, None
|
||||||
Y2_list = _restore_dtype(Y2_list)
|
|
||||||
|
|
||||||
for (_i, sel), Out_i in zip(expert_slices, Y2_list, strict=False):
|
for (_i, sel), Out_i in zip(expert_slices, Y2_list, strict=False):
|
||||||
y_buf[sel] = Out_i
|
y_buf[sel] = Out_i
|
||||||
|
|||||||
Reference in New Issue
Block a user