From 035e6806311fd7e2d48e7c721d7beb1a05e88764 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 15 Mar 2024 13:58:12 +0000 Subject: [PATCH] Update test --- tests/monkeypatch/test_moe.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/monkeypatch/test_moe.py b/tests/monkeypatch/test_moe.py index 7fbc06462..dbcbe14ba 100644 --- a/tests/monkeypatch/test_moe.py +++ b/tests/monkeypatch/test_moe.py @@ -1,17 +1,18 @@ import torch -from copy import deepcopy +import pytest +from torch import nn +from torch.nn import functional as F from axolotl.monkeypatch.moe.mlp import FusedExperts from axolotl.monkeypatch.moe.moe import SparseMoeBlock + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig def test_fused_mixtral_moe(): # NOTE: Requires torch 2.2.0 # Set random seeds for reproducibility - torch.manual_seed(0) - torch.cuda.manual_seed(0) - torch.cuda.manual_seed_all(0) - torch.set_default_dtype(torch.float32) + torch.set_default_dtype(torch.float16) torch.set_default_device("cuda") + torch.manual_seed(0) # Define the configuration for the MixtralSparseMoeBlock config = MixtralConfig( @@ -23,7 +24,6 @@ def test_fused_mixtral_moe(): # Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration mixtral_moe = MixtralSparseMoeBlock(config) - sparse_moe = SparseMoeBlock( experts=mixtral_moe.experts, gate=mixtral_moe.gate, @@ -33,6 +33,11 @@ def test_fused_mixtral_moe(): top_k=config.num_experts_per_tok ) + assert torch.cat([ + mixtral_moe.experts[0].w1.weight.data, + mixtral_moe.experts[0].w3.weight.data], dim=0 + ).equal(sparse_moe.experts.experts.weight[0]) + # Generate random input data batch_size = 16 sequence_length = 32 @@ -47,11 +52,9 @@ def test_fused_mixtral_moe(): output_diff = torch.abs(mixtral_output - sparse_output).mean().item() router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item() - print(output_diff, router_diff) - # Define the tolerance for the difference tolerance = 0.05 # # Check if the difference is within the tolerance - assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" - assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" \ No newline at end of file + assert output_diff < 0.05, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}" + assert router_diff == 0, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"