add retries for load datasets requests failures (#2007)

This commit is contained in:
Wing Lian
2024-10-31 13:26:14 -04:00
committed by GitHub
parent d4dbfa02fe
commit dc1de7d81b

View File

@@ -2,9 +2,11 @@
import functools import functools
import logging import logging
import time
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import requests
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
@@ -53,6 +55,28 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None): def prepare_dataset(cfg, tokenizer, processor=None):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset: