fixing tests

This commit is contained in:
Salman Mohammadi
2025-04-08 17:23:21 +01:00
parent 6f47b1e896
commit 2f147cc6ff

View File

@@ -457,7 +457,7 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
) )
@require_torch_2_7_0 @require_torch_2_6_0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_reshard_after_forward", "fsdp_reshard_after_forward",
[True, False], [True, False],
@@ -503,7 +503,7 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward, "fsdp_reshard_after_forward": fsdp_reshard_after_forward,
}, },
"use_tensorboard": True, "use_tensorboard": True,
"flex_attention": True, "flash_attention": True,
} }
) )
# write cfg to yaml file # write cfg to yaml file
@@ -527,7 +527,7 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
) )
@require_torch_2_6_0 @require_torch_2_7_0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_reshard_after_forward", "fsdp_reshard_after_forward",
[True, False], [True, False],
@@ -573,7 +573,7 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward, "fsdp_reshard_after_forward": fsdp_reshard_after_forward,
}, },
"use_tensorboard": True, "use_tensorboard": True,
"flash_attention": True, "flex_attention": True,
} }
) )
# write cfg to yaml file # write cfg to yaml file