* feat: add cut_cross_entropy * fix: add to input * fix: remove from setup.py * feat: refactor into an integration * chore: ignore lint * feat: add test for cce * fix: set max_steps for liger test * chore: Update base model following suggestion Co-authored-by: Wing Lian <wing.lian@gmail.com> * chore: update special_tokens following suggestion Co-authored-by: Wing Lian <wing.lian@gmail.com> * chore: remove with_temp_dir following comments * fix: plugins aren't loaded * chore: update quotes in error message * chore: lint * chore: lint * feat: enable FA on test * chore: refactor get_pytorch_version * fix: lock cce commit version * fix: remove subclassing UT * fix: downcast even if not using FA and config check * feat: add test to check different attentions * feat: add install to CI * chore: refactor to use parametrize for attention * fix: pytest not detecting test * feat: handle torch lower than 2.4 * fix args/kwargs to match docs * use release version cut-cross-entropy==24.11.4 * fix quotes * fix: use named params for clarity for modal builder * fix: handle install from pip * fix: test check only top level module install * fix: re-add import check * uninstall existing version if no transformers submodule in cce * more dataset fixtures into the cache --------- Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
77 lines
2.0 KiB
Python
77 lines
2.0 KiB
Python
"""
|
|
shared pytest fixtures
|
|
"""
|
|
import shutil
|
|
import tempfile
|
|
|
|
import pytest
|
|
from huggingface_hub import snapshot_download
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_smollm2_135m_model():
|
|
# download the model
|
|
snapshot_download("HuggingFaceTB/SmolLM2-135M")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_llama_68m_random_model():
|
|
# download the model
|
|
snapshot_download("JackFram/llama-68m")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_qwen_2_5_half_billion_model():
|
|
# download the model
|
|
snapshot_download("Qwen/Qwen2.5-0.5B")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_tatsu_lab_alpaca_dataset():
|
|
# download the dataset
|
|
snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mhenrichsen_alpaca_2k_dataset():
|
|
# download the dataset
|
|
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
|
# download the dataset
|
|
snapshot_download(
|
|
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
|
|
)
|
|
|
|
|
|
def download_mlabonne_finetome_100k_dataset():
|
|
# download the dataset
|
|
snapshot_download("mlabonne/FineTome-100k", repo_type="dataset")
|
|
|
|
|
|
@pytest.fixture
|
|
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
|
# download the dataset
|
|
snapshot_download(
|
|
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
|
# download the dataset
|
|
snapshot_download(
|
|
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dir():
|
|
# Create a temporary directory
|
|
_temp_dir = tempfile.mkdtemp()
|
|
yield _temp_dir
|
|
# Clean up the directory after the test
|
|
shutil.rmtree(_temp_dir)
|