fixing tests
This commit is contained in:
@@ -457,7 +457,7 @@ class TestMultiGPULlama:
|
|||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_2_7_0
|
@require_torch_2_6_0
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"fsdp_reshard_after_forward",
|
"fsdp_reshard_after_forward",
|
||||||
[True, False],
|
[True, False],
|
||||||
@@ -503,7 +503,7 @@ class TestMultiGPULlama:
|
|||||||
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
||||||
},
|
},
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"flex_attention": True,
|
"flash_attention": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# write cfg to yaml file
|
# write cfg to yaml file
|
||||||
@@ -527,7 +527,7 @@ class TestMultiGPULlama:
|
|||||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_2_6_0
|
@require_torch_2_7_0
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"fsdp_reshard_after_forward",
|
"fsdp_reshard_after_forward",
|
||||||
[True, False],
|
[True, False],
|
||||||
@@ -573,7 +573,7 @@ class TestMultiGPULlama:
|
|||||||
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
||||||
},
|
},
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"flash_attention": True,
|
"flex_attention": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# write cfg to yaml file
|
# write cfg to yaml file
|
||||||
|
|||||||
Reference in New Issue
Block a user