This commit is contained in:
Dan Saunders
2025-05-29 20:04:35 +00:00
parent 9581a9efed
commit b1570ed0fa
2 changed files with 138 additions and 36 deletions

View File

@@ -1,8 +1,4 @@
"""
Test cases for the tokenizer loading
"""
import unittest
"""Test cases for tokenizer loading."""
import pytest
@@ -13,9 +9,7 @@ from tests.hf_offline_utils import enable_hf_offline
class TestTokenizers:
"""
test class for the load_tokenizer fn
"""
"""Test class for the load_tokenizer fn"""
@enable_hf_offline
def test_default_use_fast(self):
@@ -155,6 +149,50 @@ class TestTokenizers:
):
load_tokenizer(cfg)
def test_mistral_tokenizer_auto_detection(self):
"""Test that Mistral models are auto-detected and use MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
if __name__ == "__main__":
unittest.main()
def test_mixtral_tokenizer_auto_detection(self):
"""Test that Mixtral models are auto-detected and use MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "model-hub/Mixtral-8x7B-v0.1",
"tokenizer_config": "model-hub/Mixtral-8x7B-v0.1",
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
def test_mistral_tokenizer_basic_functionality(self):
"""Test basic encode/decode functionality of MistralTokenizerWrapper"""
cfg = DictDefault(
{
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
}
)
tokenizer = load_tokenizer(cfg)
# Test basic encoding
text = "Hello, world!"
tokens = tokenizer.encode(text)
assert isinstance(tokens, list)
assert len(tokens) > 0
# Test basic decoding
decoded = tokenizer.decode(tokens)
assert isinstance(decoded, str)
# Test token properties are accessible
assert hasattr(tokenizer, "eos_token_id")
assert hasattr(tokenizer, "bos_token_id")
assert isinstance(tokenizer.eos_token_id, int)
assert isinstance(tokenizer.bos_token_id, int)