diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 90d6e9828..467dfeb89 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -93,18 +93,16 @@ def moe_ffn_forward_grouped( device = hidden_states.device routing_dtype = gate_linear.weight.dtype - expert_dtype = hidden_states.dtype - x_flat = hidden_states.view(tokens, hdim) - router_logits = gate_linear(x_flat.to(routing_dtype)) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) - topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) - - sample = getattr( + sample_mod = getattr( experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0]) ) - if hasattr(sample, "w1") and hasattr(sample, "w3") and hasattr(sample, "w2"): + if ( + hasattr(sample_mod, "w1") + and hasattr(sample_mod, "w3") + and hasattr(sample_mod, "w2") + ): + expert_dtype = sample_mod.w1.weight.dtype w13 = _stack_weights( experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device ) @@ -112,11 +110,12 @@ def moe_ffn_forward_grouped( experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device ) else: - names13 = ( - ("gate_up_proj",) - if hasattr(sample, "gate_up_proj") - else ("up_proj", "gate_proj") - ) + if hasattr(sample_mod, "gate_up_proj"): + expert_dtype = sample_mod.gate_up_proj.weight.dtype + names13: Tuple[str, ...] = ("gate_up_proj",) + else: + expert_dtype = sample_mod.up_proj.weight.dtype + names13 = ("up_proj", "gate_proj") w13 = _stack_weights( experts_module, names13, key="w13", dtype=expert_dtype, device=device ) @@ -124,8 +123,15 @@ def moe_ffn_forward_grouped( experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device ) + x_flat = hidden_states.view(tokens, hdim).to(expert_dtype) + router_logits = gate_linear(x_flat.to(routing_dtype)) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) + topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) + flat_idx = topk_idx.view(-1) - x_rep = x_flat.to(expert_dtype).repeat_interleave(top_k, dim=0) + x_rep = x_flat.repeat_interleave(top_k, dim=0) as_list: List[torch.Tensor] = [] bs_list: List[torch.Tensor] = []