Release update 20250331 (#2460) [skip ci]

* make torch 2.6.0 the default image

* fix tests against upstream main

* fix attribute access

* use fixture dataset

* fix dataset load

* correct the fixtures + tests

* more fixtures

* add accidentally removed shakespeare fixture

* fix conversion from unittest to pytest class

* nightly main ci caches

* build 12.6.3 cuda base image

* override for fix from huggingface/transformers#37162

* address PR feedback
This commit is contained in:
Wing Lian
2025-04-01 08:47:50 -04:00
committed by GitHub
parent 328d598114
commit e0aba74dd0
17 changed files with 347 additions and 169 deletions

View File

@@ -2,13 +2,8 @@
import json
import logging
import unittest
from pathlib import Path
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
@@ -61,24 +56,13 @@ test_data = {
}
class TestPromptTokenizationStrategies(unittest.TestCase):
class TestPromptTokenizationStrategies:
"""
Test class for prompt tokenization strategies.
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_no_sys_prompt(self):
def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):
"""
tests the interface between the user and assistant parts
"""
@@ -86,7 +70,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
tokenizer_huggyllama_w_special_tokens,
False,
2048,
)
@@ -99,7 +83,8 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100
def test_alpaca(self):
@enable_hf_offline
def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):
"""
tests the interface between the user and assistant parts
"""
@@ -107,7 +92,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
tokenizer_huggyllama_w_special_tokens,
False,
2048,
)
@@ -118,28 +103,17 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx - 1] == -100
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
class TestInstructionWSystemPromptTokenizingStrategy:
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_system_alpaca(self):
def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
strat = InstructionWSystemPromptTokenizingStrategy(
prompter,
self.tokenizer,
tokenizer_huggyllama_w_special_tokens,
False,
2048,
)
@@ -160,18 +134,13 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
assert example["input_ids"][8] == 11889 # USER
class Llama2ChatTokenizationTest(unittest.TestCase):
class Llama2ChatTokenizationTest:
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
# woraround because official Meta repos are not open
def test_llama2_chat_integration(self):
def test_llama2_chat_integration(self, tokenizer_llama2_7b):
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
@@ -186,16 +155,18 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
prompter = Llama2ChatPrompter()
strat = LLama2ChatTokenizingStrategy(
prompter,
self.tokenizer,
tokenizer_llama2_7b,
False,
4096,
)
example = strat.tokenize_prompt(conversation)
for fields in ["input_ids", "attention_mask", "labels"]:
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])
# pytest assert equals
def compare_with_transformers_integration(self):
assert len(example[fields]) == len(tokenized_conversation[fields])
assert example[fields] == tokenized_conversation[fields]
def compare_with_transformers_integration(self, tokenizer_llama2_7b):
# this needs transformers >= v4.31.0
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
from transformers.pipelines.conversational import Conversation
@@ -234,49 +205,27 @@ If a question does not make any sense, or is not factually coherent, explain why
generated_responses=answers,
)
# pylint: disable=W0212
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
self.assertEqual(
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
)
assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
class OrpoTokenizationTest(unittest.TestCase):
class OrpoTokenizationTest:
"""test case for the ORPO tokenization"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer.add_tokens(
[
AddedToken(
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
),
]
)
self.tokenizer = tokenizer
self.dataset = load_dataset(
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
).select([0])
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
def test_orpo_integration(self):
def test_orpo_integration(
self,
tokenizer_mistral_7b_instruct_chatml,
dataset_argilla_ultrafeedback_binarized_preferences_cleaned,
):
ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])
strat = load(
self.tokenizer,
tokenizer_mistral_7b_instruct_chatml,
DictDefault({"train_on_inputs": False}),
DictDefault({"chat_template": "chatml"}),
)
res = strat.tokenize_prompt(self.dataset[0])
res = strat.tokenize_prompt(ds[0])
assert "rejected_input_ids" in res
assert "rejected_labels" in res
assert "input_ids" in res
@@ -295,7 +244,3 @@ class OrpoTokenizationTest(unittest.TestCase):
assert res["prompt_attention_mask"][0] == 1
assert res["prompt_attention_mask"][-1] == 0
if __name__ == "__main__":
unittest.main()