diff --git a/tests/conftest.py b/tests/conftest.py index 4d05d3a26..c71ea1e8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,10 +14,15 @@ import datasets import pytest import requests from huggingface_hub import snapshot_download +from huggingface_hub.errors import LocalEntryNotFoundError from tokenizers import AddedToken from transformers import AutoTokenizer -from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline +from tests.hf_offline_utils import ( + disable_hf_offline, + enable_hf_offline, + hf_offline_context, +) def retry_on_request_exceptions(max_retries=3, delay=1): @@ -47,6 +52,16 @@ def retry_on_request_exceptions(max_retries=3, delay=1): @retry_on_request_exceptions(max_retries=3, delay=5) @disable_hf_offline def snapshot_download_w_retry(*args, **kwargs): + """ + download a model or dataset from HF Hub, retrying in requests failures. We also try to fetch it from the local + cache first using hf_hub_offline to avoid hitting HF Hub API rate limits. If it doesn't exist in the cache, + disable hf_hub_offline and actually fetch from the hub + """ + with hf_offline_context(True): + try: + return snapshot_download(*args, **kwargs) + except LocalEntryNotFoundError: + pass return snapshot_download(*args, **kwargs) diff --git a/tests/hf_offline_utils.py b/tests/hf_offline_utils.py index 0ce878577..0c7b5d4a4 100644 --- a/tests/hf_offline_utils.py +++ b/tests/hf_offline_utils.py @@ -3,6 +3,7 @@ test utils for helpers and decorators """ import os +from contextlib import contextmanager from functools import wraps from huggingface_hub.utils import reset_sessions @@ -83,3 +84,23 @@ def disable_hf_offline(test_func): reload_modules(False) return wrapper + + +@contextmanager +def hf_offline_context(hf_hub_offline): + """ + Context manager that sets HF_HUB_OFFLINE environment variable to the given value. + :param hf_hub_offline: The new value for HF_HUB_OFFLINE. + :return: A context manager. + """ + original_hf_offline = os.getenv("HF_HUB_OFFLINE") + os.environ["HF_HUB_OFFLINE"] = str(hf_hub_offline) + reload_modules(True) + yield + # Restore the original value of HF_HUB_OFFLINE environment variable + if original_hf_offline is not None: + os.environ["HF_HUB_OFFLINE"] = original_hf_offline + reload_modules(bool(original_hf_offline)) + else: + del os.environ["HF_HUB_OFFLINE"] + reload_modules(False)