Refactor names, bugfixes
This commit is contained in:
@@ -10,7 +10,7 @@ def test_fused_mixtral_moe():
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
# Define the configuration for the MixtralSparseMoeBlock
|
||||
@@ -26,6 +26,7 @@ def test_fused_mixtral_moe():
|
||||
|
||||
sparse_moe = SparseMoeBlock(
|
||||
experts=mixtral_moe.experts,
|
||||
gate=mixtral_moe.gate,
|
||||
hidden_dim=config.hidden_size,
|
||||
ffn_dim=config.intermediate_size,
|
||||
num_experts=config.num_local_experts,
|
||||
@@ -39,14 +40,18 @@ def test_fused_mixtral_moe():
|
||||
|
||||
# Run the forward pass with gradients for both models
|
||||
with torch.no_grad():
|
||||
mixtral_output, _ = mixtral_moe(input_data)
|
||||
sparse_output, _ = sparse_moe(input_data)
|
||||
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
|
||||
sparse_output, sparse_router_logits = sparse_moe(input_data)
|
||||
|
||||
# Compute the difference between the outputs and router logits
|
||||
# Compute the difference between the outputs
|
||||
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
||||
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
||||
|
||||
print(output_diff, router_diff)
|
||||
|
||||
# Define the tolerance for the difference
|
||||
tolerance = 0.1
|
||||
tolerance = 0.05
|
||||
|
||||
# # Check if the difference is within the tolerance
|
||||
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
assert output_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
assert router_diff < tolerance, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
Reference in New Issue
Block a user