update
This commit is contained in:
@@ -35,7 +35,7 @@ LLAMA_TOKENIZER_CLASSES = {
|
|||||||
"CodeLlamaTokenizerFast",
|
"CodeLlamaTokenizerFast",
|
||||||
}
|
}
|
||||||
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
|
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
|
||||||
MISTRAL_MODEL_TYPES = {"mistral", "mistral3"}
|
MISTRAL_MODEL_TYPES = {"mistral", "mistral3", "mixtral"}
|
||||||
|
|
||||||
QWEN_DEFAULT_TOKEN = "<|endoftext|>" # nosec B105
|
QWEN_DEFAULT_TOKEN = "<|endoftext|>" # nosec B105
|
||||||
GPTNEOX_PAD_TOKEN = "[PAD]" # nosec B105
|
GPTNEOX_PAD_TOKEN = "[PAD]" # nosec B105
|
||||||
@@ -269,35 +269,99 @@ class TokenizerConfiguration:
|
|||||||
|
|
||||||
def should_use_mistral_tokenizer(self) -> bool:
|
def should_use_mistral_tokenizer(self) -> bool:
|
||||||
"""Determine if Mistral tokenizer should be used."""
|
"""Determine if Mistral tokenizer should be used."""
|
||||||
# Explicit configuration
|
|
||||||
return self.model_config.model_type in MISTRAL_MODEL_TYPES
|
return self.model_config.model_type in MISTRAL_MODEL_TYPES
|
||||||
|
|
||||||
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
||||||
"""Load Mistral tokenizer from model configuration."""
|
"""Load Mistral tokenizer from model configuration."""
|
||||||
model_id = getattr(self.cfg, "model_name_or_path", None) or getattr(
|
# Try to find the appropriate tokenizer file
|
||||||
self.cfg, "base_model", None
|
model_id = self.cfg.base_model
|
||||||
)
|
tokenizer_file = self._find_mistral_tokenizer_file(model_id)
|
||||||
if not model_id:
|
|
||||||
raise ValueError(
|
|
||||||
"model_name_or_path or base_model must be specified for Mistral tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Load the Mistral tokenizer and wrap for transformers compatibility
|
||||||
|
mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file)
|
||||||
|
wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
||||||
|
|
||||||
|
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
||||||
|
return wrapped_tokenizer
|
||||||
|
|
||||||
|
def _find_mistral_tokenizer_file(self, model_id: str) -> str:
|
||||||
|
"""Find the appropriate tokenizer file for the given model."""
|
||||||
|
# Generate all possible SentencePiece suffixes based on mistral_common patterns
|
||||||
|
instruct_versions = ["v1", "v2", "v3", "v7"]
|
||||||
|
mm_versions = ["m1", ""] # multimodal versions + empty string
|
||||||
|
|
||||||
|
# Create all possible .model file suffixes
|
||||||
|
sentencepiece_suffixes = [
|
||||||
|
f".model.{v}{m}" for v in instruct_versions for m in mm_versions
|
||||||
|
] + [".model"]
|
||||||
|
|
||||||
|
# List of tokenizer files to try, in order of preference
|
||||||
|
candidate_files = ["tekken.json"] # Try Tekken first
|
||||||
|
|
||||||
|
# Add SentencePiece candidates in preference order (newer versions first)
|
||||||
|
preferred_sp_files = [
|
||||||
|
"tokenizer.model.v7", # Latest instruction version
|
||||||
|
"tokenizer.model.v7m1", # Latest with multimodal
|
||||||
|
"tokenizer.model.v3", # Common version
|
||||||
|
"tokenizer.model.v3m1", # v3 with multimodal
|
||||||
|
"tokenizer.model.v2", # Older version
|
||||||
|
"tokenizer.model.v2m1", # v2 with multimodal
|
||||||
|
"tokenizer.model.v1", # Oldest versioned
|
||||||
|
"tokenizer.model.v1m1", # v1 with multimodal
|
||||||
|
"tokenizer.model", # Generic fallback
|
||||||
|
]
|
||||||
|
candidate_files.extend(preferred_sp_files)
|
||||||
|
|
||||||
|
# Try each candidate file
|
||||||
|
for filename in candidate_files:
|
||||||
|
try:
|
||||||
|
tokenizer_file = hf_hub_download(repo_id=model_id, filename=filename)
|
||||||
|
LOG.debug(f"Found tokenizer file: {filename}")
|
||||||
|
return tokenizer_file
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If no standard files found, try to list and find any matching files
|
||||||
try:
|
try:
|
||||||
# Download the tekken.json file for the tokenizer
|
from huggingface_hub import list_repo_files
|
||||||
tekken_file = hf_hub_download(repo_id=model_id, filename="tekken.json")
|
|
||||||
|
|
||||||
# Load the Mistral tokenizer
|
repo_files = list_repo_files(repo_id=model_id)
|
||||||
mistral_tokenizer = MistralTokenizer.from_file(tekken_file)
|
|
||||||
|
|
||||||
# Wrap it for compatibility
|
# Look for any files matching the SentencePiece patterns
|
||||||
wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
matching_files = []
|
||||||
|
for repo_file in repo_files:
|
||||||
|
if any(repo_file.endswith(suffix) for suffix in sentencepiece_suffixes):
|
||||||
|
matching_files.append(repo_file)
|
||||||
|
|
||||||
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
if matching_files:
|
||||||
return wrapped_tokenizer
|
# Sort by preference (newer versions first)
|
||||||
|
def sort_key(filename):
|
||||||
|
# Prioritize by version number (v7 > v3 > v2 > v1)
|
||||||
|
if "v7" in filename:
|
||||||
|
return 0
|
||||||
|
elif "v3" in filename:
|
||||||
|
return 1
|
||||||
|
elif "v2" in filename:
|
||||||
|
return 2
|
||||||
|
elif "v1" in filename:
|
||||||
|
return 3
|
||||||
|
else:
|
||||||
|
return 4
|
||||||
|
|
||||||
|
matching_files.sort(key=sort_key)
|
||||||
|
tokenizer_file = hf_hub_download(
|
||||||
|
repo_id=model_id, filename=matching_files[0]
|
||||||
|
)
|
||||||
|
LOG.debug(f"Using discovered tokenizer file: {matching_files[0]}")
|
||||||
|
return tokenizer_file
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.error(f"Failed to load Mistral tokenizer: {e}")
|
LOG.debug(f"Could not list repo files: {e}")
|
||||||
raise
|
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not find suitable tokenizer file for {model_id}. "
|
||||||
|
f"Tried: {', '.join(candidate_files[:5])}... and {len(candidate_files)-5} others"
|
||||||
|
)
|
||||||
|
|
||||||
def get_tokenizer_class(self):
|
def get_tokenizer_class(self):
|
||||||
"""Get the appropriate tokenizer class."""
|
"""Get the appropriate tokenizer class."""
|
||||||
@@ -546,17 +610,17 @@ def load_tokenizer(cfg):
|
|||||||
Fully configured tokenizer instance.
|
Fully configured tokenizer instance.
|
||||||
"""
|
"""
|
||||||
# Configure tokenizer parameters
|
# Configure tokenizer parameters
|
||||||
config = TokenizerConfiguration(cfg)
|
tokenizer_config = TokenizerConfiguration(cfg)
|
||||||
|
|
||||||
# Check if we should use Mistral tokenizer
|
# Check if we should use Mistral tokenizer
|
||||||
if config.should_use_mistral_tokenizer():
|
if tokenizer_config.should_use_mistral_tokenizer():
|
||||||
tokenizer = config.load_mistral_tokenizer()
|
tokenizer = tokenizer_config.load_mistral_tokenizer()
|
||||||
else:
|
else:
|
||||||
# Standard tokenizer loading
|
# Standard tokenizer loading
|
||||||
tokenizer_cls = config.get_tokenizer_class()
|
tokenizer_cls = tokenizer_config.get_tokenizer_class()
|
||||||
tokenizer_path = config.get_tokenizer_path()
|
tokenizer_path = tokenizer_config.get_tokenizer_path()
|
||||||
use_fast = config.should_use_fast_tokenizer()
|
use_fast = tokenizer_config.should_use_fast_tokenizer()
|
||||||
tokenizer_kwargs = config.get_tokenizer_kwargs()
|
tokenizer_kwargs = tokenizer_config.get_tokenizer_kwargs()
|
||||||
|
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
|
|||||||
@@ -1,8 +1,4 @@
|
|||||||
"""
|
"""Test cases for tokenizer loading."""
|
||||||
Test cases for the tokenizer loading
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -13,9 +9,7 @@ from tests.hf_offline_utils import enable_hf_offline
|
|||||||
|
|
||||||
|
|
||||||
class TestTokenizers:
|
class TestTokenizers:
|
||||||
"""
|
"""Test class for the load_tokenizer fn"""
|
||||||
test class for the load_tokenizer fn
|
|
||||||
"""
|
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_default_use_fast(self):
|
def test_default_use_fast(self):
|
||||||
@@ -155,6 +149,50 @@ class TestTokenizers:
|
|||||||
):
|
):
|
||||||
load_tokenizer(cfg)
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
def test_mistral_tokenizer_auto_detection(self):
|
||||||
|
"""Test that Mistral models are auto-detected and use MistralTokenizerWrapper"""
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||||
|
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test_mixtral_tokenizer_auto_detection(self):
|
||||||
unittest.main()
|
"""Test that Mixtral models are auto-detected and use MistralTokenizerWrapper"""
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "model-hub/Mixtral-8x7B-v0.1",
|
||||||
|
"tokenizer_config": "model-hub/Mixtral-8x7B-v0.1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
|
||||||
|
|
||||||
|
def test_mistral_tokenizer_basic_functionality(self):
|
||||||
|
"""Test basic encode/decode functionality of MistralTokenizerWrapper"""
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||||
|
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
|
# Test basic encoding
|
||||||
|
text = "Hello, world!"
|
||||||
|
tokens = tokenizer.encode(text)
|
||||||
|
assert isinstance(tokens, list)
|
||||||
|
assert len(tokens) > 0
|
||||||
|
|
||||||
|
# Test basic decoding
|
||||||
|
decoded = tokenizer.decode(tokens)
|
||||||
|
assert isinstance(decoded, str)
|
||||||
|
|
||||||
|
# Test token properties are accessible
|
||||||
|
assert hasattr(tokenizer, "eos_token_id")
|
||||||
|
assert hasattr(tokenizer, "bos_token_id")
|
||||||
|
assert isinstance(tokenizer.eos_token_id, int)
|
||||||
|
assert isinstance(tokenizer.bos_token_id, int)
|
||||||
|
|||||||
Reference in New Issue
Block a user