diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index ce01b4409..ec17fb9c2 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -2,9 +2,11 @@ import functools import logging +import time from pathlib import Path from typing import List, Optional, Tuple, Union +import requests from datasets import ( Dataset, DatasetDict, @@ -53,6 +55,28 @@ from axolotl.utils.trainer import ( LOG = logging.getLogger("axolotl") +def retry_on_request_exceptions(max_retries=3, delay=1): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except ( + requests.exceptions.ReadTimeout, + requests.exceptions.ConnectionError, + ) as exc: + if attempt < max_retries - 1: + time.sleep(delay) + else: + raise exc + + return wrapper + + return decorator + + +@retry_on_request_exceptions(max_retries=3, delay=5) def prepare_dataset(cfg, tokenizer, processor=None): prompters = [] if not cfg.pretraining_dataset: