Simplify mistral tokenizer identification (depends on upstream PR)
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
"""Tokenizer loading functionality and associated utils."""
|
"""Tokenizer loading functionality and associated utils"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -9,9 +9,14 @@ import transformers
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage
|
from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
from mistral_common.tokens.tokenizers.mistral import (
|
||||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
MODEL_NAME_TO_TOKENIZER_CLS,
|
||||||
from transformers import AddedToken, AutoTokenizer
|
MistralTokenizer,
|
||||||
|
)
|
||||||
|
from transformers import (
|
||||||
|
AddedToken,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
|
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
|
||||||
@@ -34,51 +39,146 @@ LLAMA_TOKENIZER_CLASSES = {
|
|||||||
"CodeLlamaTokenizer",
|
"CodeLlamaTokenizer",
|
||||||
"CodeLlamaTokenizerFast",
|
"CodeLlamaTokenizerFast",
|
||||||
}
|
}
|
||||||
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
|
|
||||||
MISTRAL_MODEL_TYPES = {"mistral", "mistral3", "mixtral"}
|
|
||||||
|
|
||||||
QWEN_DEFAULT_TOKEN = "<|endoftext|>" # nosec B105
|
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
|
||||||
GPTNEOX_PAD_TOKEN = "[PAD]" # nosec B105
|
|
||||||
|
QWEN_DEFAULT_TOKEN = "<|endoftext|>"
|
||||||
|
GPTNEOX_PAD_TOKEN = "[PAD]"
|
||||||
CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
|
CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
|
||||||
|
|
||||||
|
|
||||||
class MistralTokenizerWrapper:
|
class MistralTokenizerWrapper:
|
||||||
"""
|
"""
|
||||||
Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer
|
Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer interface.
|
||||||
interface. This provides a bridge between Mistral's native tokenizer and axolotl's
|
This provides a bridge between Mistral's native tokenizer and axolotl's expectations.
|
||||||
expectations.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, mistral_tokenizer: "MistralTokenizer", model_id: str):
|
||||||
self,
|
|
||||||
mistral_tokenizer: MistralTokenizer,
|
|
||||||
model_id: str,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
):
|
|
||||||
self.mistral_tokenizer = mistral_tokenizer
|
self.mistral_tokenizer = mistral_tokenizer
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.system_prompt = system_prompt
|
self._system_prompt = None
|
||||||
self.padding_side = "right" # Default padding side
|
self.padding_side = "right" # Default padding side
|
||||||
self.chat_template = None
|
self.chat_template = None
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# Cache token IDs by inspecting the actual tokenizer
|
||||||
|
self._token_ids = self._discover_token_ids()
|
||||||
|
|
||||||
|
# Try to load system prompt if available
|
||||||
|
try:
|
||||||
|
self._system_prompt = self._load_system_prompt(
|
||||||
|
model_id, "SYSTEM_PROMPT.txt"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.debug(f"Could not load system prompt: {e}")
|
||||||
|
|
||||||
|
def _discover_token_ids(self) -> Dict[str, int]:
|
||||||
|
"""Discover the actual token IDs used by this Mistral tokenizer."""
|
||||||
|
token_ids = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(self.mistral_tokenizer, "instruct_tokenizer"):
|
||||||
|
instruct_tokenizer = self.mistral_tokenizer.instruct_tokenizer
|
||||||
|
|
||||||
|
# Get BOS token ID from instruct_tokenizer
|
||||||
|
token_ids["bos_token_id"] = getattr(instruct_tokenizer, "BOS", 1)
|
||||||
|
|
||||||
|
# Get token IDs from the underlying Tekkenizer
|
||||||
|
if hasattr(instruct_tokenizer, "tokenizer"):
|
||||||
|
tekkenizer = instruct_tokenizer.tokenizer
|
||||||
|
|
||||||
|
# Get BOS ID from tekkenizer (should match instruct_tokenizer.BOS)
|
||||||
|
if hasattr(tekkenizer, "bos_id"):
|
||||||
|
token_ids["bos_token_id"] = tekkenizer.bos_id
|
||||||
|
|
||||||
|
# Get vocab size to help find EOS token
|
||||||
|
vocab_size = getattr(tekkenizer, "_vocab_size", None)
|
||||||
|
|
||||||
|
# Check special tokens
|
||||||
|
if hasattr(tekkenizer, "_all_special_tokens"):
|
||||||
|
special_tokens = tekkenizer._all_special_tokens
|
||||||
|
keys = (
|
||||||
|
list(special_tokens.keys())
|
||||||
|
if hasattr(special_tokens, "keys")
|
||||||
|
else special_tokens
|
||||||
|
)
|
||||||
|
LOG.debug(f"Special tokens available: {keys}")
|
||||||
|
|
||||||
|
# Try to find EOS token in special tokens
|
||||||
|
if hasattr(special_tokens, "get"):
|
||||||
|
# Common EOS token patterns
|
||||||
|
for eos_pattern in ["</s>", "<|endoftext|>", "eos", "EOS"]:
|
||||||
|
if eos_pattern in special_tokens:
|
||||||
|
token_ids["eos_token_id"] = special_tokens[
|
||||||
|
eos_pattern
|
||||||
|
]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check special tokens reverse vocab
|
||||||
|
if hasattr(tekkenizer, "_special_tokens_reverse_vocab"):
|
||||||
|
reverse_vocab = tekkenizer._special_tokens_reverse_vocab
|
||||||
|
LOG.debug(f"Reverse special tokens: {reverse_vocab}")
|
||||||
|
|
||||||
|
# Look for common special token IDs
|
||||||
|
for token_id, token_str in reverse_vocab.items():
|
||||||
|
if token_str in ["</s>", "<|endoftext|>"]:
|
||||||
|
token_ids["eos_token_id"] = token_id
|
||||||
|
elif token_str in ["<unk>", "<UNK>"]:
|
||||||
|
token_ids["unk_token_id"] = token_id
|
||||||
|
|
||||||
|
# If we have vocab_size, EOS is often vocab_size - 1 or similar
|
||||||
|
if "eos_token_id" not in token_ids and vocab_size:
|
||||||
|
# Common patterns: EOS could be 2, vocab_size-1, or other values
|
||||||
|
# Let's try a safer approach by checking what tokens decode to
|
||||||
|
for candidate_id in [2, vocab_size - 1, vocab_size - 2]:
|
||||||
|
try:
|
||||||
|
# Try to decode and see if it looks like EOS
|
||||||
|
decoded = tekkenizer.decode([candidate_id])
|
||||||
|
if decoded in ["</s>", "<|endoftext|>", ""]:
|
||||||
|
token_ids["eos_token_id"] = candidate_id
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.debug(f"Could not discover token IDs: {e}")
|
||||||
|
|
||||||
|
# Set reasonable defaults for any missing token IDs
|
||||||
|
token_ids.setdefault("bos_token_id", 1)
|
||||||
|
token_ids.setdefault("eos_token_id", 2)
|
||||||
|
token_ids.setdefault("unk_token_id", 0)
|
||||||
|
token_ids.setdefault(
|
||||||
|
"pad_token_id", token_ids["eos_token_id"]
|
||||||
|
) # Use EOS as pad
|
||||||
|
|
||||||
|
LOG.info(f"Discovered Mistral token IDs: {token_ids}")
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
def _load_system_prompt(self, repo_id: str, filename: str) -> str:
|
||||||
|
"""Load system prompt from HuggingFace Hub"""
|
||||||
|
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
|
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
|
||||||
"""Encode text to token IDs"""
|
"""Encode text to token IDs"""
|
||||||
# For simple string encoding, create a user message
|
if isinstance(text, str):
|
||||||
messages = []
|
# For simple string encoding, create a user message
|
||||||
if self.system_prompt and add_special_tokens:
|
messages = []
|
||||||
messages.append(SystemMessage(content=self.system_prompt))
|
if self._system_prompt and add_special_tokens:
|
||||||
messages.append(UserMessage(content=text))
|
messages.append(SystemMessage(content=self._system_prompt))
|
||||||
|
messages.append(UserMessage(content=text))
|
||||||
|
|
||||||
tokenized = self.mistral_tokenizer.encode_chat_completion(
|
tokenized = self.mistral_tokenizer.encode_chat_completion(
|
||||||
ChatCompletionRequest(messages=messages)
|
ChatCompletionRequest(messages=messages)
|
||||||
)
|
)
|
||||||
return tokenized.tokens
|
return tokenized.tokens
|
||||||
|
else:
|
||||||
|
raise ValueError("MistralTokenizer wrapper only supports string input")
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
token_ids: Union[List[int], torch.Tensor],
|
token_ids: Union[List[int], torch.Tensor],
|
||||||
skip_special_tokens: bool = True, # pylint: disable=unused-argument
|
skip_special_tokens: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Decode token IDs to text"""
|
"""Decode token IDs to text"""
|
||||||
if isinstance(token_ids, torch.Tensor):
|
if isinstance(token_ids, torch.Tensor):
|
||||||
@@ -91,19 +191,28 @@ class MistralTokenizerWrapper:
|
|||||||
return {"input_ids": torch.tensor([tokens])}
|
return {"input_ids": torch.tensor([tokens])}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def special_tokens_reverse_vocab(self):
|
def eos_token_id(self):
|
||||||
# pylint: disable=protected-access
|
return self._token_ids["eos_token_id"]
|
||||||
return (
|
|
||||||
self.mistral_tokenizer.instruct_tokenizer.tokenizer._special_tokens_reverse_vocab
|
@property
|
||||||
)
|
def bos_token_id(self):
|
||||||
|
return self._token_ids["bos_token_id"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_token_id(self):
|
||||||
|
return self._token_ids["pad_token_id"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unk_token_id(self):
|
||||||
|
return self._token_ids["unk_token_id"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eos_token(self):
|
def eos_token(self):
|
||||||
return SpecialTokens.eos
|
return "</s>" # Standard Mistral EOS token
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bos_token(self):
|
def bos_token(self):
|
||||||
return SpecialTokens.bos
|
return "<s>" # Standard Mistral BOS token
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pad_token(self):
|
def pad_token(self):
|
||||||
@@ -111,25 +220,16 @@ class MistralTokenizerWrapper:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def unk_token(self):
|
def unk_token(self):
|
||||||
return SpecialTokens.unk
|
return "<unk>" # Standard UNK token
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eos_token_id(self):
|
def __class__(self):
|
||||||
return self.special_tokens_reverse_vocab[self.eos_token]
|
# Create a mock class for compatibility checks
|
||||||
|
class MistralTokenizerWrapperClass:
|
||||||
|
__name__ = "MistralTokenizerWrapper"
|
||||||
|
|
||||||
@property
|
return MistralTokenizerWrapperClass
|
||||||
def bos_token_id(self):
|
|
||||||
return self.special_tokens_reverse_vocab[self.bos_token]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pad_token_id(self):
|
|
||||||
return self.special_tokens_reverse_vocab[self.pad_token]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def unk_token_id(self):
|
|
||||||
return self.special_tokens_reverse_vocab[self.unk_token]
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
|
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
|
||||||
"""Placeholder for special token addition - Mistral tokenizer handles this internally"""
|
"""Placeholder for special token addition - Mistral tokenizer handles this internally"""
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -137,7 +237,6 @@ class MistralTokenizerWrapper:
|
|||||||
)
|
)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def add_tokens(self, tokens) -> int:
|
def add_tokens(self, tokens) -> int:
|
||||||
"""Placeholder for token addition - Mistral tokenizer handles this internally"""
|
"""Placeholder for token addition - Mistral tokenizer handles this internally"""
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -267,102 +366,43 @@ class TokenizerConfiguration:
|
|||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.model_config = load_model_config(cfg)
|
self.model_config = load_model_config(cfg)
|
||||||
|
|
||||||
def should_use_mistral_tokenizer(self) -> bool:
|
def detect_by_model_name_mapping(self) -> bool:
|
||||||
"""Determine if Mistral tokenizer should be used."""
|
model_path = getattr(self.cfg, "model_name_or_path", "") or getattr(
|
||||||
return self.model_config.model_type in MISTRAL_MODEL_TYPES
|
self.cfg, "base_model", ""
|
||||||
|
)
|
||||||
|
if not model_path:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Extract model name from path
|
||||||
|
model = model_path.split("/")[-1]
|
||||||
|
|
||||||
|
for model_name in MODEL_NAME_TO_TOKENIZER_CLS.keys():
|
||||||
|
if model_name in model.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
||||||
"""Load Mistral tokenizer from model configuration."""
|
"""Load Mistral tokenizer from model configuration."""
|
||||||
# Try to find the appropriate tokenizer file
|
model_id = getattr(self.cfg, "model_name_or_path", None) or getattr(
|
||||||
model_id = self.cfg.base_model
|
self.cfg, "base_model", None
|
||||||
tokenizer_file = self._find_mistral_tokenizer_file(model_id)
|
)
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError(
|
||||||
|
"model_name_or_path or base_model must be specified for Mistral tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
# Load the Mistral tokenizer and wrap for transformers compatibility
|
# First try to use the model name mapping for direct instantiation
|
||||||
mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file)
|
model_name = model_id.split("/")[-1] # Extract model name from path
|
||||||
|
tokenizer_factory = MODEL_NAME_TO_TOKENIZER_CLS[model_name]
|
||||||
|
mistral_tokenizer = tokenizer_factory()
|
||||||
|
|
||||||
|
# Wrap it for compatibility
|
||||||
wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
wrapped_tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
||||||
|
|
||||||
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
||||||
return wrapped_tokenizer
|
return wrapped_tokenizer
|
||||||
|
|
||||||
def _find_mistral_tokenizer_file(self, model_id: str) -> str:
|
|
||||||
"""Find the appropriate tokenizer file for the given model."""
|
|
||||||
# Generate all possible SentencePiece suffixes based on mistral_common patterns
|
|
||||||
instruct_versions = ["v1", "v2", "v3", "v7"]
|
|
||||||
mm_versions = ["m1", ""] # multimodal versions + empty string
|
|
||||||
|
|
||||||
# Create all possible .model file suffixes
|
|
||||||
sentencepiece_suffixes = [
|
|
||||||
f".model.{v}{m}" for v in instruct_versions for m in mm_versions
|
|
||||||
] + [".model"]
|
|
||||||
|
|
||||||
# List of tokenizer files to try, in order of preference
|
|
||||||
candidate_files = ["tekken.json"] # Try Tekken first
|
|
||||||
|
|
||||||
# Add SentencePiece candidates in preference order (newer versions first)
|
|
||||||
preferred_sp_files = [
|
|
||||||
"tokenizer.model.v7", # Latest instruction version
|
|
||||||
"tokenizer.model.v7m1", # Latest with multimodal
|
|
||||||
"tokenizer.model.v3", # Common version
|
|
||||||
"tokenizer.model.v3m1", # v3 with multimodal
|
|
||||||
"tokenizer.model.v2", # Older version
|
|
||||||
"tokenizer.model.v2m1", # v2 with multimodal
|
|
||||||
"tokenizer.model.v1", # Oldest versioned
|
|
||||||
"tokenizer.model.v1m1", # v1 with multimodal
|
|
||||||
"tokenizer.model", # Generic fallback
|
|
||||||
]
|
|
||||||
candidate_files.extend(preferred_sp_files)
|
|
||||||
|
|
||||||
# Try each candidate file
|
|
||||||
for filename in candidate_files:
|
|
||||||
try:
|
|
||||||
tokenizer_file = hf_hub_download(repo_id=model_id, filename=filename)
|
|
||||||
LOG.debug(f"Found tokenizer file: {filename}")
|
|
||||||
return tokenizer_file
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If no standard files found, try to list and find any matching files
|
|
||||||
try:
|
|
||||||
from huggingface_hub import list_repo_files
|
|
||||||
|
|
||||||
repo_files = list_repo_files(repo_id=model_id)
|
|
||||||
|
|
||||||
# Look for any files matching the SentencePiece patterns
|
|
||||||
matching_files = []
|
|
||||||
for repo_file in repo_files:
|
|
||||||
if any(repo_file.endswith(suffix) for suffix in sentencepiece_suffixes):
|
|
||||||
matching_files.append(repo_file)
|
|
||||||
|
|
||||||
if matching_files:
|
|
||||||
# Sort by preference (newer versions first)
|
|
||||||
def sort_key(filename):
|
|
||||||
# Prioritize by version number (v7 > v3 > v2 > v1)
|
|
||||||
if "v7" in filename:
|
|
||||||
return 0
|
|
||||||
elif "v3" in filename:
|
|
||||||
return 1
|
|
||||||
elif "v2" in filename:
|
|
||||||
return 2
|
|
||||||
elif "v1" in filename:
|
|
||||||
return 3
|
|
||||||
else:
|
|
||||||
return 4
|
|
||||||
|
|
||||||
matching_files.sort(key=sort_key)
|
|
||||||
tokenizer_file = hf_hub_download(
|
|
||||||
repo_id=model_id, filename=matching_files[0]
|
|
||||||
)
|
|
||||||
LOG.debug(f"Using discovered tokenizer file: {matching_files[0]}")
|
|
||||||
return tokenizer_file
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
LOG.debug(f"Could not list repo files: {e}")
|
|
||||||
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Could not find suitable tokenizer file for {model_id}. "
|
|
||||||
f"Tried: {', '.join(candidate_files[:5])}... and {len(candidate_files)-5} others"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_tokenizer_class(self):
|
def get_tokenizer_class(self):
|
||||||
"""Get the appropriate tokenizer class."""
|
"""Get the appropriate tokenizer class."""
|
||||||
if self.cfg.tokenizer_type:
|
if self.cfg.tokenizer_type:
|
||||||
@@ -610,17 +650,17 @@ def load_tokenizer(cfg):
|
|||||||
Fully configured tokenizer instance.
|
Fully configured tokenizer instance.
|
||||||
"""
|
"""
|
||||||
# Configure tokenizer parameters
|
# Configure tokenizer parameters
|
||||||
tokenizer_config = TokenizerConfiguration(cfg)
|
config = TokenizerConfiguration(cfg)
|
||||||
|
|
||||||
# Check if we should use Mistral tokenizer
|
# Check if we should use Mistral tokenizer
|
||||||
if tokenizer_config.should_use_mistral_tokenizer():
|
if config.detect_by_model_name_mapping():
|
||||||
tokenizer = tokenizer_config.load_mistral_tokenizer()
|
tokenizer = config.load_mistral_tokenizer()
|
||||||
else:
|
else:
|
||||||
# Standard tokenizer loading
|
# Standard tokenizer loading
|
||||||
tokenizer_cls = tokenizer_config.get_tokenizer_class()
|
tokenizer_cls = config.get_tokenizer_class()
|
||||||
tokenizer_path = tokenizer_config.get_tokenizer_path()
|
tokenizer_path = config.get_tokenizer_path()
|
||||||
use_fast = tokenizer_config.should_use_fast_tokenizer()
|
use_fast = config.should_use_fast_tokenizer()
|
||||||
tokenizer_kwargs = tokenizer_config.get_tokenizer_kwargs()
|
tokenizer_kwargs = config.get_tokenizer_kwargs()
|
||||||
|
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
|
|||||||
Reference in New Issue
Block a user