diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 43e9a0c7a..3725ca2e7 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -457,7 +457,7 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" ) - @require_torch_2_7_0 + @require_torch_2_6_0 @pytest.mark.parametrize( "fsdp_reshard_after_forward", [True, False], @@ -503,7 +503,7 @@ class TestMultiGPULlama: "fsdp_reshard_after_forward": fsdp_reshard_after_forward, }, "use_tensorboard": True, - "flex_attention": True, + "flash_attention": True, } ) # write cfg to yaml file @@ -527,7 +527,7 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high" ) - @require_torch_2_6_0 + @require_torch_2_7_0 @pytest.mark.parametrize( "fsdp_reshard_after_forward", [True, False], @@ -573,7 +573,7 @@ class TestMultiGPULlama: "fsdp_reshard_after_forward": fsdp_reshard_after_forward, }, "use_tensorboard": True, - "flash_attention": True, + "flex_attention": True, } ) # write cfg to yaml file