log
This commit is contained in:
@@ -56,9 +56,13 @@ def _stack_weights(
|
|||||||
|
|
||||||
def _call_grouped_mm(
|
def _call_grouped_mm(
|
||||||
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
|
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
|
||||||
) -> Optional[List[torch.Tensor]]:
|
) -> Tuple[Optional[List[torch.Tensor]], Optional[str]]:
|
||||||
if not As or dtype not in (torch.bfloat16, torch.float16):
|
if not As:
|
||||||
return [] if not As else None
|
return [], None
|
||||||
|
if dtype not in (torch.bfloat16, torch.float16):
|
||||||
|
msg = f"unsupported dtype {dtype}"
|
||||||
|
_LOGGER.debug("torch_grouped: %s", msg)
|
||||||
|
return None, msg
|
||||||
|
|
||||||
try:
|
try:
|
||||||
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||||
@@ -68,7 +72,7 @@ def _call_grouped_mm(
|
|||||||
[a.shape[0] for a in As2], device=device, dtype=torch.int32
|
[a.shape[0] for a in As2], device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
offsets = torch.cumsum(lengths, dim=0)
|
offsets = torch.cumsum(lengths, dim=0)
|
||||||
Y_cat = torch.ops.aten._grouped_mm(
|
Y_cat = torch._grouped_mm(
|
||||||
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets
|
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets
|
||||||
)
|
)
|
||||||
outs: List[torch.Tensor] = []
|
outs: List[torch.Tensor] = []
|
||||||
@@ -76,10 +80,11 @@ def _call_grouped_mm(
|
|||||||
for size in lengths.tolist():
|
for size in lengths.tolist():
|
||||||
outs.append(Y_cat[start : start + size])
|
outs.append(Y_cat[start : start + size])
|
||||||
start += size
|
start += size
|
||||||
return outs
|
return outs, None
|
||||||
except RuntimeError as err:
|
except RuntimeError as err:
|
||||||
_LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err)
|
message = f"_grouped_mm failed ({err})"
|
||||||
return None
|
_LOGGER.warning("torch_grouped: %s", message)
|
||||||
|
return None, message
|
||||||
|
|
||||||
|
|
||||||
def moe_ffn_forward_grouped(
|
def moe_ffn_forward_grouped(
|
||||||
@@ -154,7 +159,7 @@ def moe_ffn_forward_grouped(
|
|||||||
if not as_list:
|
if not as_list:
|
||||||
return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits
|
return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
up_out = _call_grouped_mm(as_list, bs_list, expert_dtype)
|
up_out, reason = _call_grouped_mm(as_list, bs_list, expert_dtype)
|
||||||
if up_out is None:
|
if up_out is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
@@ -167,7 +172,7 @@ def moe_ffn_forward_grouped(
|
|||||||
down_inputs.append(hidden)
|
down_inputs.append(hidden)
|
||||||
down_weights.append(w2[i])
|
down_weights.append(w2[i])
|
||||||
|
|
||||||
down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype)
|
down_out, reason = _call_grouped_mm(down_inputs, down_weights, expert_dtype)
|
||||||
if down_out is None:
|
if down_out is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user