Update test
This commit is contained in:
@@ -1,17 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
from copy import deepcopy
|
import pytest
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||||
|
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig
|
||||||
|
|
||||||
def test_fused_mixtral_moe():
|
def test_fused_mixtral_moe():
|
||||||
# NOTE: Requires torch 2.2.0
|
# NOTE: Requires torch 2.2.0
|
||||||
# Set random seeds for reproducibility
|
# Set random seeds for reproducibility
|
||||||
torch.manual_seed(0)
|
torch.set_default_dtype(torch.float16)
|
||||||
torch.cuda.manual_seed(0)
|
|
||||||
torch.cuda.manual_seed_all(0)
|
|
||||||
torch.set_default_dtype(torch.float32)
|
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
# Define the configuration for the MixtralSparseMoeBlock
|
# Define the configuration for the MixtralSparseMoeBlock
|
||||||
config = MixtralConfig(
|
config = MixtralConfig(
|
||||||
@@ -23,7 +24,6 @@ def test_fused_mixtral_moe():
|
|||||||
|
|
||||||
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
||||||
mixtral_moe = MixtralSparseMoeBlock(config)
|
mixtral_moe = MixtralSparseMoeBlock(config)
|
||||||
|
|
||||||
sparse_moe = SparseMoeBlock(
|
sparse_moe = SparseMoeBlock(
|
||||||
experts=mixtral_moe.experts,
|
experts=mixtral_moe.experts,
|
||||||
gate=mixtral_moe.gate,
|
gate=mixtral_moe.gate,
|
||||||
@@ -33,6 +33,11 @@ def test_fused_mixtral_moe():
|
|||||||
top_k=config.num_experts_per_tok
|
top_k=config.num_experts_per_tok
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert torch.cat([
|
||||||
|
mixtral_moe.experts[0].w1.weight.data,
|
||||||
|
mixtral_moe.experts[0].w3.weight.data], dim=0
|
||||||
|
).equal(sparse_moe.experts.experts.weight[0])
|
||||||
|
|
||||||
# Generate random input data
|
# Generate random input data
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
sequence_length = 32
|
sequence_length = 32
|
||||||
@@ -47,11 +52,9 @@ def test_fused_mixtral_moe():
|
|||||||
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
||||||
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
||||||
|
|
||||||
print(output_diff, router_diff)
|
|
||||||
|
|
||||||
# Define the tolerance for the difference
|
# Define the tolerance for the difference
|
||||||
tolerance = 0.05
|
tolerance = 0.05
|
||||||
|
|
||||||
# # Check if the difference is within the tolerance
|
# # Check if the difference is within the tolerance
|
||||||
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
assert output_diff < 0.05, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||||
assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
assert router_diff == 0, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||||
|
|||||||
Reference in New Issue
Block a user