Refactor names, bugfixes

This commit is contained in:
Casper Hansen
2024-03-15 12:39:11 +00:00
parent 1bc008e901
commit 26fc10df01
4 changed files with 31 additions and 23 deletions

View File

@@ -10,7 +10,7 @@ def test_fused_mixtral_moe():
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.set_default_dtype(torch.float16)
torch.set_default_dtype(torch.float32)
torch.set_default_device("cuda")
# Define the configuration for the MixtralSparseMoeBlock
@@ -26,6 +26,7 @@ def test_fused_mixtral_moe():
sparse_moe = SparseMoeBlock(
experts=mixtral_moe.experts,
gate=mixtral_moe.gate,
hidden_dim=config.hidden_size,
ffn_dim=config.intermediate_size,
num_experts=config.num_local_experts,
@@ -39,14 +40,18 @@ def test_fused_mixtral_moe():
# Run the forward pass with gradients for both models
with torch.no_grad():
mixtral_output, _ = mixtral_moe(input_data)
sparse_output, _ = sparse_moe(input_data)
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
sparse_output, sparse_router_logits = sparse_moe(input_data)
# Compute the difference between the outputs and router logits
# Compute the difference between the outputs
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
print(output_diff, router_diff)
# Define the tolerance for the difference
tolerance = 0.1
tolerance = 0.05
# # Check if the difference is within the tolerance
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"