Release update 20250331 (#2460) [skip ci]
* make torch 2.6.0 the default image * fix tests against upstream main * fix attribute access * use fixture dataset * fix dataset load * correct the fixtures + tests * more fixtures * add accidentally removed shakespeare fixture * fix conversion from unittest to pytest class * nightly main ci caches * build 12.6.3 cuda base image * override for fix from huggingface/transformers#37162 * address PR feedback
This commit is contained in:
@@ -8,11 +8,13 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
|
||||
@@ -48,6 +50,14 @@ def snapshot_download_w_retry(*args, **kwargs):
|
||||
return snapshot_download(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_ds_fixture_bundle():
|
||||
ds_dir = snapshot_download_w_retry(
|
||||
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
|
||||
)
|
||||
return Path(ds_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_smollm2_135m_model():
|
||||
# download the model
|
||||
@@ -108,43 +118,43 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||
)
|
||||
# @pytest.fixture(scope="session", autouse=True)
|
||||
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||
# # download the dataset
|
||||
# snapshot_download_w_retry(
|
||||
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||
# )
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_fozzie_alpaca_dpo_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||
)
|
||||
snapshot_download_w_retry(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
repo_type="dataset",
|
||||
revision="ea82cff",
|
||||
)
|
||||
# @pytest.fixture(scope="session", autouse=True)
|
||||
# def download_fozzie_alpaca_dpo_dataset():
|
||||
# # download the dataset
|
||||
# snapshot_download_w_retry(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||
# )
|
||||
# snapshot_download_w_retry(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
# repo_type="dataset",
|
||||
# revision="ea82cff",
|
||||
# )
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@disable_hf_offline
|
||||
def dataset_fozzie_alpaca_dpo_dataset(
|
||||
download_fozzie_alpaca_dpo_dataset,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@disable_hf_offline
|
||||
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
download_fozzie_alpaca_dpo_dataset,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return load_dataset(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
)
|
||||
# @pytest.fixture(scope="session")
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(scope="session")
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# return load_dataset(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
# )
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -271,7 +281,7 @@ def download_mlx_mistral_7b_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture
|
||||
def download_llama2_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
@@ -281,7 +291,7 @@ def download_llama2_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
download_huggyllama_model_fixture,
|
||||
@@ -292,6 +302,57 @@ def tokenizer_huggyllama(
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama_w_special_tokens(
|
||||
tokenizer_huggyllama,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
tokenizer_huggyllama.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
return tokenizer_huggyllama
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_llama2_7b(
|
||||
download_llama2_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_mistral_7b_instruct(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
||||
tokenizer_mistral_7b_instruct.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer_mistral_7b_instruct.add_tokens(
|
||||
[
|
||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
return tokenizer_mistral_7b_instruct
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
# Create a temporary directory
|
||||
@@ -357,6 +418,60 @@ def cleanup_monkeypatches():
|
||||
globals().pop(module_global, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_winglian_tiny_shakespeare(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
|
||||
return datasets.load_from_disk(ds_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_tatsu_lab_alpaca(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_mhenrichsen_alpaca_2k_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
|
||||
)
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
|
||||
)
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
# def test_load_fixtures(
|
||||
# download_smollm2_135m_model,
|
||||
|
||||
@@ -324,7 +324,7 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_hub_with_revision_with_dpo(
|
||||
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
):
|
||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||
|
||||
@@ -339,12 +339,10 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset:
|
||||
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
@@ -354,7 +352,9 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||
def test_load_local_hub_with_revision(self, tokenizer):
|
||||
def test_load_local_hub_with_revision(
|
||||
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
|
||||
):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
@@ -386,13 +386,23 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_loading_local_dataset_folder(self, tokenizer):
|
||||
|
||||
@@ -238,21 +238,22 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_with_deduplication(
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
self,
|
||||
cfg,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
tokenizer_huggyllama,
|
||||
):
|
||||
"""Verify that loading with deduplication removes duplicates."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
@@ -263,19 +264,20 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_without_deduplication(
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
self,
|
||||
cfg,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
tokenizer_huggyllama,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
|
||||
import pytest
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from datasets import concatenate_datasets
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -27,7 +27,6 @@ class TestBatchedSamplerPacking:
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size, num_workers",
|
||||
[
|
||||
@@ -41,14 +40,17 @@ class TestBatchedSamplerPacking:
|
||||
@pytest.mark.parametrize("sequential", [True, False])
|
||||
@enable_hf_offline
|
||||
def test_packing(
|
||||
self, batch_size, num_workers, tokenizer, max_seq_length, sequential
|
||||
self,
|
||||
dataset_winglian_tiny_shakespeare,
|
||||
batch_size,
|
||||
num_workers,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
sequential,
|
||||
):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
dataset = load_dataset(
|
||||
"winglian/tiny-shakespeare",
|
||||
split="train",
|
||||
)
|
||||
dataset = dataset_winglian_tiny_shakespeare["train"]
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -58,7 +60,7 @@ class TestBatchedSamplerPacking:
|
||||
)
|
||||
ds_cfg = DictDefault(
|
||||
{
|
||||
"field": "Text",
|
||||
"field": "text",
|
||||
}
|
||||
)
|
||||
completion_strategy = load(tokenizer, cfg, ds_cfg)
|
||||
|
||||
@@ -2,13 +2,8 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
InstructionWSystemPromptTokenizingStrategy,
|
||||
@@ -61,24 +56,13 @@ test_data = {
|
||||
}
|
||||
|
||||
|
||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
class TestPromptTokenizationStrategies:
|
||||
"""
|
||||
Test class for prompt tokenization strategies.
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
@@ -86,7 +70,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
@@ -99,7 +83,8 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
assert example["labels"][world_idx] == 3186
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
def test_alpaca(self):
|
||||
@enable_hf_offline
|
||||
def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
@@ -107,7 +92,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
prompter = AlpacaPrompter()
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
@@ -118,28 +103,17 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
|
||||
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
||||
class TestInstructionWSystemPromptTokenizingStrategy:
|
||||
"""
|
||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_system_alpaca(self):
|
||||
def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):
|
||||
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
|
||||
strat = InstructionWSystemPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
@@ -160,18 +134,13 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
||||
assert example["input_ids"][8] == 11889 # USER
|
||||
|
||||
|
||||
class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||
class Llama2ChatTokenizationTest:
|
||||
"""
|
||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
# woraround because official Meta repos are not open
|
||||
|
||||
def test_llama2_chat_integration(self):
|
||||
def test_llama2_chat_integration(self, tokenizer_llama2_7b):
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
@@ -186,16 +155,18 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||
prompter = Llama2ChatPrompter()
|
||||
strat = LLama2ChatTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
tokenizer_llama2_7b,
|
||||
False,
|
||||
4096,
|
||||
)
|
||||
example = strat.tokenize_prompt(conversation)
|
||||
for fields in ["input_ids", "attention_mask", "labels"]:
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
# pytest assert equals
|
||||
|
||||
def compare_with_transformers_integration(self):
|
||||
assert len(example[fields]) == len(tokenized_conversation[fields])
|
||||
assert example[fields] == tokenized_conversation[fields]
|
||||
|
||||
def compare_with_transformers_integration(self, tokenizer_llama2_7b):
|
||||
# this needs transformers >= v4.31.0
|
||||
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
@@ -234,49 +205,27 @@ If a question does not make any sense, or is not factually coherent, explain why
|
||||
generated_responses=answers,
|
||||
)
|
||||
# pylint: disable=W0212
|
||||
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
|
||||
hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
|
||||
|
||||
self.assertEqual(
|
||||
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||
)
|
||||
assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||
|
||||
|
||||
class OrpoTokenizationTest(unittest.TestCase):
|
||||
class OrpoTokenizationTest:
|
||||
"""test case for the ORPO tokenization"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(
|
||||
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
||||
),
|
||||
]
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = load_dataset(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
||||
).select([0])
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
def test_orpo_integration(self):
|
||||
def test_orpo_integration(
|
||||
self,
|
||||
tokenizer_mistral_7b_instruct_chatml,
|
||||
dataset_argilla_ultrafeedback_binarized_preferences_cleaned,
|
||||
):
|
||||
ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])
|
||||
strat = load(
|
||||
self.tokenizer,
|
||||
tokenizer_mistral_7b_instruct_chatml,
|
||||
DictDefault({"train_on_inputs": False}),
|
||||
DictDefault({"chat_template": "chatml"}),
|
||||
)
|
||||
res = strat.tokenize_prompt(self.dataset[0])
|
||||
res = strat.tokenize_prompt(ds[0])
|
||||
assert "rejected_input_ids" in res
|
||||
assert "rejected_labels" in res
|
||||
assert "input_ids" in res
|
||||
@@ -295,7 +244,3 @@ class OrpoTokenizationTest(unittest.TestCase):
|
||||
|
||||
assert res["prompt_attention_mask"][0] == 1
|
||||
assert res["prompt_attention_mask"][-1] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user