simplify
This commit is contained in:
@@ -23,7 +23,7 @@ def available() -> bool:
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 9:
|
||||
return False
|
||||
return hasattr(torch.ops, "aten") and hasattr(torch.ops.aten, "_grouped_mm")
|
||||
return hasattr(torch.ops, "_grouped_mm")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -57,7 +57,7 @@ def _call_grouped_mm(
|
||||
As: List[torch.Tensor], Bs: List[torch.Tensor]
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
"""
|
||||
Call grouped mm using aten._grouped_mm with packed representation.
|
||||
Call grouped mm with packed representation
|
||||
|
||||
- A_cat: concat As along rows -> [sum_i Mi, K]
|
||||
- B_stk: stack Bs per group -> [G, K, N]
|
||||
@@ -65,35 +65,23 @@ def _call_grouped_mm(
|
||||
Returns list of per-group outputs split from concatenated result.
|
||||
"""
|
||||
global LAST_ERROR
|
||||
try:
|
||||
# Ensure 2D contiguous inputs
|
||||
As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||
Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||
# Ensure 2D contiguous inputs
|
||||
As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||
Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||
|
||||
if not As2:
|
||||
return []
|
||||
device = As2[0].device
|
||||
A_cat = torch.cat(As2, dim=0)
|
||||
B_stk = torch.stack(Bs2, dim=0)
|
||||
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
||||
|
||||
if hasattr(torch.ops.aten, "_grouped_mm"):
|
||||
try:
|
||||
Y_cat = torch.ops.aten._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined]
|
||||
outs: List[torch.Tensor] = []
|
||||
start = 0
|
||||
for m in offs.tolist():
|
||||
outs.append(Y_cat[start : start + m, :])
|
||||
start += m
|
||||
return outs
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"_grouped_mm failed: {e}"
|
||||
return None
|
||||
LAST_ERROR = "aten._grouped_mm not present"
|
||||
return None
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"call error: {e}"
|
||||
return None
|
||||
if not As2:
|
||||
return []
|
||||
device = As2[0].device
|
||||
A_cat = torch.cat(As2, dim=0).to(torch.bfloat16)
|
||||
B_stk = torch.stack(Bs2, dim=0).to(torch.bfloat16)
|
||||
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
||||
Y_cat = torch._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined]
|
||||
outs: List[torch.Tensor] = []
|
||||
start = 0
|
||||
for m in offs.tolist():
|
||||
outs.append(Y_cat[start : start + m, :])
|
||||
start += m
|
||||
return outs
|
||||
|
||||
|
||||
def moe_ffn_forward_grouped(
|
||||
@@ -103,6 +91,9 @@ def moe_ffn_forward_grouped(
|
||||
global LAST_ERROR
|
||||
LAST_ERROR = None
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
compute_dtype = gate_linear.weight.dtype
|
||||
if hidden_states.dtype != compute_dtype:
|
||||
hidden_states = hidden_states.to(dtype=compute_dtype)
|
||||
x = hidden_states.view(-1, hdim)
|
||||
router_logits = gate_linear(x)
|
||||
|
||||
@@ -114,16 +105,7 @@ def moe_ffn_forward_grouped(
|
||||
flat_idx = topk_idx.view(-1)
|
||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||
|
||||
try:
|
||||
E = _num_experts(experts_module)
|
||||
except AttributeError as err:
|
||||
LAST_ERROR = str(err)
|
||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||
_LOGGER.warning(
|
||||
"torch_grouped: could not determine expert count; falling back to naive"
|
||||
)
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
E = _num_experts(experts_module)
|
||||
dev, dt = x.device, x.dtype
|
||||
first = experts_module[0]
|
||||
|
||||
@@ -161,87 +143,76 @@ def moe_ffn_forward_grouped(
|
||||
raise AttributeError(f"expert {idx} missing nested module '{nested_attr}'")
|
||||
return nested_mod
|
||||
|
||||
try:
|
||||
if is_mixtral:
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w1")
|
||||
or experts_module._stacked_w1.device != dev
|
||||
or experts_module._stacked_w1.dtype != dt
|
||||
):
|
||||
mods = [_resolve_expert(i) for i in range(E)]
|
||||
w1 = [mod.w1.weight.t() for mod in mods]
|
||||
w3 = [mod.w3.weight.t() for mod in mods]
|
||||
w2 = [mod.w2.weight.t() for mod in mods]
|
||||
experts_module._stacked_w1 = (
|
||||
torch.stack(w1, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w3 = (
|
||||
torch.stack(w3, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(w2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w13 = torch.cat(
|
||||
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
||||
).contiguous()
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
else:
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w13")
|
||||
or experts_module._stacked_w13.device != dev
|
||||
or experts_module._stacked_w13.dtype != dt
|
||||
):
|
||||
w13 = []
|
||||
w2 = []
|
||||
for i in range(E):
|
||||
mod = _resolve_expert(i)
|
||||
if hasattr(mod, "gate_up_proj"):
|
||||
w13.append(mod.gate_up_proj.weight.t())
|
||||
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"):
|
||||
w13.append(
|
||||
torch.cat(
|
||||
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
LAST_ERROR = "unrecognized Qwen MoE expert weight layout"
|
||||
if not getattr(
|
||||
experts_module, "_ax_grouped_logged_fail", False
|
||||
):
|
||||
_LOGGER.warning(
|
||||
"torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive"
|
||||
)
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
w2.append(mod.down_proj.weight.t())
|
||||
experts_module._stacked_w13 = (
|
||||
torch.stack(w13, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(w2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
except AttributeError as err:
|
||||
LAST_ERROR = str(err)
|
||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||
_LOGGER.warning(
|
||||
"torch_grouped: expert weights missing expected attributes; falling back to naive"
|
||||
if is_mixtral:
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w1")
|
||||
or experts_module._stacked_w1.device != dev
|
||||
or experts_module._stacked_w1.dtype != dt
|
||||
):
|
||||
mods = [_resolve_expert(i) for i in range(E)]
|
||||
w1 = [mod.w1.weight.t() for mod in mods]
|
||||
w3 = [mod.w3.weight.t() for mod in mods]
|
||||
w2 = [mod.w2.weight.t() for mod in mods]
|
||||
experts_module._stacked_w1 = (
|
||||
torch.stack(w1, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
experts_module._stacked_w3 = (
|
||||
torch.stack(w3, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(w2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w13 = torch.cat(
|
||||
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
||||
).contiguous()
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
else:
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w13")
|
||||
or experts_module._stacked_w13.device != dev
|
||||
or experts_module._stacked_w13.dtype != dt
|
||||
):
|
||||
w13 = []
|
||||
w2 = []
|
||||
for i in range(E):
|
||||
mod = _resolve_expert(i)
|
||||
if hasattr(mod, "gate_up_proj"):
|
||||
w13.append(mod.gate_up_proj.weight.t())
|
||||
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"):
|
||||
w13.append(
|
||||
torch.cat(
|
||||
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
LAST_ERROR = "unrecognized Qwen MoE expert weight layout"
|
||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||
_LOGGER.warning(
|
||||
"torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive"
|
||||
)
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
w2.append(mod.down_proj.weight.t())
|
||||
experts_module._stacked_w13 = (
|
||||
torch.stack(w13, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(w2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
|
||||
As: List[torch.Tensor] = []
|
||||
Bs: List[torch.Tensor] = []
|
||||
@@ -251,7 +222,7 @@ def moe_ffn_forward_grouped(
|
||||
if sel.any():
|
||||
Xi = x_rep[sel].contiguous()
|
||||
As.append(Xi)
|
||||
Bs.append(W13[i].contiguous())
|
||||
Bs.append(W13[i].reshape(hdim, -1).contiguous())
|
||||
expert_slices.append((i, sel))
|
||||
|
||||
if not As:
|
||||
@@ -264,19 +235,13 @@ def moe_ffn_forward_grouped(
|
||||
target_dtype: torch.dtype,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
global LAST_ERROR
|
||||
try:
|
||||
if target_dtype != dt:
|
||||
a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors]
|
||||
b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors]
|
||||
outputs = _call_grouped_mm(a_tensors, b_tensors)
|
||||
if outputs is not None and target_dtype != dt:
|
||||
outputs = [t.to(dt).contiguous() for t in outputs]
|
||||
return outputs
|
||||
except RuntimeError as err:
|
||||
LAST_ERROR = f"grouped_mm cast failure: {err}" # type: ignore[assignment]
|
||||
if torch.cuda.is_available(): # pragma: no cover - defensive
|
||||
torch.cuda.synchronize()
|
||||
return None
|
||||
if target_dtype != dt:
|
||||
a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors]
|
||||
b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors]
|
||||
outputs = _call_grouped_mm(a_tensors, b_tensors)
|
||||
if outputs is not None and target_dtype != dt:
|
||||
outputs = [t.to(dt).contiguous() for t in outputs]
|
||||
return outputs
|
||||
|
||||
def _try_grouped_mm(
|
||||
a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor]
|
||||
@@ -312,7 +277,7 @@ def moe_ffn_forward_grouped(
|
||||
I2 = Yi.shape[-1] // 2
|
||||
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
||||
As2.append(Yi_hidden)
|
||||
Bs2.append(W2[i].contiguous())
|
||||
Bs2.append(W2[i].reshape(I2, hdim).contiguous())
|
||||
|
||||
Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2)
|
||||
if Y2_list is None:
|
||||
|
||||
Reference in New Issue
Block a user