* reduce test concurrency to avoid HF rate limiting, test suite parity * make val_set_size smaller to speed up e2e tests * more retries for pytest fixture downloads * val_set_size was too small * move retry_on_request_exceptions to data utils and add retry strategy * pre-download ultrafeedback as a test fixture * refactor download retry into it's own fn * don't import from data utils * use retry mechanism now for fixtures
116 lines
3.3 KiB
Python
116 lines
3.3 KiB
Python
"""
|
|
shared pytest fixtures
|
|
"""
|
|
import functools
|
|
import shutil
|
|
import tempfile
|
|
import time
|
|
|
|
import pytest
|
|
import requests
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
|
# pylint: disable=duplicate-code
|
|
def decorator(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except (
|
|
requests.exceptions.ReadTimeout,
|
|
requests.exceptions.ConnectionError,
|
|
) as exc:
|
|
if attempt < max_retries - 1:
|
|
time.sleep(delay)
|
|
else:
|
|
raise exc
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
|
def snapshot_download_w_retry(*args, **kwargs):
|
|
return snapshot_download(*args, **kwargs)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_smollm2_135m_model():
|
|
# download the model
|
|
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_llama_68m_random_model():
|
|
# download the model
|
|
snapshot_download_w_retry("JackFram/llama-68m")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_qwen_2_5_half_billion_model():
|
|
# download the model
|
|
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_tatsu_lab_alpaca_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mhenrichsen_alpaca_2k_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mlabonne_finetome_100k_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dir():
|
|
# Create a temporary directory
|
|
_temp_dir = tempfile.mkdtemp()
|
|
yield _temp_dir
|
|
# Clean up the directory after the test
|
|
shutil.rmtree(_temp_dir)
|