fixing tests
This commit is contained in:
@@ -457,7 +457,7 @@ class TestMultiGPULlama:
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@require_torch_2_7_0
|
||||
@require_torch_2_6_0
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_reshard_after_forward",
|
||||
[True, False],
|
||||
@@ -503,7 +503,7 @@ class TestMultiGPULlama:
|
||||
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"flex_attention": True,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
# write cfg to yaml file
|
||||
@@ -527,7 +527,7 @@ class TestMultiGPULlama:
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@require_torch_2_6_0
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_reshard_after_forward",
|
||||
[True, False],
|
||||
@@ -573,7 +573,7 @@ class TestMultiGPULlama:
|
||||
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"flash_attention": True,
|
||||
"flex_attention": True,
|
||||
}
|
||||
)
|
||||
# write cfg to yaml file
|
||||
|
||||
Reference in New Issue
Block a user