logs + fix
This commit is contained in:
@@ -249,9 +249,9 @@ def moe_ffn_forward_grouped(
|
||||
for i in range(E):
|
||||
sel = flat_idx == i
|
||||
if sel.any():
|
||||
Xi = x_rep[sel]
|
||||
Xi = x_rep[sel].contiguous()
|
||||
As.append(Xi)
|
||||
Bs.append(W13[i])
|
||||
Bs.append(W13[i].contiguous())
|
||||
expert_slices.append((i, sel))
|
||||
|
||||
if not As:
|
||||
@@ -263,13 +263,20 @@ def moe_ffn_forward_grouped(
|
||||
b_tensors: List[torch.Tensor],
|
||||
target_dtype: torch.dtype,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
if target_dtype != dt:
|
||||
a_tensors = [t.to(target_dtype) for t in a_tensors]
|
||||
b_tensors = [t.to(target_dtype) for t in b_tensors]
|
||||
outputs = _call_grouped_mm(a_tensors, b_tensors)
|
||||
if outputs is not None and target_dtype != dt:
|
||||
outputs = [t.to(dt) for t in outputs]
|
||||
return outputs
|
||||
global LAST_ERROR
|
||||
try:
|
||||
if target_dtype != dt:
|
||||
a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors]
|
||||
b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors]
|
||||
outputs = _call_grouped_mm(a_tensors, b_tensors)
|
||||
if outputs is not None and target_dtype != dt:
|
||||
outputs = [t.to(dt).contiguous() for t in outputs]
|
||||
return outputs
|
||||
except RuntimeError as err:
|
||||
LAST_ERROR = f"grouped_mm cast failure: {err}" # type: ignore[assignment]
|
||||
if torch.cuda.is_available(): # pragma: no cover - defensive
|
||||
torch.cuda.synchronize()
|
||||
return None
|
||||
|
||||
def _try_grouped_mm(
|
||||
a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor]
|
||||
@@ -305,7 +312,7 @@ def moe_ffn_forward_grouped(
|
||||
I2 = Yi.shape[-1] // 2
|
||||
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
||||
As2.append(Yi_hidden)
|
||||
Bs2.append(W2[i])
|
||||
Bs2.append(W2[i].contiguous())
|
||||
|
||||
Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2)
|
||||
if Y2_list is None:
|
||||
|
||||
Reference in New Issue
Block a user