"""Shared pytest fixtures""" import functools import importlib import logging import os import shutil import sys import tempfile import time from pathlib import Path from typing import Generator import datasets import pytest import requests import torch from huggingface_hub import snapshot_download from huggingface_hub.errors import LocalEntryNotFoundError from tokenizers import AddedToken from transformers import AutoTokenizer from axolotl.utils.dict import DictDefault from tests.hf_offline_utils import ( enable_hf_offline, hf_offline_context, ) logging.getLogger("filelock").setLevel(logging.CRITICAL) def retry_on_request_exceptions(max_retries=3, delay=1): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): for attempt in range(max_retries): try: return func(*args, **kwargs) except ( requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError, requests.exceptions.HTTPError, ) as exc: if attempt < max_retries - 1: wait = 2**attempt * delay # in seconds time.sleep(wait) else: raise exc return wrapper return decorator @retry_on_request_exceptions(max_retries=3, delay=5) def snapshot_download_w_retry(*args, **kwargs): """ download a model or dataset from HF Hub, retrying in requests failures. We also try to fetch it from the local cache first using hf_hub_offline to avoid hitting HF Hub API rate limits. If it doesn't exist in the cache, disable hf_hub_offline and actually fetch from the hub """ with hf_offline_context(True): try: return snapshot_download(*args, local_files_only=True, **kwargs) except LocalEntryNotFoundError: pass with hf_offline_context(False): return snapshot_download(*args, **kwargs) @pytest.fixture(scope="session", autouse=True) def download_ds_fixture_bundle(): ds_dir = snapshot_download_w_retry( "axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset" ) return Path(ds_dir) @pytest.fixture(scope="session", autouse=True) def download_smollm2_135m_model(): # download the model snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model") @pytest.fixture(scope="session", autouse=True) def download_smollm2_135m_instruct_model(): # download the model snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M-Instruct", repo_type="model") @pytest.fixture(scope="session", autouse=True) def download_smollm2_135m_gptq_model(): # download the model snapshot_download_w_retry("lilmeaty/SmolLM2-135M-Instruct-GPTQ", repo_type="model") @pytest.fixture(scope="session", autouse=True) def download_qwen_2_5_half_billion_model(): # download the model snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model") @pytest.fixture(scope="session", autouse=True) def download_qwen3_half_billion_model(): # download the model snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model") @pytest.fixture(scope="session", autouse=True) def download_tatsu_lab_alpaca_dataset(): # download the dataset snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset") @pytest.fixture(scope="session", autouse=True) def download_mhenrichsen_alpaca_2k_dataset(): # download the dataset snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset") @pytest.fixture(scope="session", autouse=True) def download_mhenrichsen_alpaca_2k_w_revision_dataset(): # download the dataset snapshot_download_w_retry( "mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb" ) @pytest.fixture(scope="session", autouse=True) def download_mlabonne_finetome_100k_dataset(): # download the dataset snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset") @pytest.fixture(scope="session", autouse=True) def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset(): # download the dataset snapshot_download_w_retry( "argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset" ) @pytest.fixture(scope="session", autouse=True) def download_argilla_distilabel_intel_orca_dpo_dataset(): # download the dataset snapshot_download_w_retry( "argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset" ) @pytest.fixture(scope="session", autouse=True) def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset(): # download the dataset snapshot_download_w_retry( "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset" ) @pytest.fixture(scope="session", autouse=True) def download_argilla_ultrafeedback_binarized_preferences_cleaned_kto_dataset(): # download the dataset snapshot_download_w_retry( "argilla/ultrafeedback-binarized-preferences-cleaned-kto", repo_type="dataset" ) # @pytest.fixture(scope="session", autouse=True) # def download_fozzie_alpaca_dpo_dataset(): # # download the dataset # snapshot_download_w_retry( # "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset" # ) # snapshot_download_w_retry( # "fozziethebeat/alpaca_messages_2k_dpo_test", # repo_type="dataset", # revision="ea82cff", # ) # @pytest.fixture(scope="session") # @disable_hf_offline # def dataset_fozzie_alpaca_dpo_dataset( # download_fozzie_alpaca_dpo_dataset, # ): # return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train") # # # @pytest.fixture(scope="session") # @disable_hf_offline # def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff( # download_fozzie_alpaca_dpo_dataset, # ): # return load_dataset( # "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff" # ) @pytest.fixture(scope="session", autouse=True) def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset(): # download the dataset snapshot_download_w_retry( "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset" ) @pytest.fixture(scope="session", autouse=True) def download_argilla_dpo_pairs_dataset(): # download the dataset snapshot_download_w_retry( "argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset" ) @pytest.fixture(scope="session", autouse=True) def download_tiny_shakespeare_dataset(): # download the dataset snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset") @pytest.fixture(scope="session", autouse=True) def download_evolkit_kd_sample_dataset(): # download the dataset snapshot_download_w_retry( "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset" ) @pytest.fixture(scope="session", autouse=True) def download_deepseek_model_fixture(): snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model") @pytest.fixture(scope="session", autouse=True) def download_huggyllama_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "huggyllama/llama-7b", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_llama33_70b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_llama_1b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "NousResearch/Llama-3.2-1B", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_llama3_8b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_llama3_8b_instruct_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "NousResearch/Meta-Llama-3-8B-Instruct", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_phi_35_mini_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_phi_4_reasoning_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "microsoft/Phi-4-reasoning", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_phi_3_medium_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "microsoft/Phi-3-medium-128k-instruct", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_mistral_7b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "casperhansen/mistral-7b-instruct-v0.1-awq", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_gemma3_4b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "mlx-community/gemma-3-4b-it-8bit", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_gemma_2b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "unsloth/gemma-2b-it", revision="703fb4a", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_gemma2_9b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "mlx-community/gemma-2-9b-it-4bit", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_mlx_mistral_7b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "mlx-community/Mistral-7B-Instruct-v0.3-4bit", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture def download_llama2_model_fixture(): # download the tokenizer only snapshot_download_w_retry( "NousResearch/Llama-2-7b-hf", repo_type="model", allow_patterns=["*token*", "config.json"], ) @pytest.fixture(scope="session", autouse=True) def download_llama32_1b_model_fixture(): snapshot_download_w_retry( "osllmai-community/Llama-3.2-1B", repo_type="model", ) @pytest.fixture @enable_hf_offline def tokenizer_huggyllama( download_huggyllama_model_fixture, ): tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") tokenizer.pad_token = "" return tokenizer @pytest.fixture @enable_hf_offline def tokenizer_huggyllama_w_special_tokens( tokenizer_huggyllama, ): tokenizer_huggyllama.add_special_tokens( { "bos_token": "", "eos_token": "", "unk_token": "", } ) return tokenizer_huggyllama @pytest.fixture @enable_hf_offline def tokenizer_llama2_7b( download_llama2_model_fixture, ): tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf") return tokenizer @pytest.fixture @enable_hf_offline def tokenizer_mistral_7b_instruct( download_mlx_mistral_7b_model_fixture, ): return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq") @pytest.fixture def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct): tokenizer_mistral_7b_instruct.add_special_tokens( { "eos_token": AddedToken( "<|im_end|>", rstrip=False, lstrip=False, normalized=False ) } ) tokenizer_mistral_7b_instruct.add_tokens( [ AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), ] ) return tokenizer_mistral_7b_instruct @pytest.fixture def temp_dir() -> Generator[str, None, None]: # Create a temporary directory _temp_dir = tempfile.mkdtemp() yield _temp_dir # Clean up the directory after the test shutil.rmtree(_temp_dir) @pytest.fixture(scope="function", autouse=True) def torch_manual_seed(): torch.manual_seed(42) @pytest.fixture(scope="function", autouse=True) def cleanup_monkeypatches(): from transformers import Trainer from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2, LlamaAttention, LlamaForCausalLM, ) # original_fa2_forward = LlamaFlashAttention2.forward original_llama_attn_forward = LlamaAttention.forward original_llama_forward = LlamaForCausalLM.forward original_trainer_inner_training_loop = Trainer._inner_training_loop original_trainer_training_step = Trainer.training_step # monkey patches can happen inside the tests yield # Reset LlamaFlashAttention2 forward # LlamaFlashAttention2.forward = original_fa2_forward LlamaAttention.forward = original_llama_attn_forward LlamaForCausalLM.forward = original_llama_forward Trainer._inner_training_loop = original_trainer_inner_training_loop Trainer.training_step = original_trainer_training_step # Reset other known monkeypatches modules_to_reset: list[tuple[str, list[str]]] = [ ("transformers.models.llama",), ( "transformers.models.llama.modeling_llama", [ # "LlamaFlashAttention2", "LlamaAttention", ], ), ("transformers.trainer",), ("transformers", ["Trainer"]), ("transformers.loss.loss_utils",), ] for module_name_tuple in modules_to_reset: module_name = module_name_tuple[0] spec = importlib.util.spec_from_file_location( module_name, sys.modules[module_name].__file__ ) sys.modules[module_name] = importlib.util.module_from_spec(spec) spec.loader.exec_module(sys.modules[module_name]) sys.modules[module_name] = importlib.reload(sys.modules[module_name]) if len(module_name_tuple) > 1: module_globals = module_name_tuple[1] for module_global in module_globals: globals().pop(module_global, None) @pytest.fixture def dataset_winglian_tiny_shakespeare( download_ds_fixture_bundle: Path, ): ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare" return datasets.load_from_disk(ds_path) @pytest.fixture def dataset_tatsu_lab_alpaca( download_ds_fixture_bundle: Path, ): ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca" return datasets.load_from_disk(ds_path)["train"] @pytest.fixture def dataset_mhenrichsen_alpaca_2k_test( download_ds_fixture_bundle: Path, ): ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test" return datasets.load_from_disk(ds_path)["train"] @pytest.fixture def dataset_argilla_ultrafeedback_binarized_preferences_cleaned( download_ds_fixture_bundle: Path, ): ds_path = ( download_ds_fixture_bundle / "argilla__ultrafeedback-binarized-preferences-cleaned" ) return datasets.load_from_disk(ds_path)["train"] @pytest.fixture def dataset_fozziethebeat_alpaca_messages_2k_dpo_test( download_ds_fixture_bundle: Path, ): ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test" return datasets.load_from_disk(ds_path)["train"] @pytest.fixture def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff( download_ds_fixture_bundle: Path, ): ds_path = ( download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff" ) return datasets.load_from_disk(ds_path)["train"] @pytest.fixture(name="min_base_cfg") def fixture_min_base_cfg(): return DictDefault( base_model="HuggingFaceTB/SmolLM2-135M", learning_rate=1e-3, datasets=[ { "path": "mhenrichsen/alpaca_2k_test", "type": "alpaca", }, ], micro_batch_size=1, gradient_accumulation_steps=1, ) # @pytest.mark.skipif( os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1", reason="Not running in CI cache preload", ) def test_load_fixtures( download_smollm2_135m_model, download_qwen_2_5_half_billion_model, download_tatsu_lab_alpaca_dataset, download_mhenrichsen_alpaca_2k_dataset, download_mhenrichsen_alpaca_2k_w_revision_dataset, download_mlabonne_finetome_100k_dataset, download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset, download_argilla_ultrafeedback_binarized_preferences_cleaned_kto_dataset, download_argilla_distilabel_capybara_dpo_7k_binarized_dataset, download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset, download_argilla_dpo_pairs_dataset, download_tiny_shakespeare_dataset, download_deepseek_model_fixture, download_huggyllama_model_fixture, download_llama_1b_model_fixture, download_llama3_8b_model_fixture, download_llama3_8b_instruct_model_fixture, download_phi_35_mini_model_fixture, download_phi_3_medium_model_fixture, download_phi_4_reasoning_model_fixture, download_mistral_7b_model_fixture, download_gemma_2b_model_fixture, download_gemma2_9b_model_fixture, download_mlx_mistral_7b_model_fixture, download_llama2_model_fixture, ): pass @pytest.fixture(autouse=True) def disable_telemetry(monkeypatch): monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1") yield