Feat: Add Magistral and mistral-common tokenizer support (#2780)

This commit is contained in:
NanoCode012
2025-06-12 16:18:33 -07:00
committed by GitHub
parent ace9287c96
commit eac4a61f55
15 changed files with 1213 additions and 14 deletions

View File

@@ -48,6 +48,13 @@ class TokenizedPromptDataset(Dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
LOG.info(
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
)
num_proc = 1
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True

View File

@@ -189,7 +189,7 @@ class KDStrategyLoader(StrategyLoader):
Load ChatTemplateStrategy with KD support using StrategyLoader.
"""
def _get_strategy_cls(self):
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
return ChatTemplateStrategyWithKD
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):

View File

@@ -121,6 +121,19 @@ def modify_tokenizer_files(
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
"""Load and configure the tokenizer based on the provided config."""
def _load_mistral_common_tokenizer(cfg: DictDefault):
"""Load mistral-common tokenizer"""
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
return tokenizer
if cfg.tokenizer_use_mistral_common:
return _load_mistral_common_tokenizer(cfg)
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default

View File

@@ -2,8 +2,10 @@
HF Chat Templates prompt strategy
"""
# pylint: disable=too-many-lines
from collections import defaultdict
from typing import Any, Dict, List, Set, Union
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
from pydantic import BaseModel
from transformers import ProcessorMixin
@@ -15,6 +17,9 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import DatasetConfig
if TYPE_CHECKING:
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
# Configure the logger
LOG = get_logger(__name__)
LOG.setLevel("INFO")
@@ -81,7 +86,7 @@ class ChatTemplatePrompter(Prompter):
def build_prompt(
self,
conversation,
conversation: list[dict],
add_generation_prompt=False,
images=None,
tools=None,
@@ -271,9 +276,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
# Default to eos_token if eot_tokens not provided
self.eot_tokens = (
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
)
self.eot_tokens = []
if eot_tokens is not None:
self.eot_tokens = eot_tokens
elif (
hasattr(self.tokenizer, "eos_token")
and self.tokenizer.eos_token is not None
):
self.eot_tokens = [self.tokenizer.eos_token]
self.split_thinking = split_thinking
self.images = "images"
@@ -796,14 +807,104 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
)
class MistralStrategy(ChatTemplateStrategy):
"""
Mistral strategy for chat template.
"""
def __init__(
self,
prompter: "ChatTemplatePrompter",
tokenizer: "HFMistralTokenizer",
train_on_inputs: bool,
sequence_len: int,
roles_to_train: list[str] | None = None,
train_on_eos: str | None = None,
train_on_eot: str | None = None,
eot_tokens: list[str] | None = None,
split_thinking: bool | None = False,
):
# Call the parent's parent __init__ (PromptTokenizingStrategy) to skip ChatTemplateStrategy's validation
# pylint: disable=non-parent-init-called,super-init-not-called
PromptTokenizingStrategy.__init__(
self, prompter, tokenizer, train_on_inputs, sequence_len
)
self.prompter: ChatTemplatePrompter = prompter
self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]
self.train_on_eos = train_on_eos
# Backward compatibility, load from train_on_eos
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
# Default to eos_token if eot_tokens not provided
self.eot_tokens = []
if eot_tokens is not None:
self.eot_tokens = eot_tokens
else:
# set eot_tokens to the eos_token
self.eot_tokens = [self.tokenizer.eos_token]
self.split_thinking = split_thinking
self.images = "images"
LOG.debug(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)
# Skip the validation that ChatTemplateStrategy calls
# TODO: address this in the future with mistral-specific checks
# self._validate_eot_and_eos_tokens()
@property
def supports_multiprocessing(self) -> bool:
"""
Whether this tokenizing strategy supports multiprocessing.
mistral_common tokenizers cannot be pickled for multiprocessing.
"""
return False
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# mistral-common tokenizer does not support eot_tokens
return self.find_first_eos_token(input_ids, start_idx)
class MistralPrompter(ChatTemplatePrompter):
"""
Mistral prompter for chat template.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._chat_template_msg_variables = set(["tool_call_id", "name", "tool_calls"])
class StrategyLoader:
"""
Load chat template strategy based on configuration.
"""
def _get_strategy_cls(self):
def _get_strategy_cls(self, cfg):
if cfg.tokenizer_use_mistral_common:
return MistralStrategy
return ChatTemplateStrategy
def _get_prompter_cls(self, cfg):
if cfg.tokenizer_use_mistral_common:
return MistralPrompter
return ChatTemplatePrompter
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
return {
"train_on_inputs": cfg.train_on_inputs,
@@ -829,9 +930,14 @@ class StrategyLoader:
else:
dataset_config = ds_cfg
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
if cfg.tokenizer_use_mistral_common:
# mistral-common does not use this, so we pass an empty string
chat_template_string = ""
else:
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
@@ -857,10 +963,11 @@ class StrategyLoader:
}
strategy_params = self._get_strategy_params(cfg, dataset_config)
strategy_cls = self._get_strategy_cls()
strategy_cls = self._get_strategy_cls(cfg)
prompter_cls = self._get_prompter_cls(cfg)
strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params),
prompter_cls(**prompter_params),
tokenizer=tokenizer,
**strategy_params,
)

View File

@@ -70,6 +70,14 @@ class PromptTokenizingStrategy(abc.ABC):
def supports_batched(self):
return False
@property
def supports_multiprocessing(self):
"""
Whether this tokenizing strategy supports multiprocessing.
Should return False if the tokenizer has unpicklable objects.
"""
return True
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:

View File

@@ -0,0 +1,567 @@
"""Wrapper for MistralTokenizer from mistral-common"""
import math
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Optional
import numpy as np
from huggingface_hub import hf_hub_download
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from torch import Tensor
from transformers.utils import PaddingStrategy
from axolotl.utils.collators.core import IGNORE_INDEX
if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import ChatCompletionRequest
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
"""Get the file path from local or HF Hub"""
if os.path.exists(path_or_repo_id):
maybe_file_path = os.path.join(path_or_repo_id, filename)
if os.path.exists(maybe_file_path):
return maybe_file_path
raise FileNotFoundError(f"File not found at {path_or_repo_id}")
return hf_hub_download(repo_id=path_or_repo_id, filename=filename)
class HFMistralTokenizer:
"""
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
and exposes HuggingFace API for special tokens.
"""
def __init__(
self, mistral: MistralTokenizer, name_or_path: str, tokenizer_path: str
):
"""
Args:
mistral: The mistral-common tokenizer to wrap.
name_or_path: The name or path to the tokenizer files or the repo id.
"""
self._mistral = mistral
self._padding_side = "right"
self._name_or_path = name_or_path
self._tokenizer_path = tokenizer_path
# Manual set to training mode
from mistral_common.protocol.instruct.validator import (
MistralRequestValidator,
ValidationMode,
)
# Check if MistralRequestValidator has a _mode attribute.
# This is a private API and may change in the future.
# pylint: disable=protected-access
if not (
hasattr(self._mistral, "_chat_completion_request_validator")
and isinstance(
self._mistral._chat_completion_request_validator,
MistralRequestValidator,
)
and hasattr(self._mistral._chat_completion_request_validator, "_mode")
):
raise RuntimeError(
"Unable to switch mistral tokenizer to finetuning mode "
"private API `_chat_completion_request_validator._mode` missing."
)
self._mistral._chat_completion_request_validator._mode = (
ValidationMode.finetuning
)
def _load_system_prompt(self, path_or_repo_id: str) -> str:
"""Load system prompt from local or HF Hub.
Note: Unused for now as we don't want to explicitly set the system prompt if a user does
not provide one.
Args:
path_or_repo_id: The path to the tokenizer files or the repo id.
Returns:
The system prompt.
"""
file_path = _get_file_path(path_or_repo_id, "SYSTEM_PROMPT.txt")
if not os.path.exists(file_path):
raise FileNotFoundError(f"System prompt file not found at {file_path}")
with open(file_path, "r", encoding="utf-8") as file:
return file.read()
@property
def bos_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.bos_id
@property
def eos_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.eos_id
@property
def pad_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.pad_id
@property
def unk_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.unk_id
@property
def bos_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.bos_token_id)
@property
def eos_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.eos_token_id)
@property
def pad_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.pad_token_id)
@property
def unk_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.unk_token_id)
@property
def padding_side(self) -> str:
return self._padding_side
@property
def name_or_path(self) -> str:
return self._name_or_path
@property
def chat_template(self) -> str | None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return None
def __len__(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.n_words
@classmethod
def from_pretrained(
cls,
name_or_path: str,
*,
revision: Optional[str] = None,
**kwargs, # pylint: disable=unused-argument
) -> "HFMistralTokenizer":
"""
Load a mistral tekken tokenizer from a local file or HF Hub and wrap it.
Args:
path_or_repo_id: The path to the tokenizer files or the repo id.
revision: The revision of the tokenizer to download.
kwargs: Additional keyword arguments.
Returns:
A HFMistralTokenizer instance.
"""
if revision:
raise NotImplementedError(
"Revision not supported yet for mistral-common tokenizer"
)
# only support Tekken tokenizer for now
# downloads from HF Hub if not local
tokenizer_path = _get_file_path(name_or_path, "tekken.json")
base = MistralTokenizer.from_file(tokenizer_path)
return cls(
base,
name_or_path=name_or_path,
tokenizer_path=tokenizer_path,
)
def save_pretrained(self, save_directory: str) -> None:
"""
Save the Tekken/SentencePiece model file so that from_pretrained can pick it up again.
Only Tekken models are supported.
Args:
save_directory: The directory to save the tokenizer files.
"""
inner = self._mistral.instruct_tokenizer.tokenizer
if isinstance(inner, Tekkenizer):
# Create the directory and save the model
try:
os.makedirs(save_directory, exist_ok=True)
# Verify directory was created
if not os.path.exists(save_directory):
raise RuntimeError(f"Failed to create directory: {save_directory}")
# Verify source file exists
if not os.path.exists(self._tokenizer_path):
raise FileNotFoundError(
f"Source tokenizer file not found: {self._tokenizer_path}"
)
destination_path = os.path.join(save_directory, "tekken.json")
copyfile(self._tokenizer_path, destination_path)
except Exception as e:
raise RuntimeError(
f"Failed to save tokenizer to {save_directory}: {e}. "
f"Source path: {self._tokenizer_path}, "
f"Directory exists: {os.path.exists(save_directory)}"
) from e
else:
raise RuntimeError(f"Unknown tokenizer type: {type(inner)}")
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
"""
Encode a text string into a list of token IDs.
Args:
text: The text string to encode.
add_special_tokens: Whether to add special tokens to the encoded tokens.
Returns:
A list of token IDs.
"""
return self._mistral.instruct_tokenizer.tokenizer.encode(
text,
bos=add_special_tokens,
eos=add_special_tokens,
)
def decode(
self, token_ids: int | list[int], skip_special_tokens: bool = False
) -> str:
"""
Decode a list of token IDs into a text string.
Args:
token_ids: The int or list of token IDs to decode.
skip_special_tokens: Whether to skip special tokens in the decoded text.
Returns:
The decoded text string.
"""
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
# to_string returns a string with special tokens
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None
) -> "ChatCompletionRequest":
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
SystemMessage,
ToolMessage,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool
messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = (
[]
)
for turn in conversation:
role = turn.get("role")
if role == "user":
messages.append(UserMessage(content=turn["content"]))
elif role == "assistant":
messages.append(
AssistantMessage(
content=turn.get("content"),
tool_calls=turn.get("tool_calls"),
)
)
elif role == "tool":
messages.append(
ToolMessage(
content=turn.get("content"),
tool_call_id=turn.get("tool_call_id"),
name=turn.get("name"),
)
)
elif role == "system":
messages.append(SystemMessage(content=turn["content"]))
else:
raise ValueError(
f"Unknown role for use with mistral-common tokenizer: {turn['role']}"
)
tool_calls: list[Tool] = []
if tools:
# convert to Tool
for tool in tools:
if tool["type"] != "function":
continue
function = tool["function"]
tool_calls.append(
Tool(
function=Function(
name=function["name"],
description=function["description"],
# set parameters to empty dict if not provided
parameters=function.get("parameters", {}),
)
)
)
chat_completion: ChatCompletionRequest = ChatCompletionRequest(
messages=messages,
tools=tool_calls,
)
return chat_completion
def apply_chat_template(
self,
messages: list[dict],
tokenize: bool = True,
tools: list[dict] | None = None,
chat_template: str | None = None, # pylint: disable=unused-argument
add_generation_prompt: bool = False, # pylint: disable=unused-argument
) -> list[int] | str:
if chat_template:
raise NotImplementedError("chat_template not supported yet")
if add_generation_prompt:
raise NotImplementedError("add_generation_prompt not supported yet")
chat_completion: ChatCompletionRequest = (
self._create_mistral_chat_completion_request(messages, tools)
)
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
if tokenize:
return tokens
return self.decode(tokens)
def pad(
self,
features: list[dict[str, list[int] | np.ndarray]],
*,
padding: bool | str | PaddingStrategy = True,
max_length: int | None = None,
pad_to_multiple_of: int | None = None,
return_tensors: str | None = None, # "np", "pt", or "tf"
) -> dict[str, np.ndarray | Tensor]:
"""
HF-style pad method that properly handles all sequence-related features:
- pad 'input_ids' & 'labels' to the longest (or to max_length)
"""
import torch
from torch.nn import functional as F
# Check for unsupported fields
if any("token_type_ids" in f for f in features):
raise ValueError("token_type_ids is not supported by this tokenizer")
# Determine desired sequence length
lengths = [len(f["input_ids"]) for f in features]
if padding in (True, "longest", PaddingStrategy.LONGEST):
target_length = max(lengths)
elif padding in ("max_length", PaddingStrategy.MAX_LENGTH):
if max_length is None:
raise ValueError("max_length must be set for 'max_length' padding")
target_length = max_length
elif padding in (False, "do_not_pad", PaddingStrategy.DO_NOT_PAD):
target_length = None
else:
raise ValueError(f"Unknown padding strategy: {padding}")
# Apply pad_to_multiple_of
if target_length is not None and pad_to_multiple_of is not None:
target_length = (
math.ceil(target_length / pad_to_multiple_of) * pad_to_multiple_of
)
# If no padding requested, just stack tensors
do_pad = target_length is not None
# Pad sequences using torch.nn.utils.rnn.pad_sequence
input_ids = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["input_ids"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=self.pad_token_id if self.pad_token_id is not None else 0,
)
labels = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["labels"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=IGNORE_INDEX,
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=0,
)
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
if "position_ids" in features[0]:
if self.padding_side == "left":
# Likely not needed, but keeping for now
# For left padding, we'll pad with 0s using pad_sequence, then handle manually
position_ids = torch.nn.utils.rnn.pad_sequence(
[
torch.tensor(x["position_ids"], dtype=torch.long)
for x in features
],
batch_first=True,
padding_value=0,
)
else:
# For right padding, continue the sequence
max_pos_len = max(len(f["position_ids"]) for f in features)
position_ids_list = []
for f in features:
pos_seq = torch.tensor(f["position_ids"], dtype=torch.long)
if len(pos_seq) < max_pos_len:
# Continue the sequence
last_pos = pos_seq[-1].item() if len(pos_seq) > 0 else -1
pad_len = max_pos_len - len(pos_seq)
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
pos_seq = torch.cat([pos_seq, pad_positions])
position_ids_list.append(pos_seq)
position_ids = torch.stack(position_ids_list)
else:
# Create position_ids if not present
seq_len = input_ids.size(1)
position_ids = (
torch.arange(seq_len, dtype=torch.long)
.unsqueeze(0)
.expand(input_ids.size(0), -1)
)
# Ensure all tensors have the same sequence length
max_seq_len = max(
input_ids.size(1),
labels.size(1),
attention_mask.size(1),
position_ids.size(1),
)
# TODO: check if trimming is needed? and correct.
if do_pad and target_length is not None:
max_seq_len = target_length
# Pad all tensors to the same length
if input_ids.size(1) < max_seq_len:
pad_len = max_seq_len - input_ids.size(1)
if self.padding_side == "right":
input_ids = F.pad(
input_ids,
(0, pad_len),
value=self.pad_token_id if self.pad_token_id is not None else 0,
)
else:
input_ids = F.pad(
input_ids,
(pad_len, 0),
value=self.pad_token_id if self.pad_token_id is not None else 0,
)
elif input_ids.size(1) > max_seq_len:
input_ids = input_ids[:, :max_seq_len]
if labels.size(1) < max_seq_len:
pad_len = max_seq_len - labels.size(1)
if self.padding_side == "right":
labels = F.pad(labels, (0, pad_len), value=IGNORE_INDEX)
else:
labels = F.pad(labels, (pad_len, 0), value=IGNORE_INDEX)
elif labels.size(1) > max_seq_len:
labels = labels[:, :max_seq_len]
if attention_mask.size(1) < max_seq_len:
pad_len = max_seq_len - attention_mask.size(1)
if self.padding_side == "right":
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
else:
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
elif attention_mask.size(1) > max_seq_len:
attention_mask = attention_mask[:, :max_seq_len]
if position_ids.size(1) < max_seq_len:
pad_len = max_seq_len - position_ids.size(1)
if self.padding_side == "right":
batch_size = position_ids.size(0)
new_position_ids = []
for i in range(batch_size):
seq = position_ids[i]
if len(seq) > 0:
# get last position and pad with sequential values
last_pos = seq[-1].item()
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
new_seq = torch.cat([seq, pad_positions])
else:
new_seq = torch.arange(pad_len, dtype=torch.long)
new_position_ids.append(new_seq)
position_ids = torch.stack(new_position_ids)
else:
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
elif position_ids.size(1) > max_seq_len:
position_ids = position_ids[:, :max_seq_len]
final_batch = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
# Handle non-sequence fields (raise error)
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
for f in features:
for key in f.keys():
if key not in sequence_fields:
raise NotImplementedError(
f"Non-sequence field {key} not handled yet"
)
# Convert to requested tensor type
if return_tensors is None or return_tensors == "np":
result = {}
for k, v in final_batch.items():
if isinstance(v, torch.Tensor):
result[k] = v.numpy().astype(np.long)
else:
result[k] = v
return result
if return_tensors == "pt":
return final_batch
raise ValueError(f"Unsupported return_tensors='{return_tensors}'")
def convert_ids_to_tokens(self, ids: list[int]) -> list[str]:
"""
Convert a list of token IDs to a list of tokens.
Args:
ids: The list of token IDs to convert.
Returns:
The list of tokens.
"""
return [
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
]

View File

@@ -1265,6 +1265,68 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_tokenizer_use_mistral_common(cls, data):
if data.get("tokenizer_use_mistral_common") is None:
if any(
"magistral" in name.lower()
for name in [
data.get("base_model", ""),
data.get("base_model_config", ""),
data.get("tokenizer_config", ""),
]
):
LOG.warning(
"tokenizer_use_mistral_common auto inferred to True for Magistral models. Please set it to True explicitly if you want to use mistral-common tokenizer."
)
data["tokenizer_use_mistral_common"] = True
return data
@field_validator("tokenizer_use_mistral_common", mode="after")
@classmethod
def check_mistral_common_import(cls, tokenizer_use_mistral_common):
if tokenizer_use_mistral_common:
try:
import mistral_common # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`."
) from exception
return tokenizer_use_mistral_common
@model_validator(mode="before")
@classmethod
def check_mistral_common_incompatible_options(cls, data):
if not data.get("tokenizer_use_mistral_common"):
return data
# NOTE: mistral-common tokenizer is not compatible with editing tokenizer at the moment
if data.get("added_tokens_overrides"):
raise ValueError(
"added_tokens_overrides is not supported with mistral-common tokenizer"
)
if data.get("special_tokens"):
raise ValueError(
"special_tokens override is not supported with mistral-common tokenizer"
)
if data.get("tokens"):
raise ValueError(
"tokens override is not supported with mistral-common tokenizer"
)
if data.get("chat_template"):
raise ValueError(
"Setting chat_template is not supported with mistral-common tokenizer"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -18,6 +18,7 @@ class ModelInputConfig(BaseModel):
tokenizer_config: str | None = None
tokenizer_use_fast: bool | None = None
tokenizer_legacy: bool | None = None
tokenizer_use_mistral_common: bool | None = None
tokenizer_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers tokenizer class"}
)