feat: add sonicmoe fused lora support (#3519)
* feat: add sonicmoe fused lora support * fix: forgot to add file * feat: add test * feat: add lora support for other routes * fix: add int8 lora support * fix: add qwen35_moe interleave support * fix: qwen3_5_moe loss * chore: lint * address some pr comments * fix test imports * add support matrix for moe kernels [skip ci] --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -51,7 +51,7 @@ def _create_tiny_qwen3_config():
|
||||
|
||||
def _interleave_gate_up_weights(model):
|
||||
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
|
||||
interleave_gate_up,
|
||||
)
|
||||
|
||||
@@ -80,7 +80,7 @@ class TestSonicMoEForwardCorrectness:
|
||||
def test_forward_output_matches(self):
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||
@@ -117,8 +117,8 @@ class TestSonicMoEGradientCorrectness:
|
||||
"""Verify all parameter gradients match between original and patched."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
|
||||
deinterleave_gate_up,
|
||||
)
|
||||
|
||||
@@ -191,7 +191,7 @@ class TestSonicMoEGradientCorrectness:
|
||||
"""Verify that router (gate) weights get non-zero gradients."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||
@@ -223,7 +223,7 @@ class TestSonicMoETrainingConvergence:
|
||||
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
@@ -254,7 +254,7 @@ class TestSonicMoETrainingConvergence:
|
||||
"""Verify expert weights change during training (not frozen)."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
318
tests/e2e/integrations/test_sonicmoe_lora.py
Normal file
318
tests/e2e/integrations/test_sonicmoe_lora.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
End-to-end tests for SonicMoE + LoRA integration.
|
||||
|
||||
Verifies that PEFT-wrapped MoE models work correctly with SonicMoE's
|
||||
runtime LoRA materialization: gradients flow to adapters, base weights
|
||||
stay frozen, and loss converges.
|
||||
|
||||
Requires:
|
||||
- H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)
|
||||
- sonicmoe package installed
|
||||
- peft package installed
|
||||
- transformers with Qwen3MoE support
|
||||
|
||||
Usage:
|
||||
pytest tests/e2e/integrations/test_sonicmoe_lora.py -v -s
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None
|
||||
_peft_available = importlib.util.find_spec("peft") is not None
|
||||
_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"),
|
||||
pytest.mark.skipif(
|
||||
not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)"
|
||||
),
|
||||
pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"),
|
||||
pytest.mark.skipif(not _peft_available, reason="PEFT not installed"),
|
||||
]
|
||||
|
||||
|
||||
def _create_tiny_qwen3_config():
|
||||
"""Create a minimal Qwen3MoE config for fast testing."""
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.for_model("qwen3_moe")
|
||||
config.hidden_size = 512
|
||||
config.intermediate_size = 1024
|
||||
config.moe_intermediate_size = 64
|
||||
config.num_attention_heads = 16
|
||||
config.num_key_value_heads = 2
|
||||
config.head_dim = 32
|
||||
config.num_hidden_layers = 2
|
||||
config.num_experts = 8
|
||||
config.num_experts_per_tok = 2
|
||||
config.vocab_size = 1000
|
||||
config.max_position_embeddings = 128
|
||||
config.norm_topk_prob = True
|
||||
config.torch_dtype = torch.bfloat16
|
||||
return config
|
||||
|
||||
|
||||
def _interleave_gate_up_weights(model):
|
||||
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
|
||||
interleave_gate_up,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for name, param in model.named_parameters():
|
||||
if "gate_up_proj" in name:
|
||||
param.copy_(interleave_gate_up(param))
|
||||
|
||||
|
||||
def _unpatch_sonicmoe():
|
||||
"""Restore original forward on the MoE block class if it was patched."""
|
||||
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||
|
||||
for moe_cls in resolve_moe_block_classes("qwen3_moe"):
|
||||
if hasattr(moe_cls, "_original_forward"):
|
||||
moe_cls.forward = moe_cls._original_forward
|
||||
del moe_cls._original_forward
|
||||
|
||||
|
||||
def _apply_lora(model, target_modules):
|
||||
"""Apply PEFT LoRA to the model."""
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=target_modules,
|
||||
lora_dropout=0.0,
|
||||
bias="none",
|
||||
)
|
||||
return get_peft_model(model, lora_config)
|
||||
|
||||
|
||||
class TestSonicMoELoRATraining:
|
||||
"""Verify SonicMoE + LoRA training works end-to-end."""
|
||||
|
||||
def teardown_method(self):
|
||||
_unpatch_sonicmoe()
|
||||
|
||||
def test_loss_decreases(self):
|
||||
"""Run 30 training steps with LoRA on experts, verify loss decreases."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[p for p in model.parameters() if p.requires_grad], lr=1e-3
|
||||
)
|
||||
losses = []
|
||||
|
||||
for step in range(30):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
loss = out.loss
|
||||
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
|
||||
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
|
||||
losses.append(loss.item())
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
assert losses[-1] < losses[0], (
|
||||
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
|
||||
)
|
||||
|
||||
def test_base_weights_frozen(self):
|
||||
"""Verify base (non-LoRA) weights don't change during training."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
|
||||
|
||||
# Snapshot frozen weights
|
||||
frozen_before = {}
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
frozen_before[name] = param.data.clone()
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[p for p in model.parameters() if p.requires_grad], lr=1e-3
|
||||
)
|
||||
for _ in range(5):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
out.loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if name in frozen_before:
|
||||
assert torch.equal(param.data, frozen_before[name]), (
|
||||
f"Frozen weight changed: {name}"
|
||||
)
|
||||
|
||||
def test_lora_adapters_receive_gradients(self):
|
||||
"""Verify LoRA A and B matrices get non-zero gradients."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
|
||||
|
||||
out = model(input_ids, labels=input_ids)
|
||||
out.loss.backward()
|
||||
|
||||
lora_grads_found = 0
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_" in name and param.requires_grad:
|
||||
assert param.grad is not None, f"No gradient for LoRA param: {name}"
|
||||
assert param.grad.abs().max() > 0, (
|
||||
f"Zero gradient for LoRA param: {name}"
|
||||
)
|
||||
lora_grads_found += 1
|
||||
|
||||
assert lora_grads_found > 0, "No LoRA parameters found with gradients"
|
||||
|
||||
def test_lora_adapters_update(self):
|
||||
"""Verify LoRA adapter weights change during training."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
|
||||
|
||||
# Snapshot LoRA weights
|
||||
lora_before = {}
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_" in name and param.requires_grad:
|
||||
lora_before[name] = param.data.clone()
|
||||
|
||||
assert lora_before, "No LoRA parameters found"
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[p for p in model.parameters() if p.requires_grad], lr=1e-3
|
||||
)
|
||||
for _ in range(5):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
out.loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
changed = sum(
|
||||
1
|
||||
for name, param in model.named_parameters()
|
||||
if name in lora_before and not torch.equal(param.data, lora_before[name])
|
||||
)
|
||||
assert changed > 0, "No LoRA weights changed after 5 training steps"
|
||||
|
||||
|
||||
class TestSonicMoEGateOnlyLoRA:
|
||||
"""Verify LoRA targeting only the gate (router) works with SonicMoE."""
|
||||
|
||||
def teardown_method(self):
|
||||
_unpatch_sonicmoe()
|
||||
|
||||
def test_gate_only_lora_loss_decreases(self):
|
||||
"""LoRA only on gate — expert path should have zero materialization overhead."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
# Only target the gate (router), not expert projections
|
||||
model = _apply_lora(model, ["gate"])
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[p for p in model.parameters() if p.requires_grad], lr=1e-3
|
||||
)
|
||||
losses = []
|
||||
|
||||
for step in range(20):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
loss = out.loss
|
||||
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
|
||||
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
|
||||
losses.append(loss.item())
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
assert losses[-1] < losses[0], (
|
||||
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
|
||||
)
|
||||
|
||||
|
||||
class TestSonicMoENoLoRARegression:
|
||||
"""Verify SonicMoE without LoRA still works after LoRA code was added."""
|
||||
|
||||
def teardown_method(self):
|
||||
_unpatch_sonicmoe()
|
||||
|
||||
def test_no_lora_loss_decreases(self):
|
||||
"""Full fine-tuning (no PEFT) with SonicMoE — regression test."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
||||
losses = []
|
||||
|
||||
for step in range(20):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
loss = out.loss
|
||||
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
|
||||
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
|
||||
losses.append(loss.item())
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
assert losses[-1] < losses[0], (
|
||||
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
|
||||
)
|
||||
@@ -93,7 +93,9 @@ class TestSoftmaxRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
|
||||
@@ -120,7 +122,9 @@ class TestSoftmaxRoutingParity:
|
||||
|
||||
def test_logits_not_returned_by_scattermoe(self):
|
||||
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
|
||||
@@ -131,7 +135,9 @@ class TestSoftmaxRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
gate.norm_topk_prob = False
|
||||
@@ -152,7 +158,9 @@ class TestSoftmaxRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
|
||||
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
|
||||
@@ -190,7 +198,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
@@ -226,7 +236,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
|
||||
@@ -254,7 +266,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, bias_on_gate=False
|
||||
@@ -281,7 +295,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
@@ -309,7 +325,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
@@ -349,7 +367,7 @@ class TestSharedExpertParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert as scatter_compute,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import (
|
||||
_compute_shared_expert as sonic_compute,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.args import KernelsArgs
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
softmax_topk_routing,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
|
||||
ConcatenatedToInterleaved,
|
||||
InterleavedToConcatenated,
|
||||
register_sonicmoe_weight_converter,
|
||||
@@ -212,9 +212,40 @@ class TestWeightConverterRegistration:
|
||||
)
|
||||
break
|
||||
|
||||
def test_register_unsupported_model_type_warns(self):
|
||||
# A model type with no conversion mapping should warn but not raise
|
||||
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
|
||||
def test_register_adds_same_key_converter(self):
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
|
||||
register_sonicmoe_weight_converter("qwen3_moe")
|
||||
|
||||
modified = get_checkpoint_conversion_mapping("qwen3_moe")
|
||||
# Should have a same-key converter for already-fused checkpoints
|
||||
same_key = [
|
||||
c
|
||||
for c in modified
|
||||
if hasattr(c, "source_patterns")
|
||||
and c.source_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and c.target_patterns == ["mlp.experts.gate_up_proj"]
|
||||
]
|
||||
assert len(same_key) == 1
|
||||
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
|
||||
|
||||
def test_register_creates_mapping_when_none(self):
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
|
||||
# qwen3_5_moe has no conversion mapping in transformers
|
||||
register_sonicmoe_weight_converter("qwen3_5_moe")
|
||||
|
||||
mapping = get_checkpoint_conversion_mapping("qwen3_5_moe")
|
||||
assert mapping is not None
|
||||
same_key = [
|
||||
c
|
||||
for c in mapping
|
||||
if hasattr(c, "source_patterns")
|
||||
and c.source_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and c.target_patterns == ["mlp.experts.gate_up_proj"]
|
||||
]
|
||||
assert len(same_key) == 1
|
||||
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
|
||||
|
||||
|
||||
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
|
||||
@@ -462,7 +493,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
"""Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -479,7 +510,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -490,7 +521,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -502,7 +533,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -514,7 +545,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -527,7 +558,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
|
||||
def test_bias_affects_expert_selection(self):
|
||||
"""Large positive bias on expert 0 should make it always selected."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -570,7 +601,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
"""Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing)."""
|
||||
|
||||
def test_output_shapes_group_limited(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -589,7 +620,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_output_shapes_greedy(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -604,7 +635,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -615,7 +646,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -627,7 +658,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -639,7 +670,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_scaling_factor_applied(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -655,7 +686,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
|
||||
def test_group_selection_restricts_experts(self):
|
||||
"""With num_group=4 and topk_group=1, experts should come from selected groups."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -674,7 +705,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert (groups == groups[0]).all()
|
||||
|
||||
def test_unsupported_topk_method_raises(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -706,7 +737,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
"""Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -723,7 +754,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -734,7 +765,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -746,7 +777,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -759,7 +790,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
"""HunYuan V1 always renormalizes (no norm_topk_prob flag)."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -772,7 +803,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
|
||||
def test_uses_gate_wg_weight(self):
|
||||
"""Verify that modifying gate.wg.weight changes the routing output."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ code path where routing happens in float32.
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
328
tests/integrations/test_sonicmoe_lora.py
Normal file
328
tests/integrations/test_sonicmoe_lora.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Unit tests for SonicMoE LoRA support."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.lora import (
|
||||
MoELoRAMaterialize,
|
||||
get_lora_params_from_wrapper,
|
||||
has_lora,
|
||||
materialize_expert_lora,
|
||||
unwrap_experts_lora,
|
||||
unwrap_gate_lora,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Helpers: mock PEFT modules
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _make_mock_lora_module(weight_A, weight_B, scaling_val, param_name=None):
|
||||
"""Create a mock PEFT-wrapped module with LoRA attributes."""
|
||||
mock = MagicMock()
|
||||
|
||||
lora_A_linear = MagicMock()
|
||||
lora_A_linear.weight = weight_A
|
||||
|
||||
lora_B_linear = MagicMock()
|
||||
lora_B_linear.weight = weight_B
|
||||
|
||||
mock.lora_A = {"default": lora_A_linear}
|
||||
mock.lora_B = {"default": lora_B_linear}
|
||||
mock.scaling = {"default": scaling_val}
|
||||
mock.active_adapters = ["default"]
|
||||
|
||||
if param_name is not None:
|
||||
mock.parameter_name = param_name
|
||||
|
||||
return mock
|
||||
|
||||
|
||||
def _make_peft_gate(hidden_size, num_experts, rank, scaling=0.5):
|
||||
"""Create a mock PEFT-wrapped gate module."""
|
||||
base_gate = MagicMock()
|
||||
base_gate.weight = torch.randn(num_experts, hidden_size)
|
||||
base_gate.top_k = 2
|
||||
base_gate.norm_topk_prob = True
|
||||
|
||||
lora_A = torch.randn(rank, hidden_size)
|
||||
lora_B = torch.randn(num_experts, rank)
|
||||
|
||||
wrapper = _make_mock_lora_module(lora_A, lora_B, scaling)
|
||||
wrapper.base_layer = base_gate
|
||||
return wrapper, base_gate
|
||||
|
||||
|
||||
def _make_peft_experts(
|
||||
num_experts, gate_up_dim, down_dim, hidden_size, rank, scaling=0.5
|
||||
):
|
||||
"""Create a mock PEFT-wrapped experts chain.
|
||||
|
||||
Simulates: ParamWrapper(down_proj) -> ParamWrapper(gate_up_proj) -> Experts
|
||||
"""
|
||||
base_experts = MagicMock()
|
||||
base_experts.gate_up_proj = torch.randn(num_experts, gate_up_dim, hidden_size)
|
||||
base_experts.down_proj = torch.randn(num_experts, hidden_size, down_dim)
|
||||
# Remove base_layer and lora_A from base_experts so the chain walk stops
|
||||
del base_experts.base_layer
|
||||
del base_experts.lora_A
|
||||
|
||||
# gate_up_proj wrapper
|
||||
gup_A = torch.randn(rank * num_experts, hidden_size)
|
||||
gup_B = torch.randn(gate_up_dim, rank * num_experts)
|
||||
gup_wrapper = _make_mock_lora_module(gup_A, gup_B, scaling, "gate_up_proj")
|
||||
gup_wrapper.base_layer = base_experts
|
||||
|
||||
# down_proj wrapper (outermost)
|
||||
down_A = torch.randn(rank * num_experts, down_dim)
|
||||
down_B = torch.randn(hidden_size, rank * num_experts)
|
||||
down_wrapper = _make_mock_lora_module(down_A, down_B, scaling, "down_proj")
|
||||
down_wrapper.base_layer = gup_wrapper
|
||||
|
||||
return down_wrapper, base_experts, (gup_A, gup_B), (down_A, down_B)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: has_lora
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestHasLora:
|
||||
def test_plain_module(self):
|
||||
m = MagicMock(spec=["weight"])
|
||||
del m.base_layer
|
||||
del m.lora_A
|
||||
assert not has_lora(m)
|
||||
|
||||
def test_wrapped_module(self):
|
||||
m = MagicMock()
|
||||
m.base_layer = MagicMock()
|
||||
m.lora_A = {"default": MagicMock()}
|
||||
assert has_lora(m)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: get_lora_params_from_wrapper
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetLoraParams:
|
||||
def test_no_lora_attrs(self):
|
||||
m = MagicMock(spec=["weight"])
|
||||
del m.lora_A
|
||||
del m.lora_B
|
||||
assert get_lora_params_from_wrapper(m) == (None, None, None)
|
||||
|
||||
def test_extracts_params(self):
|
||||
A = torch.randn(4, 8)
|
||||
B = torch.randn(16, 4)
|
||||
wrapper = _make_mock_lora_module(A, B, 0.5)
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper)
|
||||
assert torch.equal(lora_A, A)
|
||||
assert torch.equal(lora_B, B)
|
||||
assert scaling == 0.5
|
||||
|
||||
def test_no_active_adapters(self):
|
||||
wrapper = _make_mock_lora_module(torch.randn(4, 8), torch.randn(16, 4), 0.5)
|
||||
wrapper.active_adapters = []
|
||||
assert get_lora_params_from_wrapper(wrapper) == (None, None, None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: unwrap_gate_lora
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestUnwrapGateLora:
|
||||
def test_plain_gate(self):
|
||||
gate = MagicMock(spec=["weight", "top_k"])
|
||||
del gate.base_layer
|
||||
del gate.lora_A
|
||||
gate.weight = torch.randn(8, 64)
|
||||
base, weight, delta = unwrap_gate_lora(gate)
|
||||
assert base is gate
|
||||
assert torch.equal(weight, gate.weight)
|
||||
assert delta is None
|
||||
|
||||
def test_wrapped_gate(self):
|
||||
wrapper, base_gate = _make_peft_gate(
|
||||
hidden_size=64, num_experts=8, rank=4, scaling=0.5
|
||||
)
|
||||
base, weight, delta = unwrap_gate_lora(wrapper)
|
||||
assert base is base_gate
|
||||
assert torch.equal(weight, base_gate.weight)
|
||||
assert delta is not None
|
||||
assert delta.shape == base_gate.weight.shape
|
||||
|
||||
# Verify delta = scaling * B @ A
|
||||
lora_A = wrapper.lora_A["default"].weight
|
||||
lora_B = wrapper.lora_B["default"].weight
|
||||
expected = 0.5 * (lora_B @ lora_A)
|
||||
assert torch.allclose(delta, expected)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: unwrap_experts_lora
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestUnwrapExpertsLora:
|
||||
def test_plain_experts(self):
|
||||
experts = MagicMock(spec=["gate_up_proj", "down_proj"])
|
||||
del experts.base_layer
|
||||
del experts.lora_A
|
||||
base, lora_dict = unwrap_experts_lora(experts)
|
||||
assert base is experts
|
||||
assert lora_dict == {}
|
||||
|
||||
def test_wrapped_experts(self):
|
||||
E, I2, I, H, r = 4, 256, 128, 64, 8 # noqa: E741
|
||||
wrapper, base_experts, (gup_A, gup_B), (down_A, down_B) = _make_peft_experts(
|
||||
E, I2, I, H, r, scaling=0.25
|
||||
)
|
||||
base, lora_dict = unwrap_experts_lora(wrapper)
|
||||
assert base is base_experts
|
||||
assert "gate_up_proj" in lora_dict
|
||||
assert "down_proj" in lora_dict
|
||||
|
||||
gup_lA, gup_lB, gup_s = lora_dict["gate_up_proj"]
|
||||
assert torch.equal(gup_lA, gup_A)
|
||||
assert torch.equal(gup_lB, gup_B)
|
||||
assert gup_s == 0.25
|
||||
|
||||
down_lA, down_lB, down_s = lora_dict["down_proj"]
|
||||
assert torch.equal(down_lA, down_A)
|
||||
assert torch.equal(down_lB, down_B)
|
||||
assert down_s == 0.25
|
||||
|
||||
def test_partial_lora(self):
|
||||
"""Only gate_up_proj has LoRA, down_proj does not."""
|
||||
base_experts = MagicMock(spec=["gate_up_proj", "down_proj"])
|
||||
del base_experts.base_layer
|
||||
del base_experts.lora_A
|
||||
|
||||
gup_A = torch.randn(16, 64)
|
||||
gup_B = torch.randn(256, 16)
|
||||
gup_wrapper = _make_mock_lora_module(gup_A, gup_B, 0.5, "gate_up_proj")
|
||||
gup_wrapper.base_layer = base_experts
|
||||
|
||||
base, lora_dict = unwrap_experts_lora(gup_wrapper)
|
||||
assert base is base_experts
|
||||
assert "gate_up_proj" in lora_dict
|
||||
assert "down_proj" not in lora_dict
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: MoELoRAMaterialize
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMoELoRAMaterialize:
|
||||
@pytest.fixture()
|
||||
def setup(self):
|
||||
E, dim1, dim2, r = 4, 32, 16, 4
|
||||
scaling = 0.5
|
||||
W = torch.randn(E, dim1, dim2, dtype=torch.float64, requires_grad=False)
|
||||
A = torch.randn(r * E, dim2, dtype=torch.float64, requires_grad=True)
|
||||
B = torch.randn(dim1, r * E, dtype=torch.float64, requires_grad=True)
|
||||
return W, A, B, scaling, E, r
|
||||
|
||||
def test_forward_shape(self, setup):
|
||||
W, A, B, scaling, E, r = setup
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
|
||||
assert W_eff.shape == W.shape
|
||||
|
||||
def test_forward_correctness(self, setup):
|
||||
W, A, B, scaling, E, r = setup
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
|
||||
|
||||
# Manual per-expert computation.
|
||||
# lora_A is expert-major: [r*E, dim2] -> rows [e*r:(e+1)*r] = expert e
|
||||
# lora_B is rank-major: [dim1, r*E] -> reshape [dim1, r, E], slice [:, :, e]
|
||||
_, dim1, dim2 = W.shape
|
||||
expected = W.clone()
|
||||
B_3d = B.reshape(dim1, r, E)
|
||||
for e in range(E):
|
||||
A_e = A[e * r : (e + 1) * r, :] # [r, dim2]
|
||||
B_e = B_3d[:, :, e] # [dim1, r]
|
||||
expected[e] += scaling * (B_e @ A_e)
|
||||
|
||||
assert torch.allclose(W_eff, expected, atol=1e-10)
|
||||
|
||||
def test_backward_gradcheck(self, setup):
|
||||
W, A, B, scaling, E, r = setup
|
||||
# gradcheck requires float64
|
||||
assert torch.autograd.gradcheck(
|
||||
lambda a, b: MoELoRAMaterialize.apply(W, a, b, scaling),
|
||||
(A, B),
|
||||
eps=1e-6,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
def test_no_grad_for_base_weight(self, setup):
|
||||
W, A, B, scaling, E, r = setup
|
||||
W.requires_grad_(True)
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
|
||||
loss = W_eff.sum()
|
||||
loss.backward()
|
||||
assert W.grad is None
|
||||
assert A.grad is not None
|
||||
assert B.grad is not None
|
||||
|
||||
def test_scaling_zero(self, setup):
|
||||
W, A, B, _, E, r = setup
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, 0.0)
|
||||
assert torch.allclose(W_eff, W)
|
||||
|
||||
def test_gate_up_proj_shapes(self):
|
||||
"""Test with realistic gate_up_proj shapes [E, 2*I, H]."""
|
||||
E, I2, H, r = 8, 512, 256, 16
|
||||
W = torch.randn(E, I2, H, dtype=torch.float64)
|
||||
A = torch.randn(r * E, H, dtype=torch.float64, requires_grad=True)
|
||||
B = torch.randn(I2, r * E, dtype=torch.float64, requires_grad=True)
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, 1.0)
|
||||
assert W_eff.shape == (E, I2, H)
|
||||
loss = W_eff.sum()
|
||||
loss.backward()
|
||||
assert A.grad.shape == A.shape
|
||||
assert B.grad.shape == B.shape
|
||||
|
||||
def test_down_proj_shapes(self):
|
||||
"""Test with realistic down_proj shapes [E, H, I]."""
|
||||
E, H, I, r = 8, 256, 512, 16 # noqa: E741
|
||||
W = torch.randn(E, H, I, dtype=torch.float64)
|
||||
A = torch.randn(r * E, I, dtype=torch.float64, requires_grad=True)
|
||||
B = torch.randn(H, r * E, dtype=torch.float64, requires_grad=True)
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, 1.0)
|
||||
assert W_eff.shape == (E, H, I)
|
||||
loss = W_eff.sum()
|
||||
loss.backward()
|
||||
assert A.grad.shape == A.shape
|
||||
assert B.grad.shape == B.shape
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: materialize_expert_lora
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMaterializeExpertLora:
|
||||
def test_none_passthrough(self):
|
||||
W = torch.randn(4, 32, 16)
|
||||
result = materialize_expert_lora(W, None)
|
||||
assert result is W
|
||||
|
||||
def test_with_lora(self):
|
||||
E, dim1, dim2, r = 4, 32, 16, 4
|
||||
W = torch.randn(E, dim1, dim2)
|
||||
A = torch.randn(r * E, dim2, requires_grad=True)
|
||||
B = torch.randn(dim1, r * E, requires_grad=True)
|
||||
result = materialize_expert_lora(W, (A, B, 0.5))
|
||||
assert result.shape == W.shape
|
||||
assert not torch.equal(result, W)
|
||||
Reference in New Issue
Block a user