re-enable DS zero3 ci with updated transformers (#2533)
This commit is contained in:
@@ -621,12 +621,6 @@ class TestMultiGPULlama:
|
|||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: remove skip once deepspeed regression is fixed
|
|
||||||
# see https://github.com/huggingface/transformers/pull/37324
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
transformers_version_eq("4.51.0"),
|
|
||||||
reason="zero3 is not supported with transformers==4.51.0",
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
|
|||||||
Reference in New Issue
Block a user