reduce test concurrency to avoid HF rate limiting, test suite parity (#2128)

* reduce test concurrency to avoid HF rate limiting, test suite parity

* make val_set_size smaller to speed up e2e tests

* more retries for pytest fixture downloads

* val_set_size was too small

* move retry_on_request_exceptions to data utils and add retry strategy

* pre-download ultrafeedback as a test fixture

* refactor download retry into it's own fn

* don't import from data utils

* use retry mechanism now for fixtures
This commit is contained in:
Wing Lian
2024-12-06 10:20:20 -05:00
committed by GitHub
parent 08fa133177
commit 5e9fa33f3d
12 changed files with 126 additions and 47 deletions

View File

@@ -2,11 +2,9 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -44,7 +42,11 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
md5,
retry_on_request_exceptions,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
@@ -55,27 +57,6 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []

View File

@@ -1,13 +1,57 @@
"""data handling helpers"""
import functools
import hashlib
import logging
import time
from enum import Enum
import huggingface_hub
import requests
from datasets import Dataset
LOG = logging.getLogger("axolotl")
class RetryStrategy(Enum):
"""
Enum for retry strategies.
"""
CONSTANT = 1
LINEAR = 2
EXPONENTIAL = 3
def retry_on_request_exceptions(
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1:
if retry_strategy == RetryStrategy.EXPONENTIAL:
step_delay = delay * 2**attempt
elif retry_strategy == RetryStrategy.LINEAR:
step_delay = delay * (attempt + 1)
else:
step_delay = delay # Use constant delay.
time.sleep(step_delay)
else:
raise exc
return wrapper
return decorator
def md5(to_hash: str, encoding: str = "utf-8") -> str:
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()