logs + fix
This commit is contained in:
@@ -249,9 +249,9 @@ def moe_ffn_forward_grouped(
|
|||||||
for i in range(E):
|
for i in range(E):
|
||||||
sel = flat_idx == i
|
sel = flat_idx == i
|
||||||
if sel.any():
|
if sel.any():
|
||||||
Xi = x_rep[sel]
|
Xi = x_rep[sel].contiguous()
|
||||||
As.append(Xi)
|
As.append(Xi)
|
||||||
Bs.append(W13[i])
|
Bs.append(W13[i].contiguous())
|
||||||
expert_slices.append((i, sel))
|
expert_slices.append((i, sel))
|
||||||
|
|
||||||
if not As:
|
if not As:
|
||||||
@@ -263,13 +263,20 @@ def moe_ffn_forward_grouped(
|
|||||||
b_tensors: List[torch.Tensor],
|
b_tensors: List[torch.Tensor],
|
||||||
target_dtype: torch.dtype,
|
target_dtype: torch.dtype,
|
||||||
) -> Optional[List[torch.Tensor]]:
|
) -> Optional[List[torch.Tensor]]:
|
||||||
if target_dtype != dt:
|
global LAST_ERROR
|
||||||
a_tensors = [t.to(target_dtype) for t in a_tensors]
|
try:
|
||||||
b_tensors = [t.to(target_dtype) for t in b_tensors]
|
if target_dtype != dt:
|
||||||
outputs = _call_grouped_mm(a_tensors, b_tensors)
|
a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors]
|
||||||
if outputs is not None and target_dtype != dt:
|
b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors]
|
||||||
outputs = [t.to(dt) for t in outputs]
|
outputs = _call_grouped_mm(a_tensors, b_tensors)
|
||||||
return outputs
|
if outputs is not None and target_dtype != dt:
|
||||||
|
outputs = [t.to(dt).contiguous() for t in outputs]
|
||||||
|
return outputs
|
||||||
|
except RuntimeError as err:
|
||||||
|
LAST_ERROR = f"grouped_mm cast failure: {err}" # type: ignore[assignment]
|
||||||
|
if torch.cuda.is_available(): # pragma: no cover - defensive
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return None
|
||||||
|
|
||||||
def _try_grouped_mm(
|
def _try_grouped_mm(
|
||||||
a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor]
|
a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor]
|
||||||
@@ -305,7 +312,7 @@ def moe_ffn_forward_grouped(
|
|||||||
I2 = Yi.shape[-1] // 2
|
I2 = Yi.shape[-1] // 2
|
||||||
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
||||||
As2.append(Yi_hidden)
|
As2.append(Yi_hidden)
|
||||||
Bs2.append(W2[i])
|
Bs2.append(W2[i].contiguous())
|
||||||
|
|
||||||
Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2)
|
Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2)
|
||||||
if Y2_list is None:
|
if Y2_list is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user