reduce test concurrency to avoid HF rate limiting, test suite parity (#2128)

* reduce test concurrency to avoid HF rate limiting, test suite parity

* make val_set_size smaller to speed up e2e tests

* more retries for pytest fixture downloads

* val_set_size was too small

* move retry_on_request_exceptions to data utils and add retry strategy

* pre-download ultrafeedback as a test fixture

* refactor download retry into it's own fn

* don't import from data utils

* use retry mechanism now for fixtures
This commit is contained in:
Wing Lian
2024-12-06 10:20:20 -05:00
committed by GitHub
parent 08fa133177
commit 5e9fa33f3d
12 changed files with 126 additions and 47 deletions

View File

@@ -23,9 +23,15 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
max-parallel: 2
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"] pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -55,6 +61,7 @@ jobs:
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging pip3 install --upgrade packaging
pip3 install -U -e . pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -45,9 +45,15 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
max-parallel: 2
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"] pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -95,6 +101,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
max-parallel: 1
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1"] pytorch_version: ["2.4.1", "2.5.1"]
@@ -124,6 +131,8 @@ jobs:
pip3 show torch pip3 show torch
python3 setup.py sdist python3 setup.py sdist
pip3 install dist/axolotl*.tar.gz pip3 install dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Ensure axolotl CLI was installed - name: Ensure axolotl CLI was installed

View File

@@ -2,11 +2,9 @@
import functools import functools
import logging import logging
import time
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import requests
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
@@ -44,7 +42,11 @@ from axolotl.prompters import (
UnsupportedPrompter, UnsupportedPrompter,
) )
from axolotl.utils.data.pretraining import wrap_pretraining_dataset from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
md5,
retry_on_request_exceptions,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
@@ -55,27 +57,6 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5) @retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None): def prepare_dataset(cfg, tokenizer, processor=None):
prompters = [] prompters = []

View File

@@ -1,13 +1,57 @@
"""data handling helpers""" """data handling helpers"""
import functools
import hashlib import hashlib
import logging import logging
import time
from enum import Enum
import huggingface_hub
import requests
from datasets import Dataset from datasets import Dataset
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
class RetryStrategy(Enum):
"""
Enum for retry strategies.
"""
CONSTANT = 1
LINEAR = 2
EXPONENTIAL = 3
def retry_on_request_exceptions(
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1:
if retry_strategy == RetryStrategy.EXPONENTIAL:
step_delay = delay * 2**attempt
elif retry_strategy == RetryStrategy.LINEAR:
step_delay = delay * (attempt + 1)
else:
step_delay = delay # Use constant delay.
time.sleep(step_delay)
else:
raise exc
return wrapper
return decorator
def md5(to_hash: str, encoding: str = "utf-8") -> str: def md5(to_hash: str, encoding: str = "utf-8") -> str:
try: try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()

View File

@@ -1,47 +1,77 @@
""" """
shared pytest fixtures shared pytest fixtures
""" """
import functools
import shutil import shutil
import tempfile import tempfile
import time
import pytest import pytest
import requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model(): def download_smollm2_135m_model():
# download the model # download the model
snapshot_download("HuggingFaceTB/SmolLM2-135M") snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model(): def download_llama_68m_random_model():
# download the model # download the model
snapshot_download("JackFram/llama-68m") snapshot_download_w_retry("JackFram/llama-68m")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model(): def download_qwen_2_5_half_billion_model():
# download the model # download the model
snapshot_download("Qwen/Qwen2.5-0.5B") snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset(): def download_tatsu_lab_alpaca_dataset():
# download the dataset # download the dataset
snapshot_download("tatsu-lab/alpaca", repo_type="dataset") snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_dataset(): def download_mhenrichsen_alpaca_2k_dataset():
# download the dataset # download the dataset
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset") snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_w_revision_dataset(): def download_mhenrichsen_alpaca_2k_w_revision_dataset():
# download the dataset # download the dataset
snapshot_download( snapshot_download_w_retry(
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb" "mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
) )
@@ -49,21 +79,29 @@ def download_mhenrichsen_alpaca_2k_w_revision_dataset():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_mlabonne_finetome_100k_dataset(): def download_mlabonne_finetome_100k_dataset():
# download the dataset # download the dataset
snapshot_download("mlabonne/FineTome-100k", repo_type="dataset") snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
@pytest.fixture @pytest.fixture(scope="session", autouse=True)
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset(): def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
# download the dataset # download the dataset
snapshot_download( snapshot_download_w_retry(
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset" "argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
) )
@pytest.fixture @pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset(): def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset # download the dataset
snapshot_download( snapshot_download_w_retry(
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset" "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
) )

View File

@@ -42,7 +42,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.1, "val_set_size": 0.02,
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mhenrichsen/alpaca_2k_test",
@@ -86,7 +86,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.02,
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mhenrichsen/alpaca_2k_test",

View File

@@ -40,7 +40,7 @@ class TestFalconPatched(unittest.TestCase):
"lora_dropout": 0.1, "lora_dropout": 0.1,
"lora_target_linear": True, "lora_target_linear": True,
"lora_modules_to_save": ["word_embeddings", "lm_head"], "lora_modules_to_save": ["word_embeddings", "lm_head"],
"val_set_size": 0.1, "val_set_size": 0.05,
"special_tokens": { "special_tokens": {
"bos_token": "<|endoftext|>", "bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",
@@ -80,7 +80,7 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,
"val_set_size": 0.1, "val_set_size": 0.05,
"special_tokens": { "special_tokens": {
"bos_token": "<|endoftext|>", "bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",

View File

@@ -38,7 +38,7 @@ class TestFusedLlama(unittest.TestCase):
"flash_attn_fuse_mlp": True, "flash_attn_fuse_mlp": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.1, "val_set_size": 0.02,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "unk_token": "<unk>",
"bos_token": "<s>", "bos_token": "<s>",

View File

@@ -98,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
"lora_alpha": 64, "lora_alpha": 64,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.02,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "unk_token": "<unk>",
"bos_token": "<s>", "bos_token": "<s>",

View File

@@ -39,7 +39,7 @@ class TestMistral(unittest.TestCase):
"lora_alpha": 64, "lora_alpha": 64,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.05,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "unk_token": "<unk>",
"bos_token": "<s>", "bos_token": "<s>",
@@ -80,7 +80,7 @@ class TestMistral(unittest.TestCase):
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.1, "val_set_size": 0.05,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "unk_token": "<unk>",
"bos_token": "<s>", "bos_token": "<s>",

View File

@@ -40,7 +40,7 @@ class TestMixtral(unittest.TestCase):
"lora_alpha": 32, "lora_alpha": 32,
"lora_dropout": 0.1, "lora_dropout": 0.1,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.05,
"special_tokens": {}, "special_tokens": {},
"datasets": [ "datasets": [
{ {
@@ -78,7 +78,7 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,
"val_set_size": 0.1, "val_set_size": 0.05,
"special_tokens": {}, "special_tokens": {},
"datasets": [ "datasets": [
{ {

View File

@@ -38,7 +38,7 @@ class TestPhiMultipack(unittest.TestCase):
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"load_in_8bit": False, "load_in_8bit": False,
"adapter": None, "adapter": None,
"val_set_size": 0.1, "val_set_size": 0.05,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",
}, },