Files
axolotl/tests/conftest.py
Wing Lian b6fc46ada8 Updates for trl 0.16.0 - mostly for GRPO (#2437) [skip ci]
* add grpo scale_rewards config for trl#3135

* options to connect to vllm server directly w grpo trl#3094

* temperature support trl#3029

* sampling/generation kwargs for grpo trl#2989

* make vllm_enable_prefix_caching a config param trl#2900

* grpo multi-step optimizeations trl#2899

* remove overrides for grpo trainer

* bump trl to 0.16.0

* add cli  to start vllm-serve via trl

* call the python module directly

* update to use vllm with 2.6.0 too now and call trl vllm serve from module

* vllm 0.8.1

* use python3

* use sys.executable

* remove context and wait for start

* fixes to make it actually work

* fixes so the grpo tests pass with new vllm paradigm

* explicit host/port and check in start vllm

* make sure that vllm doesn't hang by setting quiet so outouts go to dev null

* also bump bnb to latest release

* add option for wait from cli and nccl debugging for ci

* grpo + vllm test on separate devices for now

* make sure grpo + vllm tests runs single worker since pynccl comms would conflict

* fix cli

* remove wait and add caching for argilla dataset

* refactoring configs

* chore: lint

* add vllm config

* fixup vllm grpo args

* fix one more incorrect schema/config path

* fix another vlllm reference and increase timeout

* make the tests run a bit faster

* change mbsz back so it is correct for grpo

* another change mbsz back so it is correct for grpo

* fixing cli args

* nits

* adding docs

* docs

* include tensor parallel size for vllm in pydantic schema

* moving start_vllm, more docs

* limit output len for grpo vllm

* vllm enable_prefix_caching isn't a bool cli arg

* fix env ordering in tests and also use pid check when looking for vllm

---------

Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-03-31 15:47:11 -04:00

389 lines
12 KiB
Python

"""
shared pytest fixtures
"""
import functools
import importlib
import shutil
import sys
import tempfile
import time
import pytest
import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
) as exc:
if attempt < max_retries - 1:
wait = 2**attempt * delay # in seconds
time.sleep(wait)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
@disable_hf_offline
def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model():
# download the model
snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the dataset
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_dataset():
# download the dataset
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
# download the dataset
snapshot_download_w_retry(
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
)
@pytest.fixture(scope="session", autouse=True)
def download_mlabonne_finetome_100k_dataset():
# download the dataset
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_argilla_distilabel_intel_orca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_fozzie_alpaca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
)
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test",
repo_type="dataset",
revision="ea82cff",
)
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset(
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
)
@pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset
snapshot_download_w_retry(
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_argilla_dpo_pairs_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_tiny_shakespeare_dataset():
# download the dataset
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_huggyllama_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"huggyllama/llama-7b",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-3.2-1B",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_instruct_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B-Instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_35_mini_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_3_medium_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3-medium-128k-instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"casperhansen/mistral-7b-instruct-v0.1-awq",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma_2b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"unsloth/gemma-2b-it",
revision="703fb4a",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma2_9b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/gemma-2-9b-it-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mlx_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama2_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-2-7b-hf",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
@pytest.fixture
def temp_dir():
# Create a temporary directory
_temp_dir = tempfile.mkdtemp()
yield _temp_dir
# Clean up the directory after the test
shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention,
LlamaForCausalLM,
)
# original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
# LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama",),
(
"transformers.models.llama.modeling_llama",
[
# "LlamaFlashAttention2",
"LlamaAttention",
],
),
("transformers.trainer",),
("transformers", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
spec = importlib.util.spec_from_file_location(
module_name, sys.modules[module_name].__file__
)
sys.modules[module_name] = importlib.util.module_from_spec(spec)
spec.loader.exec_module(sys.modules[module_name])
sys.modules[module_name] = importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,
# download_llama_68m_random_model,
# download_qwen_2_5_half_billion_model,
# download_tatsu_lab_alpaca_dataset,
# download_mhenrichsen_alpaca_2k_dataset,
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
# download_mlabonne_finetome_100k_dataset,
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
# download_fozzie_alpaca_dpo_dataset,
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
# download_argilla_dpo_pairs_dataset,
# download_tiny_shakespeare_dataset,
# download_deepseek_model_fixture,
# download_huggyllama_model_fixture,
# download_llama_1b_model_fixture,
# download_llama3_8b_model_fixture,
# download_llama3_8b_instruct_model_fixture,
# download_phi_35_mini_model_fixture,
# download_phi_3_medium_model_fixture,
# download_mistral_7b_model_fixture,
# download_gemma_2b_model_fixture,
# download_gemma2_9b_model_fixture,
# download_mlx_mistral_7b_model_fixture,
# download_llama2_model_fixture,
# ):
# pass