Update test

This commit is contained in:
Casper Hansen
2024-03-15 13:58:12 +00:00
parent 26fc10df01
commit 035e680631

View File

@@ -1,17 +1,18 @@
import torch
from copy import deepcopy
import pytest
from torch import nn
from torch.nn import functional as F
from axolotl.monkeypatch.moe.mlp import FusedExperts
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig
def test_fused_mixtral_moe():
# NOTE: Requires torch 2.2.0
# Set random seeds for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.set_default_dtype(torch.float32)
torch.set_default_dtype(torch.float16)
torch.set_default_device("cuda")
torch.manual_seed(0)
# Define the configuration for the MixtralSparseMoeBlock
config = MixtralConfig(
@@ -23,7 +24,6 @@ def test_fused_mixtral_moe():
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
mixtral_moe = MixtralSparseMoeBlock(config)
sparse_moe = SparseMoeBlock(
experts=mixtral_moe.experts,
gate=mixtral_moe.gate,
@@ -33,6 +33,11 @@ def test_fused_mixtral_moe():
top_k=config.num_experts_per_tok
)
assert torch.cat([
mixtral_moe.experts[0].w1.weight.data,
mixtral_moe.experts[0].w3.weight.data], dim=0
).equal(sparse_moe.experts.experts.weight[0])
# Generate random input data
batch_size = 16
sequence_length = 32
@@ -47,11 +52,9 @@ def test_fused_mixtral_moe():
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
print(output_diff, router_diff)
# Define the tolerance for the difference
tolerance = 0.05
# # Check if the difference is within the tolerance
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
assert output_diff < 0.05, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
assert router_diff == 0, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"