moe kernels init scaffold
This commit is contained in:
18
docs/moe_backends.md
Normal file
18
docs/moe_backends.md
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
MoE Backends in Axolotl
|
||||||
|
|
||||||
|
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via an environment variable:
|
||||||
|
|
||||||
|
- AXOLOTL_MOE_BACKEND=auto|hf_triton|torch_grouped|naive
|
||||||
|
|
||||||
|
Behavior
|
||||||
|
- auto (default): prefers PyTorch 2.8+ grouped GEMM, then Hugging Face kernels hub, otherwise naive.
|
||||||
|
- hf_triton: uses the Hugging Face kernels hub (kernels-community/triton_kernels) when available.
|
||||||
|
- torch_grouped: targets PyTorch 2.8+ grouped GEMM.
|
||||||
|
- naive: keeps the reference per-expert loop.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
- Current implementation wires the backend selector and routes Mixtral MoE through it. The hf_triton path is initially a stub: it uses kernels hub for routing but still falls back to per-expert computation until grouped GEMM is fully integrated.
|
||||||
|
- No changes to training scripts are required; Axolotl wraps Transformers Trainer; selection happens inside the model forward.
|
||||||
|
|
||||||
|
Example
|
||||||
|
AXOLOTL_MOE_BACKEND=hf_triton accelerate launch -m axolotl.cli.train path/to/config.yaml
|
||||||
161
scripts/bench_moe.py
Normal file
161
scripts/bench_moe.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUMlp(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||||
|
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||||
|
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||||
|
self.act_fn = F.silu
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.w2(self.act_fn(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class Experts(nn.Module):
|
||||||
|
def __init__(self, num_experts: int, hidden_size: int, intermediate_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
SwiGLUMlp(hidden_size, intermediate_size) for _ in range(num_experts)
|
||||||
|
)
|
||||||
|
self.num_experts = num_experts
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.layers[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def forward_naive(
|
||||||
|
hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int
|
||||||
|
):
|
||||||
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
|
x = hidden_states.view(-1, hdim)
|
||||||
|
router_logits = gate(x)
|
||||||
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
|
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype)
|
||||||
|
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||||
|
y = torch.empty_like(x_rep)
|
||||||
|
flat_idx = topk_idx.view(-1)
|
||||||
|
for i in range(experts.num_experts):
|
||||||
|
sel = flat_idx == i
|
||||||
|
if sel.any():
|
||||||
|
y[sel] = experts[i](x_rep[sel])
|
||||||
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
|
return y.view(bsz, seqlen, hdim)
|
||||||
|
|
||||||
|
|
||||||
|
def forward_hf_triton(
|
||||||
|
hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from axolotl.kernels.moe import hf_triton as _hf
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
y, _ = _hf.moe_ffn_forward_stub(hidden_states, gate, experts, top_k)
|
||||||
|
return y
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def bench(fn, *args, iters=50, warmup=10, sync=True):
|
||||||
|
# warmup
|
||||||
|
for _ in range(warmup):
|
||||||
|
out = fn(*args)
|
||||||
|
if sync and torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# measure
|
||||||
|
times = []
|
||||||
|
for _ in range(iters):
|
||||||
|
if sync and torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
out = fn(*args)
|
||||||
|
if sync and torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
dt = (time.perf_counter() - t0) * 1000.0
|
||||||
|
times.append(dt)
|
||||||
|
return sum(times) / len(times)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser(description="MoE microbenchmark")
|
||||||
|
p.add_argument("--bsz", type=int, default=8)
|
||||||
|
p.add_argument("--seq", type=int, default=1024)
|
||||||
|
p.add_argument("--hidden", type=int, default=4096)
|
||||||
|
p.add_argument("--inter", type=int, default=14336)
|
||||||
|
p.add_argument("--experts", type=int, default=8)
|
||||||
|
p.add_argument("--top_k", type=int, default=2)
|
||||||
|
p.add_argument(
|
||||||
|
"--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]
|
||||||
|
)
|
||||||
|
p.add_argument("--iters", type=int, default=50)
|
||||||
|
p.add_argument("--warmup", type=int, default=10)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
dtype = {
|
||||||
|
"bf16": torch.bfloat16,
|
||||||
|
"fp16": torch.float16,
|
||||||
|
"fp32": torch.float32,
|
||||||
|
}[args.dtype]
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
if device == "cuda":
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
experts = Experts(args.experts, args.hidden, args.inter).to(
|
||||||
|
device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
gate = nn.Linear(args.hidden, args.experts, bias=False).to(
|
||||||
|
device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# data
|
||||||
|
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Report config
|
||||||
|
tokens = args.bsz * args.seq
|
||||||
|
print(
|
||||||
|
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} experts={args.experts} top_k={args.top_k}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Naive baseline
|
||||||
|
t_naive = bench(
|
||||||
|
forward_naive,
|
||||||
|
x,
|
||||||
|
gate,
|
||||||
|
experts,
|
||||||
|
args.top_k,
|
||||||
|
iters=args.iters,
|
||||||
|
warmup=args.warmup,
|
||||||
|
)
|
||||||
|
print(f"naive {t_naive:.2f} ms {tokens / (t_naive / 1000):.1f} tok/s")
|
||||||
|
|
||||||
|
# HF Triton (routing + stub compute for now)
|
||||||
|
os.environ.setdefault("AXOLOTL_MOE_BACKEND", "hf_triton")
|
||||||
|
t_hf = forward_hf_triton
|
||||||
|
y = t_hf(x, gate, experts, args.top_k)
|
||||||
|
if y is not None:
|
||||||
|
t_ms = bench(
|
||||||
|
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
|
||||||
|
)
|
||||||
|
print(f"hf_triton {t_ms:.2f} ms {tokens / (t_ms / 1000):.1f} tok/s")
|
||||||
|
else:
|
||||||
|
print("hf_triton N/A (kernels hub not available)")
|
||||||
|
|
||||||
|
# torch_grouped placeholder — not yet implemented
|
||||||
|
print("torch_grouped N/A (pending implementation)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
3
src/axolotl/kernels/moe/__init__.py
Normal file
3
src/axolotl/kernels/moe/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .backends import MOEBackend, get_moe_backend_name
|
||||||
|
|
||||||
|
__all__ = ["get_moe_backend_name", "MOEBackend"]
|
||||||
65
src/axolotl/kernels/moe/backends.py
Normal file
65
src/axolotl/kernels/moe/backends.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class MOEBackend(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
HF_TRITON = "hf_triton"
|
||||||
|
TORCH_GROUPED = "torch_grouped"
|
||||||
|
NAIVE = "naive"
|
||||||
|
|
||||||
|
|
||||||
|
def _probe_torch_grouped() -> bool:
|
||||||
|
try:
|
||||||
|
import torch # noqa: F401
|
||||||
|
|
||||||
|
# Prefer a simple version check; exact APIs may vary across 2.8+.
|
||||||
|
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
|
||||||
|
return ver >= (2, 8)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _probe_hf_triton() -> bool:
|
||||||
|
try:
|
||||||
|
# The hub loads kernels lazily; this import is a light probe.
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
importlib.import_module("kernels")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
||||||
|
"""
|
||||||
|
Resolve the desired MoE backend using, in order of precedence:
|
||||||
|
- explicit preferred argument
|
||||||
|
- environment variable AXOLOTL_MOE_BACKEND
|
||||||
|
- auto detection
|
||||||
|
"""
|
||||||
|
choice = (preferred or os.getenv("AXOLOTL_MOE_BACKEND") or "auto").lower()
|
||||||
|
try:
|
||||||
|
selected = MOEBackend(choice)
|
||||||
|
except ValueError:
|
||||||
|
warnings.warn(f"Unknown moe backend '{choice}', falling back to auto")
|
||||||
|
selected = MOEBackend.AUTO
|
||||||
|
|
||||||
|
if selected == MOEBackend.AUTO:
|
||||||
|
if _probe_torch_grouped():
|
||||||
|
return MOEBackend.TORCH_GROUPED
|
||||||
|
if _probe_hf_triton():
|
||||||
|
return MOEBackend.HF_TRITON
|
||||||
|
return MOEBackend.NAIVE
|
||||||
|
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
|
||||||
|
warnings.warn(
|
||||||
|
"torch_grouped requested but torch>=2.8 not detected; falling back to hf_triton/naive"
|
||||||
|
)
|
||||||
|
return MOEBackend.HF_TRITON if _probe_hf_triton() else MOEBackend.NAIVE
|
||||||
|
if selected == MOEBackend.HF_TRITON and not _probe_hf_triton():
|
||||||
|
warnings.warn(
|
||||||
|
"hf_triton requested but kernels hub not available; falling back to torch_grouped/naive"
|
||||||
|
)
|
||||||
|
return MOEBackend.TORCH_GROUPED if _probe_torch_grouped() else MOEBackend.NAIVE
|
||||||
|
return selected
|
||||||
104
src/axolotl/kernels/moe/hf_triton.py
Normal file
104
src/axolotl/kernels/moe/hf_triton.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
Adapter for Hugging Face kernels hub (kernels-community/triton_kernels).
|
||||||
|
This file provides light probes and placeholders for future integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HFTritonHandles:
|
||||||
|
routing: Any
|
||||||
|
matmul_ogs: Any
|
||||||
|
swiglu: Any
|
||||||
|
|
||||||
|
|
||||||
|
def available() -> bool:
|
||||||
|
try:
|
||||||
|
import kernels # noqa: F401
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def load() -> Optional[HFTritonHandles]:
|
||||||
|
try:
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
tk = get_kernel("kernels-community/triton_kernels")
|
||||||
|
return HFTritonHandles(
|
||||||
|
routing=tk.routing, matmul_ogs=tk.matmul_ogs, swiglu=tk.swiglu
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def route_topk(logits, top_k: int):
|
||||||
|
handles = load()
|
||||||
|
if handles is None:
|
||||||
|
return None
|
||||||
|
return handles.routing.routing_torch(logits, n_expts_act=top_k)
|
||||||
|
|
||||||
|
|
||||||
|
def swiglu(x, alpha, limit=1.0, routing_data=None):
|
||||||
|
handles = load()
|
||||||
|
if handles is None:
|
||||||
|
return None
|
||||||
|
pc = handles.swiglu.PrecisionConfig(limit=limit)
|
||||||
|
return handles.swiglu.swiglu(x, alpha, pc, routing_data)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_ffn_forward_stub(
|
||||||
|
hidden_states, gate_linear, experts_module, top_k: int
|
||||||
|
) -> Tuple[object, object]:
|
||||||
|
"""
|
||||||
|
Temporary stub that uses kernels hub routing, but falls back to per-expert compute.
|
||||||
|
Returns (final_hidden_states, router_logits).
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
|
flat = hidden_states.view(-1, hdim)
|
||||||
|
router_logits = gate_linear(flat)
|
||||||
|
# use hub routing if available; otherwise fallback to softmax+topk
|
||||||
|
routed = None
|
||||||
|
if available():
|
||||||
|
try:
|
||||||
|
routed = route_topk(router_logits, top_k)
|
||||||
|
except Exception:
|
||||||
|
routed = None
|
||||||
|
if routed is None:
|
||||||
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
|
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||||
|
topk_weight = topk_weight.to(flat.dtype)
|
||||||
|
x_rep = flat.repeat_interleave(top_k, dim=0)
|
||||||
|
y = torch.empty_like(x_rep)
|
||||||
|
flat_idx = topk_idx.view(-1)
|
||||||
|
for i in range(experts_module.num_experts):
|
||||||
|
expert = experts_module[i]
|
||||||
|
y[flat_idx == i] = expert(x_rep[flat_idx == i])
|
||||||
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
|
return y.reshape(bsz, seqlen, hdim), router_logits
|
||||||
|
# If routed via hub, still fallback to per-expert compute until grouped GEMM path is wired.
|
||||||
|
ex_routing_data, gather_idx, scatter_idx = routed
|
||||||
|
# Convert to naive per-expert compute on packed tokens (future: call matmul_ogs + swiglu)
|
||||||
|
# For now, reconstruct the same result as naive path (no speedup but validates routing).
|
||||||
|
# We map the selected experts from gather_idx back to expert ids via router_logits argmax among top-k.
|
||||||
|
# Simpler: reuse naive computation for correctness; detailed integration will follow.
|
||||||
|
routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
|
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(flat.dtype)
|
||||||
|
x_rep = flat.repeat_interleave(top_k, dim=0)
|
||||||
|
y = torch.empty_like(x_rep)
|
||||||
|
flat_idx = topk_idx.view(-1)
|
||||||
|
for i in range(experts_module.num_experts):
|
||||||
|
expert = experts_module[i]
|
||||||
|
y[flat_idx == i] = expert(x_rep[flat_idx == i])
|
||||||
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
|
return y.reshape(bsz, seqlen, hdim), router_logits
|
||||||
16
src/axolotl/kernels/moe/torch_grouped.py
Normal file
16
src/axolotl/kernels/moe/torch_grouped.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
Placeholder for PyTorch 2.8+ grouped GEMM MoE path.
|
||||||
|
Currently probes availability; full integration to be implemented.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def available() -> bool:
|
||||||
|
try:
|
||||||
|
import torch # noqa: F401
|
||||||
|
|
||||||
|
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
|
||||||
|
return ver >= (2, 8)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
@@ -6,8 +6,13 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def patch_mixtral_moe_forward_zero3() -> None:
|
def patch_mixtral_moe_forward_zero3() -> None:
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from axolotl.kernels.moe import backends as _moe_backends, hf_triton as _hf_triton
|
||||||
|
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
|
||||||
|
|
||||||
def mlp_forward(self, hidden_states):
|
def mlp_forward(self, hidden_states):
|
||||||
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
|
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
|
||||||
hidden_states
|
hidden_states
|
||||||
@@ -21,21 +26,42 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
|||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
# router_logits: (batch * sequence_length, n_experts)
|
# router_logits: (batch * sequence_length, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
|
backend = get_moe_backend_name()
|
||||||
|
if backend == MOEBackend.HF_TRITON and _hf_triton.available():
|
||||||
|
# Stub path: use kernels hub routing and fallback per-expert compute
|
||||||
|
try:
|
||||||
|
final_hidden_states, router_logits = _hf_triton.moe_ffn_forward_stub(
|
||||||
|
hidden_states.view(batch_size, sequence_length, hidden_dim),
|
||||||
|
self.gate,
|
||||||
|
self.experts,
|
||||||
|
self.top_k,
|
||||||
|
)
|
||||||
|
return final_hidden_states, router_logits
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"hf_triton backend failed, falling back to naive: {e}")
|
||||||
|
elif (
|
||||||
|
backend == MOEBackend.TORCH_GROUPED
|
||||||
|
and not _moe_backends._probe_torch_grouped()
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
"torch_grouped selected but not available; falling back to naive"
|
||||||
|
)
|
||||||
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
topk_weight, topk_idx = torch.topk(
|
topk_weight, topk_idx = torch.topk(
|
||||||
routing_weights, self.top_k, dim=-1, sorted=False
|
routing_weights, self.top_k, dim=-1, sorted=False
|
||||||
)
|
)
|
||||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||||
# we cast back to the input dtype
|
|
||||||
topk_weight = topk_weight.to(hidden_states.dtype)
|
topk_weight = topk_weight.to(hidden_states.dtype)
|
||||||
|
|
||||||
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
hidden_states_rep = hidden_states.repeat_interleave(self.top_k, dim=0)
|
||||||
y = torch.empty_like(hidden_states)
|
y = torch.empty_like(hidden_states_rep)
|
||||||
flat_topk_idx = topk_idx.view(-1)
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
for i in range(self.num_experts):
|
for i in range(self.num_experts):
|
||||||
expert = self.experts[i]
|
expert = self.experts[i]
|
||||||
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
sel = flat_topk_idx == i
|
||||||
|
if sel.any():
|
||||||
|
y[sel] = expert(hidden_states_rep[sel])
|
||||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
||||||
return final_hidden_states, router_logits
|
return final_hidden_states, router_logits
|
||||||
|
|||||||
Reference in New Issue
Block a user