diff --git a/requirements.txt b/requirements.txt index 2f7eb46b4..13718b3e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ liger-kernel==0.5.5 packaging==23.2 peft==0.15.0 -transformers==4.51.0 +transformers>=4.50.3,<=4.51.0 tokenizers>=0.21.1 accelerate==1.6.0 datasets==3.5.0 diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 433d06835..c2bb12a02 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -7,9 +7,11 @@ import os from pathlib import Path import pytest +import transformers import yaml from accelerate.test_utils import execute_subprocess_async from huggingface_hub import snapshot_download +from packaging import version from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault @@ -28,6 +30,10 @@ def download_model(): snapshot_download("HuggingFaceTB/SmolLM2-135M") +def transformers_version_eq(required_version): + return version.parse(transformers.__version__) == version.parse(required_version) + + class TestMultiGPULlama: """ Test case for Llama models using LoRA @@ -612,8 +618,11 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" ) - @pytest.mark.skip( - reason="ds-zero3 broken in main until transformers#37281 resolved" + # TODO: remove skip once deepspeed regression is fixed + # see https://github.com/huggingface/transformers/pull/37324 + @pytest.mark.skipif( + transformers_version_eq("4.51.0"), + reason="zero3 is not supported with transformers==4.51.0", ) @pytest.mark.parametrize( "gradient_accumulation_steps",