590 lines
18 KiB
Python
590 lines
18 KiB
Python
"""
|
|
shared pytest fixtures
|
|
"""
|
|
|
|
import functools
|
|
import importlib
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Generator
|
|
|
|
import datasets
|
|
import pytest
|
|
import requests
|
|
import torch
|
|
from huggingface_hub import snapshot_download
|
|
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
from tokenizers import AddedToken
|
|
from transformers import AutoTokenizer
|
|
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
from tests.hf_offline_utils import (
|
|
enable_hf_offline,
|
|
hf_offline_context,
|
|
)
|
|
|
|
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
|
|
|
|
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
|
# pylint: disable=duplicate-code
|
|
def decorator(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except (
|
|
requests.exceptions.ReadTimeout,
|
|
requests.exceptions.ConnectionError,
|
|
requests.exceptions.HTTPError,
|
|
) as exc:
|
|
if attempt < max_retries - 1:
|
|
wait = 2**attempt * delay # in seconds
|
|
time.sleep(wait)
|
|
else:
|
|
raise exc
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
|
def snapshot_download_w_retry(*args, **kwargs):
|
|
"""
|
|
download a model or dataset from HF Hub, retrying in requests failures. We also try to fetch it from the local
|
|
cache first using hf_hub_offline to avoid hitting HF Hub API rate limits. If it doesn't exist in the cache,
|
|
disable hf_hub_offline and actually fetch from the hub
|
|
"""
|
|
with hf_offline_context(True):
|
|
try:
|
|
return snapshot_download(*args, **kwargs)
|
|
except LocalEntryNotFoundError:
|
|
pass
|
|
with hf_offline_context(False):
|
|
return snapshot_download(*args, **kwargs)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_ds_fixture_bundle():
|
|
ds_dir = snapshot_download_w_retry(
|
|
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
|
|
)
|
|
return Path(ds_dir)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_smollm2_135m_model():
|
|
# download the model
|
|
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_smollm2_135m_gptq_model():
|
|
# download the model
|
|
snapshot_download_w_retry("lilmeaty/SmolLM2-135M-Instruct-GPTQ", repo_type="model")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_qwen_2_5_half_billion_model():
|
|
# download the model
|
|
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_qwen3_half_billion_model():
|
|
# download the model
|
|
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_tatsu_lab_alpaca_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mhenrichsen_alpaca_2k_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mlabonne_finetome_100k_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_argilla_distilabel_intel_orca_dpo_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
|
)
|
|
|
|
|
|
# @pytest.fixture(scope="session", autouse=True)
|
|
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
|
# # download the dataset
|
|
# snapshot_download_w_retry(
|
|
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
|
# )
|
|
|
|
|
|
# @pytest.fixture(scope="session", autouse=True)
|
|
# def download_fozzie_alpaca_dpo_dataset():
|
|
# # download the dataset
|
|
# snapshot_download_w_retry(
|
|
# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
|
# )
|
|
# snapshot_download_w_retry(
|
|
# "fozziethebeat/alpaca_messages_2k_dpo_test",
|
|
# repo_type="dataset",
|
|
# revision="ea82cff",
|
|
# )
|
|
|
|
|
|
# @pytest.fixture(scope="session")
|
|
# @disable_hf_offline
|
|
# def dataset_fozzie_alpaca_dpo_dataset(
|
|
# download_fozzie_alpaca_dpo_dataset,
|
|
# ): # pylint: disable=unused-argument,redefined-outer-name
|
|
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
|
#
|
|
#
|
|
# @pytest.fixture(scope="session")
|
|
# @disable_hf_offline
|
|
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
|
# download_fozzie_alpaca_dpo_dataset,
|
|
# ): # pylint: disable=unused-argument,redefined-outer-name
|
|
# return load_dataset(
|
|
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
|
# )
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_argilla_dpo_pairs_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_tiny_shakespeare_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_evolkit_kd_sample_dataset():
|
|
# download the dataset
|
|
snapshot_download_w_retry(
|
|
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_deepseek_model_fixture():
|
|
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_huggyllama_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"huggyllama/llama-7b",
|
|
repo_type="model",
|
|
allow_patterns=["*token*", "config.json"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_llama33_70b_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
|
|
repo_type="model",
|
|
allow_patterns=["*token*", "config.json"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_llama_1b_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"NousResearch/Llama-3.2-1B",
|
|
repo_type="model",
|
|
allow_patterns=["*token*", "config.json"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_llama3_8b_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_llama3_8b_instruct_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"NousResearch/Meta-Llama-3-8B-Instruct",
|
|
repo_type="model",
|
|
allow_patterns=["*token*"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_phi_35_mini_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_phi_3_medium_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"microsoft/Phi-3-medium-128k-instruct",
|
|
repo_type="model",
|
|
allow_patterns=["*token*"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mistral_7b_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"casperhansen/mistral-7b-instruct-v0.1-awq",
|
|
repo_type="model",
|
|
allow_patterns=["*token*", "config.json"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_gemma3_4b_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"mlx-community/gemma-3-4b-it-8bit",
|
|
repo_type="model",
|
|
allow_patterns=["*token*", "config.json"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_gemma_2b_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"unsloth/gemma-2b-it",
|
|
revision="703fb4a",
|
|
repo_type="model",
|
|
allow_patterns=["*token*", "config.json"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_gemma2_9b_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"mlx-community/gemma-2-9b-it-4bit",
|
|
repo_type="model",
|
|
allow_patterns=["*token*", "config.json"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mlx_mistral_7b_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
|
|
repo_type="model",
|
|
allow_patterns=["*token*", "config.json"],
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def download_llama2_model_fixture():
|
|
# download the tokenizer only
|
|
snapshot_download_w_retry(
|
|
"NousResearch/Llama-2-7b-hf",
|
|
repo_type="model",
|
|
allow_patterns=["*token*", "config.json"],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_llama32_1b_model_fixture():
|
|
snapshot_download_w_retry(
|
|
"osllmai-community/Llama-3.2-1B",
|
|
repo_type="model",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
@enable_hf_offline
|
|
def tokenizer_huggyllama(
|
|
download_huggyllama_model_fixture,
|
|
): # pylint: disable=unused-argument,redefined-outer-name
|
|
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
|
tokenizer.pad_token = "</s>"
|
|
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture
|
|
@enable_hf_offline
|
|
def tokenizer_huggyllama_w_special_tokens(
|
|
tokenizer_huggyllama,
|
|
): # pylint: disable=redefined-outer-name
|
|
tokenizer_huggyllama.add_special_tokens(
|
|
{
|
|
"bos_token": "<s>",
|
|
"eos_token": "</s>",
|
|
"unk_token": "<unk>",
|
|
}
|
|
)
|
|
|
|
return tokenizer_huggyllama
|
|
|
|
|
|
@pytest.fixture
|
|
@enable_hf_offline
|
|
def tokenizer_llama2_7b(
|
|
download_llama2_model_fixture,
|
|
): # pylint: disable=unused-argument,redefined-outer-name
|
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
|
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture
|
|
@enable_hf_offline
|
|
def tokenizer_mistral_7b_instruct(
|
|
download_mlx_mistral_7b_model_fixture,
|
|
): # pylint: disable=unused-argument,redefined-outer-name
|
|
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
|
|
|
|
|
|
@pytest.fixture
|
|
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
|
tokenizer_mistral_7b_instruct.add_special_tokens(
|
|
{
|
|
"eos_token": AddedToken(
|
|
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
|
)
|
|
}
|
|
)
|
|
tokenizer_mistral_7b_instruct.add_tokens(
|
|
[
|
|
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
|
]
|
|
)
|
|
return tokenizer_mistral_7b_instruct
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dir() -> Generator[str, None, None]:
|
|
# Create a temporary directory
|
|
_temp_dir = tempfile.mkdtemp()
|
|
yield _temp_dir
|
|
# Clean up the directory after the test
|
|
shutil.rmtree(_temp_dir)
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def torch_manual_seed():
|
|
torch.manual_seed(42)
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def cleanup_monkeypatches():
|
|
from transformers import Trainer
|
|
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
|
LlamaAttention,
|
|
LlamaForCausalLM,
|
|
)
|
|
|
|
# original_fa2_forward = LlamaFlashAttention2.forward
|
|
original_llama_attn_forward = LlamaAttention.forward
|
|
original_llama_forward = LlamaForCausalLM.forward
|
|
original_trainer_inner_training_loop = (
|
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
|
)
|
|
original_trainer_training_step = Trainer.training_step
|
|
# monkey patches can happen inside the tests
|
|
yield
|
|
# Reset LlamaFlashAttention2 forward
|
|
# LlamaFlashAttention2.forward = original_fa2_forward
|
|
LlamaAttention.forward = original_llama_attn_forward
|
|
LlamaForCausalLM.forward = original_llama_forward
|
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
|
original_trainer_inner_training_loop
|
|
)
|
|
Trainer.training_step = original_trainer_training_step
|
|
|
|
# Reset other known monkeypatches
|
|
modules_to_reset: list[tuple[str, list[str]]] = [
|
|
("transformers.models.llama",),
|
|
(
|
|
"transformers.models.llama.modeling_llama",
|
|
[
|
|
# "LlamaFlashAttention2",
|
|
"LlamaAttention",
|
|
],
|
|
),
|
|
("transformers.trainer",),
|
|
("transformers", ["Trainer"]),
|
|
("transformers.loss.loss_utils",),
|
|
]
|
|
for module_name_tuple in modules_to_reset:
|
|
module_name = module_name_tuple[0]
|
|
|
|
spec = importlib.util.spec_from_file_location(
|
|
module_name, sys.modules[module_name].__file__
|
|
)
|
|
sys.modules[module_name] = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(sys.modules[module_name])
|
|
|
|
sys.modules[module_name] = importlib.reload(sys.modules[module_name])
|
|
if len(module_name_tuple) > 1:
|
|
module_globals = module_name_tuple[1]
|
|
for module_global in module_globals:
|
|
globals().pop(module_global, None)
|
|
|
|
|
|
@pytest.fixture
|
|
def dataset_winglian_tiny_shakespeare(
|
|
download_ds_fixture_bundle: Path,
|
|
): # pylint: disable=redefined-outer-name
|
|
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
|
|
return datasets.load_from_disk(ds_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def dataset_tatsu_lab_alpaca(
|
|
download_ds_fixture_bundle: Path,
|
|
): # pylint: disable=redefined-outer-name
|
|
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
|
|
return datasets.load_from_disk(ds_path)["train"]
|
|
|
|
|
|
@pytest.fixture
|
|
def dataset_mhenrichsen_alpaca_2k_test(
|
|
download_ds_fixture_bundle: Path,
|
|
): # pylint: disable=redefined-outer-name
|
|
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
|
|
return datasets.load_from_disk(ds_path)["train"]
|
|
|
|
|
|
@pytest.fixture
|
|
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
|
download_ds_fixture_bundle: Path,
|
|
): # pylint: disable=redefined-outer-name
|
|
ds_path = (
|
|
download_ds_fixture_bundle
|
|
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
|
|
)
|
|
return datasets.load_from_disk(ds_path)["train"]
|
|
|
|
|
|
@pytest.fixture
|
|
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
|
download_ds_fixture_bundle: Path,
|
|
): # pylint: disable=redefined-outer-name
|
|
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
|
|
return datasets.load_from_disk(ds_path)["train"]
|
|
|
|
|
|
@pytest.fixture
|
|
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
|
download_ds_fixture_bundle: Path,
|
|
): # pylint: disable=redefined-outer-name
|
|
ds_path = (
|
|
download_ds_fixture_bundle
|
|
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
|
|
)
|
|
return datasets.load_from_disk(ds_path)["train"]
|
|
|
|
|
|
@pytest.fixture(name="min_base_cfg")
|
|
def fixture_min_base_cfg():
|
|
return DictDefault(
|
|
base_model="HuggingFaceTB/SmolLM2-135M",
|
|
learning_rate=1e-3,
|
|
datasets=[
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
micro_batch_size=1,
|
|
gradient_accumulation_steps=1,
|
|
)
|
|
|
|
|
|
# # pylint: disable=redefined-outer-name,unused-argument
|
|
@pytest.mark.skipif(
|
|
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
|
|
reason="Not running in CI cache preload",
|
|
)
|
|
def test_load_fixtures(
|
|
download_smollm2_135m_model,
|
|
download_qwen_2_5_half_billion_model,
|
|
download_tatsu_lab_alpaca_dataset,
|
|
download_mhenrichsen_alpaca_2k_dataset,
|
|
download_mhenrichsen_alpaca_2k_w_revision_dataset,
|
|
download_mlabonne_finetome_100k_dataset,
|
|
download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
|
|
download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
|
|
download_argilla_dpo_pairs_dataset,
|
|
download_tiny_shakespeare_dataset,
|
|
download_deepseek_model_fixture,
|
|
download_huggyllama_model_fixture,
|
|
download_llama_1b_model_fixture,
|
|
download_llama3_8b_model_fixture,
|
|
download_llama3_8b_instruct_model_fixture,
|
|
download_phi_35_mini_model_fixture,
|
|
download_phi_3_medium_model_fixture,
|
|
download_mistral_7b_model_fixture,
|
|
download_gemma_2b_model_fixture,
|
|
download_gemma2_9b_model_fixture,
|
|
download_mlx_mistral_7b_model_fixture,
|
|
download_llama2_model_fixture,
|
|
):
|
|
pass
|