59 lines
2.3 KiB
Python
59 lines
2.3 KiB
Python
import torch
|
|
from copy import deepcopy
|
|
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
|
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|
|
|
def test_fused_mixtral_moe():
|
|
# Set random seeds for reproducibility
|
|
torch.manual_seed(0)
|
|
torch.cuda.manual_seed(0)
|
|
torch.cuda.manual_seed_all(0)
|
|
|
|
# Define the configuration for the MixtralSparseMoeBlock
|
|
config = {
|
|
'hidden_size': 128,
|
|
'intermediate_size': 512,
|
|
'num_local_experts': 8,
|
|
'num_experts_per_tok': 2,
|
|
}
|
|
|
|
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
|
mixtral_moe = MixtralSparseMoeBlock(config)
|
|
mixtral_moe_copy = deepcopy(mixtral_moe)
|
|
|
|
experts = FusedExperts(
|
|
experts=mixtral_moe_copy.experts,
|
|
input_size=mixtral_moe_copy.ffn_dim,
|
|
hidden_size=mixtral_moe_copy.hidden_dim,
|
|
num_experts=mixtral_moe_copy.num_experts,
|
|
top_k=mixtral_moe_copy.top_k,
|
|
activation=mixtral_moe_copy.experts[0].act_fn
|
|
)
|
|
sparse_moe = SparseMoeBlock(
|
|
experts,
|
|
hidden_dim=config['hidden_size'],
|
|
ffn_dim=config['intermediate_size'],
|
|
num_experts=config['num_local_experts'],
|
|
top_k=config['num_experts_per_tok']
|
|
)
|
|
|
|
# Generate random input data
|
|
batch_size = 16
|
|
sequence_length = 32
|
|
input_data = torch.randn(batch_size, sequence_length, config['hidden_size'])
|
|
|
|
# Run the forward pass with gradients for both models
|
|
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
|
|
sparse_output, sparse_router_logits = sparse_moe(input_data)
|
|
|
|
# Compute the difference between the outputs and router logits
|
|
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
|
router_logits_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
|
|
|
# Define the tolerance for the difference
|
|
tolerance = 0.00001
|
|
|
|
# Check if the difference is within the tolerance
|
|
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
|
assert router_logits_diff < tolerance, f"Router logits difference is {router_logits_diff}, which is greater than the tolerance of {tolerance}" |