diff --git a/tests/monkeypatch/test_moe.py b/tests/monkeypatch/test_moe.py index 7fbc06462..dbcbe14ba 100644 --- a/tests/monkeypatch/test_moe.py +++ b/tests/monkeypatch/test_moe.py @@ -1,17 +1,18 @@ import torch -from copy import deepcopy +import pytest +from torch import nn +from torch.nn import functional as F from axolotl.monkeypatch.moe.mlp import FusedExperts from axolotl.monkeypatch.moe.moe import SparseMoeBlock + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig def test_fused_mixtral_moe(): # NOTE: Requires torch 2.2.0 # Set random seeds for reproducibility - torch.manual_seed(0) - torch.cuda.manual_seed(0) - torch.cuda.manual_seed_all(0) - torch.set_default_dtype(torch.float32) + torch.set_default_dtype(torch.float16) torch.set_default_device("cuda") + torch.manual_seed(0) # Define the configuration for the MixtralSparseMoeBlock config = MixtralConfig( @@ -23,7 +24,6 @@ def test_fused_mixtral_moe(): # Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration mixtral_moe = MixtralSparseMoeBlock(config) - sparse_moe = SparseMoeBlock( experts=mixtral_moe.experts, gate=mixtral_moe.gate, @@ -33,6 +33,11 @@ def test_fused_mixtral_moe(): top_k=config.num_experts_per_tok ) + assert torch.cat([ + mixtral_moe.experts[0].w1.weight.data, + mixtral_moe.experts[0].w3.weight.data], dim=0 + ).equal(sparse_moe.experts.experts.weight[0]) + # Generate random input data batch_size = 16 sequence_length = 32 @@ -47,11 +52,9 @@ def test_fused_mixtral_moe(): output_diff = torch.abs(mixtral_output - sparse_output).mean().item() router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item() - print(output_diff, router_diff) - # Define the tolerance for the difference tolerance = 0.05 # # Check if the difference is within the tolerance - assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" - assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" \ No newline at end of file + assert output_diff < 0.05, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" + assert router_diff == 0, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"