Update test
This commit is contained in:
@@ -1,17 +1,18 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
import pytest
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig
|
||||
|
||||
def test_fused_mixtral_moe():
|
||||
# NOTE: Requires torch 2.2.0
|
||||
# Set random seeds for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Define the configuration for the MixtralSparseMoeBlock
|
||||
config = MixtralConfig(
|
||||
@@ -23,7 +24,6 @@ def test_fused_mixtral_moe():
|
||||
|
||||
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
||||
mixtral_moe = MixtralSparseMoeBlock(config)
|
||||
|
||||
sparse_moe = SparseMoeBlock(
|
||||
experts=mixtral_moe.experts,
|
||||
gate=mixtral_moe.gate,
|
||||
@@ -33,6 +33,11 @@ def test_fused_mixtral_moe():
|
||||
top_k=config.num_experts_per_tok
|
||||
)
|
||||
|
||||
assert torch.cat([
|
||||
mixtral_moe.experts[0].w1.weight.data,
|
||||
mixtral_moe.experts[0].w3.weight.data], dim=0
|
||||
).equal(sparse_moe.experts.experts.weight[0])
|
||||
|
||||
# Generate random input data
|
||||
batch_size = 16
|
||||
sequence_length = 32
|
||||
@@ -47,11 +52,9 @@ def test_fused_mixtral_moe():
|
||||
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
||||
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
||||
|
||||
print(output_diff, router_diff)
|
||||
|
||||
# Define the tolerance for the difference
|
||||
tolerance = 0.05
|
||||
|
||||
# # Check if the difference is within the tolerance
|
||||
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
assert output_diff < 0.05, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
assert router_diff == 0, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
|
||||
Reference in New Issue
Block a user