yet another refactor

This commit is contained in:
Dan Saunders
2025-09-18 12:47:15 -04:00
parent 0295df5bca
commit 7500641601
2 changed files with 6598 additions and 36 deletions

View File

@@ -45,6 +45,29 @@ class _GroupedWeightStorage:
device: torch.device
def _allocate_fused_gate_up(
num_experts: int,
gate_shape: torch.Size,
up_shape: torch.Size,
*,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if gate_shape[1] != up_shape[1]:
raise RuntimeError(
"torch_grouped: gate and up projections must share the hidden dimension"
)
fused = torch.empty(
(num_experts, gate_shape[0] + up_shape[0], gate_shape[1]),
device=device,
dtype=dtype,
)
gate_view = fused[:, : gate_shape[0]]
up_view = fused[:, gate_shape[0] : gate_shape[0] + up_shape[0]]
return fused, gate_view, up_view
def _ensure_grouped_weights(
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
) -> _GroupedWeightStorage:
@@ -63,28 +86,26 @@ def _ensure_grouped_weights(
and hasattr(sample_mod, "w2")
):
pattern = "swi_glu"
num_experts = len(expert_impls)
w1_shape = sample_mod.w1.weight.shape
w3_shape = sample_mod.w3.weight.shape
w2_shape = sample_mod.w2.weight.shape
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.w1.weight.dtype
and storage.device == sample_mod.w1.weight.device
and storage.gate.shape[1:] == w1_shape
):
return storage
num_experts = len(expert_impls)
w1_shape = sample_mod.w1.weight.shape
w3_shape = sample_mod.w3.weight.shape
w2_shape = sample_mod.w2.weight.shape
gate = torch.empty(
(num_experts, *w1_shape),
fused, gate, up = _allocate_fused_gate_up(
num_experts,
w1_shape,
w3_shape,
device=sample_mod.w1.weight.device,
dtype=sample_mod.w1.weight.dtype,
)
up = torch.empty(
(num_experts, *w3_shape),
device=sample_mod.w3.weight.device,
dtype=sample_mod.w3.weight.dtype,
)
down = torch.empty(
(num_experts, *w2_shape),
device=sample_mod.w2.weight.device,
@@ -102,7 +123,6 @@ def _ensure_grouped_weights(
mod.w2.weight.detach_()
mod.w2.weight.set_(down[idx])
fused = torch.cat((gate, up), dim=1).contiguous()
return _store(
_GroupedWeightStorage(
pattern=pattern,
@@ -124,6 +144,8 @@ def _ensure_grouped_weights(
and storage.pattern == pattern
and storage.dtype == gate_weight.dtype
and storage.device == gate_weight.device
and storage.gate.shape[1:]
== (gate_weight.shape[0] // 2, gate_weight.shape[1])
):
return storage
@@ -168,25 +190,23 @@ def _ensure_grouped_weights(
and hasattr(sample_mod, "down_proj")
):
pattern = "dual_proj"
up_weight = sample_mod.up_proj.weight
gate_weight = sample_mod.gate_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.up_proj.weight.dtype
and storage.device == sample_mod.up_proj.weight.device
and storage.gate.shape[1:] == gate_weight.shape
):
return storage
num_experts = len(expert_impls)
up_weight = sample_mod.up_proj.weight
gate_weight = sample_mod.gate_proj.weight
down_weight = sample_mod.down_proj.weight
up = torch.empty(
(num_experts, *up_weight.shape),
device=up_weight.device,
dtype=up_weight.dtype,
)
gate = torch.empty(
(num_experts, *gate_weight.shape),
fused, gate, up = _allocate_fused_gate_up(
num_experts,
gate_weight.shape,
up_weight.shape,
device=gate_weight.device,
dtype=gate_weight.dtype,
)
@@ -197,8 +217,8 @@ def _ensure_grouped_weights(
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
up[idx].copy_(mod.up_proj.weight.detach())
gate[idx].copy_(mod.gate_proj.weight.detach())
up[idx].copy_(mod.up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.up_proj.weight.detach_()
mod.up_proj.weight.set_(up[idx])
@@ -207,7 +227,6 @@ def _ensure_grouped_weights(
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
fused = torch.cat((gate, up), dim=1).contiguous()
return _store(
_GroupedWeightStorage(
pattern=pattern,
@@ -268,33 +287,31 @@ def moe_ffn_forward_grouped(
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
perm = torch.argsort(flat_idx, stable=True)
sorted_experts = flat_idx[perm]
sorted_experts, perm = torch.sort(flat_idx)
assignments = torch.bincount(sorted_experts, minlength=num_experts)
if assignments.sum() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
token_indices_sorted = perm // top_k
scores_sorted = topk_weight.view(-1)[perm]
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index).contiguous()
routed_input = x_flat.index_select(0, token_indices_sorted).contiguous()
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1)
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1).contiguous()
if active_idx.numel() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
counts_active = assignments[active_idx]
offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0)
offsets = offsets.to(torch.int32)
counts_active_i32 = counts_active.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_active_i32, dim=0)
if offsets[-1].item() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
w_gate_up_t = w_gate_up[active_idx].transpose(-2, -1)
w2_t = w2[active_idx].transpose(-2, -1)
w_gate_up_t = w_gate_up.index_select(0, active_idx).transpose(-2, -1)
w2_t = w2.index_select(0, active_idx).transpose(-2, -1)
routed_in = routed_input.to(expert_dtype)
gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets)
@@ -307,5 +324,5 @@ def moe_ffn_forward_grouped(
down_out.mul_(weights)
combined = torch.zeros_like(x_flat)
combined.scatter_add_(0, gather_index, down_out)
combined.index_add_(0, token_indices_sorted, down_out)
return combined.view(bsz, seqlen, hdim), router_logits

6545
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff