diff --git a/src/axolotl/kernels/moe/hf_triton.py b/src/axolotl/kernels/moe/hf_triton.py index a04af5359..ec42e81e4 100644 --- a/src/axolotl/kernels/moe/hf_triton.py +++ b/src/axolotl/kernels/moe/hf_triton.py @@ -86,19 +86,42 @@ def moe_ffn_forward_stub( routing_data, gather_idx, scatter_idx = handles.routing.routing_torch( router_logits, n_expts_act=top_k ) - # Prepare expert weights: shapes [E, K, N] + # Prepare and cache expert weights: shapes [E, K, N] + import torch + E = experts_module.num_experts - K = hdim - # up projections - W1 = [] - W3 = [] - for i in range(E): - exp = experts_module[i] - # Linear weight is [out, in]; need [in, out] - W1.append(exp.w1.weight.t()) - W3.append(exp.w3.weight.t()) - W1 = torch.stack(W1, dim=0).to(device=flat.device, dtype=flat.dtype) - W3 = torch.stack(W3, dim=0).to(device=flat.device, dtype=flat.dtype) + dev = flat.device + dt = flat.dtype + if ( + not hasattr(experts_module, "_stacked_w1") + or experts_module._stacked_w1.device != dev + or experts_module._stacked_w1.dtype != dt + ): + W1 = [] + W3 = [] + W2 = [] + for i in range(E): + exp = experts_module[i] + W1.append(exp.w1.weight.t()) + W3.append(exp.w3.weight.t()) + W2.append(exp.w2.weight.t()) + experts_module._stacked_w1 = ( + torch.stack(W1, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w3 = ( + torch.stack(W3, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(W2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + W1 = experts_module._stacked_w1 + W3 = experts_module._stacked_w3 # compute gathered inputs X_g according to gather_idx via matmul_ogs gather # First matmul for w1: gather happens inside kernel using gather_indx Y1 = handles.matmul_ogs.matmul_ogs( @@ -123,8 +146,7 @@ def moe_ffn_forward_stub( # SwiGLU: silu(Y1) * Y3 Hidden = F.silu(Y1) * Y3 # Down projection weights [E, inter, hidden] - W2 = [experts_module[i].w2.weight.t() for i in range(E)] - W2 = torch.stack(W2, dim=0).to(device=flat.device, dtype=flat.dtype) + W2 = experts_module._stacked_w2 # Down matmul with fused scatter back using scatter_indx Out = handles.matmul_ogs.matmul_ogs( Hidden, @@ -134,6 +156,7 @@ def moe_ffn_forward_stub( gather_indx=None, scatter_indx=scatter_idx, precision_config=handles.matmul_ogs.PrecisionConfig(), + gammas=routing_data.gate_scal, ) return Out.view(bsz, seqlen, hdim), router_logits except Exception: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 3b4a890b5..35e7cf5ca 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -9,6 +9,7 @@ from pydantic import ( Field, StringConstraints, field_serializer, + field_validator, model_validator, ) @@ -138,6 +139,8 @@ class AxolotlInputConfig( "description": "Mixture-of-Experts backend to use: 'auto', 'hf_triton', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.", }, ) + + # Value is constrained by the Literal type; no normalization needed. qat: QATConfig | None = None quantization: PTQConfig | None = None reward_model: bool | None = Field(