Feat: Add voxtral, magistral small 1.1, and misc gemma3n fixes (#2979)

* fix: lock version in gemma3n docs

* feat: add sample configs and docs

* chore: move mistraltokenizer into mistral folder

* feat: update instructions

* feat: add dynamic load voxtral

* fix: remove incorrect vision config, add audio

* fix: support voxtral processing strategy and address none in data

* feat: patch mistraltokenizer subclass upstream and add missing

* feat: update cce commit to include voxtral

* fix: remove old comment

* fix: gemma3 patch not needed anymore

* fix: voxtral modeling code

* fix: remove incorrect ds path

* fix: adjust apply chat template parsing

* feat: enable voxtral patch

* fix: patch

* feat: update example datasets

* fix: target layer

* feat: update gemma3n docs

* feat: update voxtral docs

* feat: revert assistant parsing to rely on new upstream changes

* chore: skip test till next PR fix

* fix: override upstream decode due to missing handling

* feat: update readme

* fix: update

* feat: add magistral small think support

* feat: update mistral-common dep

* fix: lint

* fix: remove optional dep

* chore: typing

* chore: simply import

* feat(doc): update differences for 2507

* fix: coderrabbit comments

* feat: update clarify docs on new transformers
This commit is contained in:
NanoCode012
2025-07-30 15:57:05 +07:00
committed by GitHub
parent 1d2aa1e467
commit 90e5598930
29 changed files with 771 additions and 695 deletions

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@631d646"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"
```
## Usage

View File

@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@631d646"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"`'
)

View File

@@ -21,3 +21,11 @@ MULTIMODAL_AUTO_MODEL_MAPPING = {
"gemma3": Gemma3ForConditionalGeneration,
"gemma3n": Gemma3nForConditionalGeneration,
}
try:
from transformers import VoxtralForConditionalGeneration
# transformers >4.53.2
MULTIMODAL_AUTO_MODEL_MAPPING["voxtral"] = VoxtralForConditionalGeneration
except ImportError:
pass

View File

@@ -64,12 +64,12 @@ class PatchManager:
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_voxtral_patches()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
@@ -253,15 +253,6 @@ class PatchManager:
has_remote_code=has_remote_code,
)
def _apply_gemma3_conditional_generation_forward_patch(self):
"""Apply gemma3 conditional generation forward patch."""
if self.model_config.model_type in ["gemma3", "gemma3_text"]:
from axolotl.monkeypatch.models.gemma3.modeling import (
patch_gemma3_conditional_generation_forward,
)
patch_gemma3_conditional_generation_forward()
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
@@ -285,6 +276,15 @@ class PatchManager:
cfg_num_shards=self.cfg.tiled_mlp_num_shards,
)
def _apply_voxtral_patches(self):
"""Apply patches for Voxtral model."""
if self.cfg.model_config_type == "voxtral":
from axolotl.monkeypatch.models.voxtral.modeling import (
patch_voxtral_conditional_generation_forward,
)
patch_voxtral_conditional_generation_forward()
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):

View File

@@ -124,7 +124,12 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
def _load_mistral_common_tokenizer(cfg: DictDefault):
"""Load mistral-common tokenizer"""
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
from transformers import tokenization_mistral_common
from axolotl.utils.mistral import HFMistralTokenizer
# patch
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)

View File

@@ -1,16 +0,0 @@
"""Monkeypatch for gemma3 conditional generation forward to fix high loss"""
def patch_gemma3_conditional_generation_forward():
# Remove when https://github.com/huggingface/transformers/pull/37208 merged
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3ForConditionalGeneration,
)
setattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs", False)
def unpatch():
delattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs")
return unpatch

View File

@@ -0,0 +1,67 @@
"""Monkeypatch for voxtral to fix leaf node and dtype mismatch"""
from typing import Optional, Union
import torch
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def patch_voxtral_conditional_generation_forward():
from transformers.models.voxtral.modeling_voxtral import (
VoxtralForConditionalGeneration,
)
# Store the original forward method
old_forward = VoxtralForConditionalGeneration.forward
def _forward(
self,
input_ids: Optional[torch.LongTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if input_features is not None:
audio_embeds = self.get_audio_embeds(input_features)
# Cast audio_embeds to match inputs_embeds dtype
audio_embeds = audio_embeds.to(inputs_embeds.dtype)
# replace text-audio token placeholders with audio embeddings
audio_token_mask = input_ids == self.config.audio_token_id
inputs_embeds = inputs_embeds.clone()
inputs_embeds[audio_token_mask] = audio_embeds
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
return outputs
# Apply the patch
VoxtralForConditionalGeneration.forward = _forward
def unpatch():
"""Restore the original forward method"""
VoxtralForConditionalGeneration.forward = old_forward
return unpatch

View File

@@ -6,9 +6,10 @@ from typing import Optional
from PIL import Image, ImageOps
from PIL.Image import Resampling
from torch import Tensor, zeros_like
from transformers import ProcessorMixin
from transformers import ProcessorMixin, VoxtralProcessor
from transformers.image_utils import load_image
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@@ -204,7 +205,7 @@ class ProcessingStrategy:
}
)
processed_examples.append(processed_example)
processed_examples.append(remove_none_values(processed_example))
return processed_examples
@@ -366,6 +367,34 @@ class Gemma3nProcessingStrategy(ProcessingStrategy):
return labels
class VoxtralProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Voxtral"""
def __init__(
self,
processor: VoxtralProcessor,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
special_ids = (
processor.tokenizer.tokenizer.instruct_tokenizer.audio_encoder.special_ids
)
self.audio_token = special_ids.audio
self.begin_audio_token = special_ids.begin_audio
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.processor.tokenizer.pad_token_id] = -100
labels[labels == self.audio_token] = -100
labels[labels == self.begin_audio_token] = -100
return labels
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -395,4 +424,10 @@ def get_processing_strategy(
return ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if isinstance(processor, VoxtralProcessor):
return VoxtralProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
raise ValueError(f"Unsupported chat template type: {chat_template_type}")

View File

@@ -14,11 +14,12 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import DatasetConfig
if TYPE_CHECKING:
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
from axolotl.utils.mistral import HFMistralTokenizer
# Configure the logger
LOG = get_logger(__name__)
@@ -379,21 +380,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
Public method that can handle either a single prompt or a batch of prompts.
"""
def _remove_none_values(obj):
"""
Remove null from a dictionary-like obj or list.
These can appear due to Dataset loading causing schema merge.
See https://github.com/axolotl-ai-cloud/axolotl/pull/2909
"""
if hasattr(obj, "items"):
return {
k: _remove_none_values(v) for k, v in obj.items() if v is not None
}
if isinstance(obj, list):
return [_remove_none_values(elem) for elem in obj]
return obj
prompt = _remove_none_values(prompt)
prompt = remove_none_values(prompt)
if not self.is_prompt_batched(prompt) or not self.supports_batched:
return self._tokenize_single_prompt(prompt)
@@ -502,6 +489,12 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
# Block multi-content for now
if not isinstance(content, str):
raise ValueError(
"`train_detail` is not supported when `content` is not a string."
)
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
content, train_detail
)

View File

@@ -36,3 +36,16 @@ class DictDefault(Dict):
p[key] = self
object.__delattr__(self, "__parent")
object.__delattr__(self, "__key")
def remove_none_values(obj):
"""
Remove null from a dictionary-like obj or list.
These can appear due to Dataset loading causing schema merge.
See https://github.com/axolotl-ai-cloud/axolotl/pull/2909
"""
if hasattr(obj, "items"):
return {k: remove_none_values(v) for k, v in obj.items() if v is not None}
if isinstance(obj, list):
return [remove_none_values(elem) for elem in obj]
return obj

View File

@@ -0,0 +1,5 @@
"""Init for `axolotl.utils.mistral` module."""
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
__all__ = ["HFMistralTokenizer"]

View File

@@ -0,0 +1,220 @@
"""Wrapper for MistralTokenizer from mistral-common"""
import os
from typing import Optional
import numpy as np
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.utils import download_tokenizer_from_hf_hub
from torch import Tensor
from transformers.tokenization_mistral_common import MistralCommonTokenizer
from transformers.tokenization_utils_base import VERY_LARGE_INTEGER
class HFMistralTokenizer(MistralCommonTokenizer):
"""
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
and exposes HuggingFace API for special tokens.
"""
def __init__(self, name_or_path: str, **kwargs):
"""
Args:
name_or_path: The name or path to the tokenizer files or the repo id.
**kwargs: Additional keyword arguments passed to the parent class.
"""
kwargs.pop("mode", None)
mode = ValidationMode.finetuning
super().__init__(**kwargs, mode=mode)
self._name_or_path = name_or_path
# set mode as is not set upstream
self._set_mode(mode)
@property
def name_or_path(self) -> str:
return self._name_or_path
@property
def chat_template(self) -> str | None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return "[This is a dummy chat template]"
def _set_mode(self, mode: ValidationMode):
"""Set the mode of the MistralRequestValidator.
Args:
mode: The mode to set.
Raises:
RuntimeError: If the MistralRequestValidator does not have a _mode attribute.
"""
# Check if MistralRequestValidator has a _mode attribute.
# This is a private API and may change in the future.
# pylint: disable=protected-access
from mistral_common.protocol.instruct.validator import MistralRequestValidator
if not (
hasattr(self.tokenizer, "_chat_completion_request_validator")
and isinstance(
self.tokenizer._chat_completion_request_validator,
MistralRequestValidator,
)
and hasattr(self.tokenizer._chat_completion_request_validator, "_mode")
):
raise RuntimeError(
f"Unable to switch mistral tokenizer to {mode.value} mode - "
"private API `_chat_completion_request_validator._mode` missing."
)
self.tokenizer._chat_completion_request_validator._mode = mode
def apply_chat_template( # type: ignore
self,
conversation: list[dict] | list[list[dict]],
chat_template: str | None = None, # pylint: disable=unused-argument
add_generation_prompt: bool = False,
**kwargs,
) -> str | list[int]:
"""Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg"""
try:
if add_generation_prompt:
self._set_mode(ValidationMode.serving)
kwargs["continue_final_message"] = True
out = super().apply_chat_template(conversation, **kwargs)
return out # type: ignore
finally:
if add_generation_prompt:
self._set_mode(ValidationMode.finetuning)
def decode( # type: ignore
self,
token_ids: int | list[int] | np.ndarray | Tensor,
**kwargs,
) -> str:
"""
Decode token_ids into str.
This overrides upstream.decode to convert int to list[int]
"""
if isinstance(token_ids, int):
token_ids = [token_ids]
return super().decode(token_ids, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike,
*init_inputs,
mode: ValidationMode = ValidationMode.test,
cache_dir: Optional[str | os.PathLike] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[str | bool] = None,
revision: str = "main",
model_max_length: int = VERY_LARGE_INTEGER,
padding_side: str = "left",
truncation_side: str = "right",
model_input_names: Optional[list[str]] = None,
clean_up_tokenization_spaces: bool = False,
**kwargs,
):
r"""
Patched fn to pass `name_or_path` and remove extra kwargs.
Instantiate a `MistralCommonTokenizer` from a predefined
tokenizer.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
- A path to a *directory* containing the tokenizer config, for instance saved
using the [`MistralCommonTokenizer.tokenization_mistral_common.save_pretrained`] method, e.g.,
`./my_model_directory/`.
mode (`ValidationMode`, *optional*, defaults to `ValidationMode.test`):
Validation mode for the `MistralTokenizer` tokenizer.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download the vocabulary files and override the cached versions if they
exist.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
max_length (`int`, *optional*):
Controls the maximum length to use by one of the truncation/padding parameters.
If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
is required by one of the truncation/padding parameters. If the model has no specific maximum input
length (like XLNet) truncation/padding to a maximum length will be deactivated.
padding_side (`str`, *optional*, defaults to `"left"`):
The side on which the model should have padding applied. Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
truncation_side (`str`, *optional*, defaults to `"right"`):
The side on which the model should have truncation applied. Should be selected between ['right', 'left'].
model_input_names (`List[string]`, *optional*):
The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or
`"attention_mask"`). Default value is picked from the class attribute of the same name.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process.
kwargs (additional keyword arguments, *optional*):
Not supported by `MistralCommonTokenizer.from_pretrained`.
Will raise an error if used.
"""
if init_inputs:
raise ValueError(
"`init_inputs` are not supported by `MistralCommonTokenizer.from_pretrained`."
)
# Delete trust_remote_code as it does nothing
kwargs.pop("trust_remote_code", None)
# Delete tokenizer as it does nothing
kwargs.pop("tokenizer", None)
# Handle kwargs and AutoTokenizer case
if kwargs and not kwargs.keys() == {"_from_auto"}:
raise ValueError(
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.from_pretrained`."
)
if not os.path.isfile(pretrained_model_name_or_path):
tokenizer_path = download_tokenizer_from_hf_hub(
repo_id=str(pretrained_model_name_or_path),
cache_dir=str(cache_dir),
token=token,
revision=revision,
force_download=force_download,
local_files_only=local_files_only,
)
else:
tokenizer_path = str(pretrained_model_name_or_path)
return cls(
name_or_path=str(pretrained_model_name_or_path),
tokenizer_path=tokenizer_path,
mode=mode,
model_max_length=model_max_length,
padding_side=padding_side,
truncation_side=truncation_side,
model_input_names=model_input_names,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)

View File

@@ -1,627 +0,0 @@
"""Wrapper for MistralTokenizer from mistral-common"""
import math
import os
from shutil import copyfile
from typing import Optional
import numpy as np
from huggingface_hub import hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
from torch import Tensor
from transformers.utils import PaddingStrategy
from axolotl.utils.collators.core import IGNORE_INDEX
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
"""Get the file path from local or HF Hub"""
if os.path.exists(path_or_repo_id):
maybe_file_path = os.path.join(path_or_repo_id, filename)
if os.path.exists(maybe_file_path):
return maybe_file_path
raise FileNotFoundError(f"File not found at {path_or_repo_id}")
return hf_hub_download(repo_id=path_or_repo_id, filename=filename)
class HFMistralTokenizer:
"""
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
and exposes HuggingFace API for special tokens.
"""
def __init__(
self, mistral: MistralTokenizer, name_or_path: str, tokenizer_path: str
):
"""
Args:
mistral: The mistral-common tokenizer to wrap.
name_or_path: The name or path to the tokenizer files or the repo id.
"""
self._mistral = mistral
self._padding_side = "right"
self._name_or_path = name_or_path
self._tokenizer_path = tokenizer_path
# Manual set to training mode
from mistral_common.protocol.instruct.validator import (
MistralRequestValidator,
ValidationMode,
)
# Check if MistralRequestValidator has a _mode attribute.
# This is a private API and may change in the future.
# pylint: disable=protected-access
if not (
hasattr(self._mistral, "_chat_completion_request_validator")
and isinstance(
self._mistral._chat_completion_request_validator,
MistralRequestValidator,
)
and hasattr(self._mistral._chat_completion_request_validator, "_mode")
):
raise RuntimeError(
"Unable to switch mistral tokenizer to finetuning mode "
"private API `_chat_completion_request_validator._mode` missing."
)
self._mistral._chat_completion_request_validator._mode = (
ValidationMode.finetuning
)
def _load_system_prompt(self, path_or_repo_id: str) -> str:
"""Load system prompt from local or HF Hub.
Note: Unused for now as we don't want to explicitly set the system prompt if a user does
not provide one.
Args:
path_or_repo_id: The path to the tokenizer files or the repo id.
Returns:
The system prompt.
"""
file_path = _get_file_path(path_or_repo_id, "SYSTEM_PROMPT.txt")
if not os.path.exists(file_path):
raise FileNotFoundError(f"System prompt file not found at {file_path}")
with open(file_path, "r", encoding="utf-8") as file:
return file.read()
@property
def bos_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.bos_id
@property
def eos_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.eos_id
@property
def pad_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.pad_id
@property
def unk_token_id(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.unk_id
@property
def bos_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.bos_token_id)
@property
def eos_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.eos_token_id)
@property
def pad_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.pad_token_id)
@property
def unk_token(self) -> str:
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.unk_token_id)
@property
def padding_side(self) -> str:
return self._padding_side
@property
def name_or_path(self) -> str:
return self._name_or_path
@property
def chat_template(self) -> str | None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return None
def __len__(self) -> int:
return self._mistral.instruct_tokenizer.tokenizer.n_words
@classmethod
def from_pretrained(
cls,
name_or_path: str,
*,
revision: Optional[str] = None,
**kwargs, # pylint: disable=unused-argument
) -> "HFMistralTokenizer":
"""
Load a mistral tekken tokenizer from a local file or HF Hub and wrap it.
Args:
path_or_repo_id: The path to the tokenizer files or the repo id.
revision: The revision of the tokenizer to download.
kwargs: Additional keyword arguments.
Returns:
A HFMistralTokenizer instance.
"""
if revision:
raise NotImplementedError(
"Revision not supported yet for mistral-common tokenizer"
)
# only support Tekken tokenizer for now
# downloads from HF Hub if not local
tokenizer_path = _get_file_path(name_or_path, "tekken.json")
base = MistralTokenizer.from_file(tokenizer_path)
return cls(
base,
name_or_path=name_or_path,
tokenizer_path=tokenizer_path,
)
def save_pretrained(self, save_directory: str) -> None:
"""
Save the Tekken/SentencePiece model file so that from_pretrained can pick it up again.
Only Tekken models are supported.
Args:
save_directory: The directory to save the tokenizer files.
"""
inner = self._mistral.instruct_tokenizer.tokenizer
if isinstance(inner, Tekkenizer):
# Create the directory and save the model
try:
os.makedirs(save_directory, exist_ok=True)
# Verify directory was created
if not os.path.exists(save_directory):
raise RuntimeError(f"Failed to create directory: {save_directory}")
# Verify source file exists
if not os.path.exists(self._tokenizer_path):
raise FileNotFoundError(
f"Source tokenizer file not found: {self._tokenizer_path}"
)
destination_path = os.path.join(save_directory, "tekken.json")
copyfile(self._tokenizer_path, destination_path)
except Exception as e:
raise RuntimeError(
f"Failed to save tokenizer to {save_directory}: {e}. "
f"Source path: {self._tokenizer_path}, "
f"Directory exists: {os.path.exists(save_directory)}"
) from e
else:
raise RuntimeError(f"Unknown tokenizer type: {type(inner)}")
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
"""
Encode a text string into a list of token IDs.
Args:
text: The text string to encode.
add_special_tokens: Whether to add special tokens to the encoded tokens.
Returns:
A list of token IDs.
"""
return self._mistral.instruct_tokenizer.tokenizer.encode(
text,
bos=add_special_tokens,
eos=add_special_tokens,
)
def decode(
self, token_ids: int | list[int], skip_special_tokens: bool = False
) -> str:
"""
Decode a list of token IDs into a text string.
Args:
token_ids: The int or list of token IDs to decode.
skip_special_tokens: Whether to skip special tokens in the decoded text.
Returns:
The decoded text string.
"""
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
)
return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
def apply_chat_template(
self,
messages: list[dict],
tokenize: bool = True,
tools: list[dict] | None = None,
chat_template: str | None = None, # pylint: disable=unused-argument
add_generation_prompt: bool = False, # pylint: disable=unused-argument
) -> list[int] | str:
if chat_template:
raise NotImplementedError("chat_template not supported yet")
if add_generation_prompt:
raise NotImplementedError("add_generation_prompt not supported yet")
chat_completion: ChatCompletionRequest = ChatCompletionRequest.from_openai(
messages, tools
)
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
if tokenize:
return tokens
return self.decode(tokens)
def pad(
self,
features: list[dict[str, list[int] | np.ndarray]],
*,
padding: bool | str | PaddingStrategy = True,
max_length: int | None = None,
pad_to_multiple_of: int | None = None,
return_tensors: str | None = None, # "np", "pt", or "tf"
) -> dict[str, np.ndarray | Tensor]:
"""
HF-style pad method that properly handles all sequence-related features:
- pad 'input_ids' & 'labels' to the longest (or to max_length)
"""
import torch
from torch.nn import functional as F
# Check for unsupported fields
if any("token_type_ids" in f for f in features):
raise ValueError("token_type_ids is not supported by this tokenizer")
# Determine desired sequence length
lengths = [len(f["input_ids"]) for f in features]
if padding in (True, "longest", PaddingStrategy.LONGEST):
target_length = max(lengths)
elif padding in ("max_length", PaddingStrategy.MAX_LENGTH):
if max_length is None:
raise ValueError("max_length must be set for 'max_length' padding")
target_length = max_length
elif padding in (False, "do_not_pad", PaddingStrategy.DO_NOT_PAD):
target_length = None
else:
raise ValueError(f"Unknown padding strategy: {padding}")
# Apply pad_to_multiple_of
if target_length is not None and pad_to_multiple_of is not None:
target_length = (
math.ceil(target_length / pad_to_multiple_of) * pad_to_multiple_of
)
# If no padding requested, just stack tensors
do_pad = target_length is not None
# Pad sequences using torch.nn.utils.rnn.pad_sequence
input_ids = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["input_ids"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=self.pad_token_id if self.pad_token_id is not None else 0,
)
labels = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["labels"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=IGNORE_INDEX,
)
attention_mask = None
if "attention_mask" in features[0]:
attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=0,
)
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
position_ids = None
if "position_ids" in features[0]:
if self.padding_side == "left":
# Likely not needed, but keeping for now
# For left padding, we'll pad with 0s using pad_sequence, then handle manually
position_ids = torch.nn.utils.rnn.pad_sequence(
[
torch.tensor(x["position_ids"], dtype=torch.long)
for x in features
],
batch_first=True,
padding_value=0,
)
else:
# For right padding, continue the sequence
max_pos_len = max(len(f["position_ids"]) for f in features)
position_ids_list = []
for f in features:
pos_seq = torch.tensor(f["position_ids"], dtype=torch.long)
if len(pos_seq) < max_pos_len:
# Continue the sequence
last_pos = pos_seq[-1].item() if len(pos_seq) > 0 else -1
pad_len = max_pos_len - len(pos_seq)
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
pos_seq = torch.cat([pos_seq, pad_positions])
position_ids_list.append(pos_seq)
position_ids = torch.stack(position_ids_list)
# Ensure all tensors have the same sequence length
# Check attention mask and position ids if they are present
tensor_lengths = [input_ids.size(1), labels.size(1)]
if attention_mask is not None:
tensor_lengths.append(attention_mask.size(1))
if position_ids is not None:
tensor_lengths.append(position_ids.size(1))
max_seq_len = max(tensor_lengths)
# TODO: check if trimming is needed? and correct.
if do_pad and target_length is not None:
max_seq_len = target_length
# Pad all tensors to the same length
if input_ids.size(1) < max_seq_len:
pad_len = max_seq_len - input_ids.size(1)
if self.padding_side == "right":
input_ids = F.pad(
input_ids,
(0, pad_len),
value=self.pad_token_id if self.pad_token_id is not None else 0,
)
else:
input_ids = F.pad(
input_ids,
(pad_len, 0),
value=self.pad_token_id if self.pad_token_id is not None else 0,
)
elif input_ids.size(1) > max_seq_len:
input_ids = input_ids[:, :max_seq_len]
if labels.size(1) < max_seq_len:
pad_len = max_seq_len - labels.size(1)
if self.padding_side == "right":
labels = F.pad(labels, (0, pad_len), value=IGNORE_INDEX)
else:
labels = F.pad(labels, (pad_len, 0), value=IGNORE_INDEX)
elif labels.size(1) > max_seq_len:
labels = labels[:, :max_seq_len]
if attention_mask is not None:
if attention_mask.size(1) < max_seq_len:
pad_len = max_seq_len - attention_mask.size(1)
if self.padding_side == "right":
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
else:
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
elif attention_mask.size(1) > max_seq_len:
attention_mask = attention_mask[:, :max_seq_len]
if position_ids is not None:
if position_ids.size(1) < max_seq_len:
pad_len = max_seq_len - position_ids.size(1)
if self.padding_side == "right":
batch_size = position_ids.size(0)
new_position_ids = []
for i in range(batch_size):
seq = position_ids[i]
if len(seq) > 0:
# get last position and pad with sequential values
last_pos = seq[-1].item()
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
new_seq = torch.cat([seq, pad_positions])
else:
new_seq = torch.arange(pad_len, dtype=torch.long)
new_position_ids.append(new_seq)
position_ids = torch.stack(new_position_ids)
else:
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
elif position_ids.size(1) > max_seq_len:
position_ids = position_ids[:, :max_seq_len]
final_batch = {
"input_ids": input_ids,
"labels": labels,
}
if attention_mask is not None:
final_batch["attention_mask"] = attention_mask
if position_ids is not None:
final_batch["position_ids"] = position_ids
# Handle non-sequence fields (raise error)
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
for f in features:
for key in f.keys():
if key not in sequence_fields:
raise NotImplementedError(
f"Non-sequence field {key} not handled yet"
)
# Convert to requested tensor type
if return_tensors is None or return_tensors == "np":
result = {}
for k, v in final_batch.items():
if isinstance(v, torch.Tensor):
result[k] = v.numpy().astype(np.int64)
else:
result[k] = v
return result
if return_tensors == "pt":
return final_batch
raise ValueError(f"Unsupported return_tensors='{return_tensors}'")
def convert_ids_to_tokens(self, ids: list[int]) -> list[str]:
"""
Convert a list of token IDs to a list of tokens.
Args:
ids: The list of token IDs to convert.
Returns:
The list of tokens.
"""
return [
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
]
def __call__(
self,
text: str | list[str],
add_special_tokens: bool = True,
padding: bool | str = False,
truncation: bool = False,
max_length: int | None = None,
return_tensors: str | None = None,
**kwargs,
) -> dict[str, list[int] | np.ndarray | Tensor]:
"""
Tokenize text and return a dictionary with input_ids and attention_mask.
Args:
text: Input text string or list of strings to tokenize.
add_special_tokens: Whether to add special tokens (BOS/EOS).
padding: Whether to pad sequences. Can be True, False, "longest", or "max_length".
truncation: Whether to truncate sequences to max_length.
max_length: Maximum sequence length for truncation/padding.
return_tensors: Return format ("pt" for PyTorch, "np" for NumPy, None for lists).
Returns:
Dictionary with "input_ids" and "attention_mask" keys.
"""
# if kwargs passed, raise error
if kwargs:
raise ValueError(
f"Unsupported kwargs: {kwargs}. Please create an issue on GitHub."
)
# `np` can work with inhomogeneous shapes but let's not support it until needed.
if (
isinstance(text, list)
and len(text) > 1
and return_tensors in ("pt", "np")
and padding is False
and truncation is False
):
raise ValueError(
"return_tensors='pt' or 'np' requires padding or truncation."
)
# Handle single string input
if isinstance(text, str):
text = [text]
# Encode all texts
# TODO: figure out how to parallelize this
batch_input_ids = []
for single_text in text:
input_ids = self.encode(single_text, add_special_tokens=add_special_tokens)
# Handle truncation
if truncation and max_length is not None and len(input_ids) > max_length:
input_ids = input_ids[:max_length]
batch_input_ids.append(input_ids)
# Create attention masks (1 for real tokens, 0 for padding)
attention_masks = [[1] * len(input_ids) for input_ids in batch_input_ids]
# Handle padding
if padding in (True, "longest"):
# Pad to longest sequence in batch
max_len = max(len(input_ids) for input_ids in batch_input_ids)
for i, input_ids in enumerate(batch_input_ids):
pad_length = max_len - len(input_ids)
if pad_length > 0:
if self.padding_side == "right":
batch_input_ids[i] = (
input_ids + [self.pad_token_id] * pad_length
)
attention_masks[i] = attention_masks[i] + [0] * pad_length
else: # left padding
batch_input_ids[i] = [
self.pad_token_id
] * pad_length + input_ids
attention_masks[i] = [0] * pad_length + attention_masks[i]
elif padding == "max_length":
if max_length is None:
raise ValueError(
"max_length must be specified when padding='max_length'"
)
for i, input_ids in enumerate(batch_input_ids):
pad_length = max_length - len(input_ids)
if pad_length > 0:
if self.padding_side == "right":
batch_input_ids[i] = (
input_ids + [self.pad_token_id] * pad_length
)
attention_masks[i] = attention_masks[i] + [0] * pad_length
else: # left padding
batch_input_ids[i] = [
self.pad_token_id
] * pad_length + input_ids
attention_masks[i] = [0] * pad_length + attention_masks[i]
# Prepare result
result = {}
# Handle return tensor format
if return_tensors == "pt":
import torch
result["input_ids"] = torch.tensor(batch_input_ids, dtype=torch.long)
result["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
elif return_tensors == "np":
result["input_ids"] = np.array(batch_input_ids, dtype=np.int64)
result["attention_mask"] = np.array(attention_masks, dtype=np.int64)
elif return_tensors is None:
result["input_ids"] = batch_input_ids
result["attention_mask"] = attention_masks
else:
raise ValueError(
f"Unsupported return_tensors='{return_tensors}'. "
"Only 'pt' and 'np' are supported."
)
# If single input, return single sequences (not batched)
if len(text) == 1 and return_tensors is None:
result["input_ids"] = result["input_ids"][0]
result["attention_mask"] = result["attention_mask"][0]
return result