Merge branch 'main' into telemetry-opt-in
This commit is contained in:
@@ -2,32 +2,38 @@
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import (
|
||||
enable_hf_offline,
|
||||
hf_offline_context,
|
||||
)
|
||||
|
||||
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
# pylint: disable=duplicate-code
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||
def wrapper(*args, **kwargs):
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -162,7 +168,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# ):
|
||||
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
#
|
||||
#
|
||||
@@ -170,7 +176,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# ):
|
||||
# return load_dataset(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
# )
|
||||
@@ -350,7 +356,7 @@ def download_llama32_1b_model_fixture():
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
download_huggyllama_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
tokenizer.pad_token = "</s>"
|
||||
|
||||
@@ -361,7 +367,7 @@ def tokenizer_huggyllama(
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama_w_special_tokens(
|
||||
tokenizer_huggyllama,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
tokenizer_huggyllama.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
@@ -377,7 +383,7 @@ def tokenizer_huggyllama_w_special_tokens(
|
||||
@enable_hf_offline
|
||||
def tokenizer_llama2_7b(
|
||||
download_llama2_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
|
||||
return tokenizer
|
||||
@@ -387,7 +393,7 @@ def tokenizer_llama2_7b(
|
||||
@enable_hf_offline
|
||||
def tokenizer_mistral_7b_instruct(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
|
||||
|
||||
|
||||
@@ -409,7 +415,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
def temp_dir() -> Generator[str, None, None]:
|
||||
# Create a temporary directory
|
||||
_temp_dir = tempfile.mkdtemp()
|
||||
yield _temp_dir
|
||||
@@ -417,6 +423,11 @@ def temp_dir():
|
||||
shutil.rmtree(_temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def torch_manual_seed():
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
from transformers import Trainer
|
||||
@@ -428,9 +439,7 @@ def cleanup_monkeypatches():
|
||||
# original_fa2_forward = LlamaFlashAttention2.forward
|
||||
original_llama_attn_forward = LlamaAttention.forward
|
||||
original_llama_forward = LlamaForCausalLM.forward
|
||||
original_trainer_inner_training_loop = (
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
original_trainer_inner_training_loop = Trainer._inner_training_loop
|
||||
original_trainer_training_step = Trainer.training_step
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
@@ -438,9 +447,7 @@ def cleanup_monkeypatches():
|
||||
# LlamaFlashAttention2.forward = original_fa2_forward
|
||||
LlamaAttention.forward = original_llama_attn_forward
|
||||
LlamaForCausalLM.forward = original_llama_forward
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
original_trainer_inner_training_loop
|
||||
)
|
||||
Trainer._inner_training_loop = original_trainer_inner_training_loop
|
||||
Trainer.training_step = original_trainer_training_step
|
||||
|
||||
# Reset other known monkeypatches
|
||||
@@ -476,7 +483,7 @@ def cleanup_monkeypatches():
|
||||
@pytest.fixture
|
||||
def dataset_winglian_tiny_shakespeare(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
|
||||
return datasets.load_from_disk(ds_path)
|
||||
|
||||
@@ -484,7 +491,7 @@ def dataset_winglian_tiny_shakespeare(
|
||||
@pytest.fixture
|
||||
def dataset_tatsu_lab_alpaca(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
@@ -492,7 +499,7 @@ def dataset_tatsu_lab_alpaca(
|
||||
@pytest.fixture
|
||||
def dataset_mhenrichsen_alpaca_2k_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
@@ -500,7 +507,7 @@ def dataset_mhenrichsen_alpaca_2k_test(
|
||||
@pytest.fixture
|
||||
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
|
||||
@@ -511,7 +518,7 @@ def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
@@ -519,7 +526,7 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
|
||||
@@ -527,7 +534,23 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
@pytest.fixture(name="min_base_cfg")
|
||||
def fixture_min_base_cfg():
|
||||
return DictDefault(
|
||||
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||
learning_rate=1e-3,
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
micro_batch_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
|
||||
reason="Not running in CI cache preload",
|
||||
|
||||
Reference in New Issue
Block a user