fixing tests

This commit is contained in:
Salman Mohammadi
2025-04-08 17:23:21 +01:00
parent 6f47b1e896
commit 2f147cc6ff

View File

@@ -457,7 +457,7 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@require_torch_2_7_0
@require_torch_2_6_0
@pytest.mark.parametrize(
"fsdp_reshard_after_forward",
[True, False],
@@ -503,7 +503,7 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
"flex_attention": True,
"flash_attention": True,
}
)
# write cfg to yaml file
@@ -527,7 +527,7 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
)
@require_torch_2_6_0
@require_torch_2_7_0
@pytest.mark.parametrize(
"fsdp_reshard_after_forward",
[True, False],
@@ -573,7 +573,7 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
"flash_attention": True,
"flex_attention": True,
}
)
# write cfg to yaml file