be flexible on transformers version and skip test on version
This commit is contained in:
@@ -12,7 +12,7 @@ liger-kernel==0.5.5
|
||||
packaging==23.2
|
||||
|
||||
peft==0.15.0
|
||||
transformers==4.51.0
|
||||
transformers>=4.50.3,<=4.51.0
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.6.0
|
||||
datasets==3.5.0
|
||||
|
||||
@@ -7,9 +7,11 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import transformers
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from huggingface_hub import snapshot_download
|
||||
from packaging import version
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -28,6 +30,10 @@ def download_model():
|
||||
snapshot_download("HuggingFaceTB/SmolLM2-135M")
|
||||
|
||||
|
||||
def transformers_version_eq(required_version):
|
||||
return version.parse(transformers.__version__) == version.parse(required_version)
|
||||
|
||||
|
||||
class TestMultiGPULlama:
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
@@ -612,8 +618,11 @@ class TestMultiGPULlama:
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="ds-zero3 broken in main until transformers#37281 resolved"
|
||||
# TODO: remove skip once deepspeed regression is fixed
|
||||
# see https://github.com/huggingface/transformers/pull/37324
|
||||
@pytest.mark.skipif(
|
||||
transformers_version_eq("4.51.0"),
|
||||
reason="zero3 is not supported with transformers==4.51.0",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
|
||||
Reference in New Issue
Block a user