Feat: Add Magistral and mistral-common tokenizer support (#2780)
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
|
||||
@@ -27,6 +27,8 @@ trust_remote_code:
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer.
|
||||
tokenizer_use_mistral_common:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
|
||||
@@ -1,3 +1,71 @@
|
||||
# Coming Soon!
|
||||
# Finetune Magistral Small with Axolotl
|
||||
|
||||
Watch this space for configs for fine-tuning [Magistral Small 2506](https://huggingface.co/mistralai/Magistral-Small-2506).
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on [HuggingFace](https://huggingface.co/mistralai/Magistral-Small-2506). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
|
||||
|
||||
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 recommended)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
|
||||
```
|
||||
|
||||
2. Download the example config:
|
||||
|
||||
```bash
|
||||
axolotl fetch examples
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/magistral/magistral-small-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 24GB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Limitations
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
|
||||
63
examples/magistral/magistral-small-qlora.yaml
Normal file
63
examples/magistral/magistral-small-qlora.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: mistralai/Magistral-Small-2506
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
@@ -67,3 +67,5 @@ schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
mistral-common==1.6.0
|
||||
|
||||
@@ -48,6 +48,13 @@ class TokenizedPromptDataset(Dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||
|
||||
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
|
||||
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
|
||||
LOG.info(
|
||||
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
|
||||
)
|
||||
num_proc = 1
|
||||
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
@@ -189,7 +189,7 @@ class KDStrategyLoader(StrategyLoader):
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
|
||||
return ChatTemplateStrategyWithKD
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
|
||||
@@ -121,6 +121,19 @@ def modify_tokenizer_files(
|
||||
|
||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
|
||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||
"""Load mistral-common tokenizer"""
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
# Load the HF-compatible wrapper around MistralTokenizer
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
||||
|
||||
return tokenizer
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
return _load_mistral_common_tokenizer(cfg)
|
||||
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
HF Chat Templates prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import ProcessorMixin
|
||||
@@ -15,6 +17,9 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
# Configure the logger
|
||||
LOG = get_logger(__name__)
|
||||
LOG.setLevel("INFO")
|
||||
@@ -81,7 +86,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
conversation,
|
||||
conversation: list[dict],
|
||||
add_generation_prompt=False,
|
||||
images=None,
|
||||
tools=None,
|
||||
@@ -271,9 +276,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
|
||||
|
||||
# Default to eos_token if eot_tokens not provided
|
||||
self.eot_tokens = (
|
||||
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
|
||||
)
|
||||
self.eot_tokens = []
|
||||
if eot_tokens is not None:
|
||||
self.eot_tokens = eot_tokens
|
||||
elif (
|
||||
hasattr(self.tokenizer, "eos_token")
|
||||
and self.tokenizer.eos_token is not None
|
||||
):
|
||||
self.eot_tokens = [self.tokenizer.eos_token]
|
||||
|
||||
self.split_thinking = split_thinking
|
||||
|
||||
self.images = "images"
|
||||
@@ -796,14 +807,104 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
class MistralStrategy(ChatTemplateStrategy):
|
||||
"""
|
||||
Mistral strategy for chat template.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: "ChatTemplatePrompter",
|
||||
tokenizer: "HFMistralTokenizer",
|
||||
train_on_inputs: bool,
|
||||
sequence_len: int,
|
||||
roles_to_train: list[str] | None = None,
|
||||
train_on_eos: str | None = None,
|
||||
train_on_eot: str | None = None,
|
||||
eot_tokens: list[str] | None = None,
|
||||
split_thinking: bool | None = False,
|
||||
):
|
||||
# Call the parent's parent __init__ (PromptTokenizingStrategy) to skip ChatTemplateStrategy's validation
|
||||
# pylint: disable=non-parent-init-called,super-init-not-called
|
||||
PromptTokenizingStrategy.__init__(
|
||||
self, prompter, tokenizer, train_on_inputs, sequence_len
|
||||
)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
|
||||
self.roles_to_train = []
|
||||
if roles_to_train:
|
||||
# map roles if exist in prompter.roles else use the role as is
|
||||
self.roles_to_train = [
|
||||
prompter.roles.get(role, role) for role in roles_to_train
|
||||
]
|
||||
|
||||
self.train_on_eos = train_on_eos
|
||||
# Backward compatibility, load from train_on_eos
|
||||
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
|
||||
|
||||
# Default to eos_token if eot_tokens not provided
|
||||
self.eot_tokens = []
|
||||
if eot_tokens is not None:
|
||||
self.eot_tokens = eot_tokens
|
||||
else:
|
||||
# set eot_tokens to the eos_token
|
||||
self.eot_tokens = [self.tokenizer.eos_token]
|
||||
|
||||
self.split_thinking = split_thinking
|
||||
|
||||
self.images = "images"
|
||||
|
||||
LOG.debug(
|
||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||
)
|
||||
|
||||
# Skip the validation that ChatTemplateStrategy calls
|
||||
# TODO: address this in the future with mistral-specific checks
|
||||
# self._validate_eot_and_eos_tokens()
|
||||
|
||||
@property
|
||||
def supports_multiprocessing(self) -> bool:
|
||||
"""
|
||||
Whether this tokenizing strategy supports multiprocessing.
|
||||
mistral_common tokenizers cannot be pickled for multiprocessing.
|
||||
"""
|
||||
|
||||
return False
|
||||
|
||||
def find_first_eot_token(self, input_ids, start_idx):
|
||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||
# mistral-common tokenizer does not support eot_tokens
|
||||
return self.find_first_eos_token(input_ids, start_idx)
|
||||
|
||||
|
||||
class MistralPrompter(ChatTemplatePrompter):
|
||||
"""
|
||||
Mistral prompter for chat template.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._chat_template_msg_variables = set(["tool_call_id", "name", "tool_calls"])
|
||||
|
||||
|
||||
class StrategyLoader:
|
||||
"""
|
||||
Load chat template strategy based on configuration.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
def _get_strategy_cls(self, cfg):
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
return MistralStrategy
|
||||
|
||||
return ChatTemplateStrategy
|
||||
|
||||
def _get_prompter_cls(self, cfg):
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
return MistralPrompter
|
||||
|
||||
return ChatTemplatePrompter
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
return {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
@@ -829,9 +930,14 @@ class StrategyLoader:
|
||||
else:
|
||||
dataset_config = ds_cfg
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
)
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
# mistral-common does not use this, so we pass an empty string
|
||||
chat_template_string = ""
|
||||
else:
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
@@ -857,10 +963,11 @@ class StrategyLoader:
|
||||
}
|
||||
|
||||
strategy_params = self._get_strategy_params(cfg, dataset_config)
|
||||
strategy_cls = self._get_strategy_cls()
|
||||
strategy_cls = self._get_strategy_cls(cfg)
|
||||
prompter_cls = self._get_prompter_cls(cfg)
|
||||
|
||||
strategy = strategy_cls(
|
||||
ChatTemplatePrompter(**prompter_params),
|
||||
prompter_cls(**prompter_params),
|
||||
tokenizer=tokenizer,
|
||||
**strategy_params,
|
||||
)
|
||||
|
||||
@@ -70,6 +70,14 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def supports_batched(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_multiprocessing(self):
|
||||
"""
|
||||
Whether this tokenizing strategy supports multiprocessing.
|
||||
Should return False if the tokenizer has unpicklable objects.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
|
||||
567
src/axolotl/utils/mistral_tokenizer.py
Normal file
567
src/axolotl/utils/mistral_tokenizer.py
Normal file
@@ -0,0 +1,567 @@
|
||||
"""Wrapper for MistralTokenizer from mistral-common"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from torch import Tensor
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.utils.collators.core import IGNORE_INDEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
|
||||
|
||||
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
|
||||
"""Get the file path from local or HF Hub"""
|
||||
if os.path.exists(path_or_repo_id):
|
||||
maybe_file_path = os.path.join(path_or_repo_id, filename)
|
||||
if os.path.exists(maybe_file_path):
|
||||
return maybe_file_path
|
||||
|
||||
raise FileNotFoundError(f"File not found at {path_or_repo_id}")
|
||||
|
||||
return hf_hub_download(repo_id=path_or_repo_id, filename=filename)
|
||||
|
||||
|
||||
class HFMistralTokenizer:
|
||||
"""
|
||||
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
|
||||
and exposes HuggingFace API for special tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, mistral: MistralTokenizer, name_or_path: str, tokenizer_path: str
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
mistral: The mistral-common tokenizer to wrap.
|
||||
name_or_path: The name or path to the tokenizer files or the repo id.
|
||||
"""
|
||||
self._mistral = mistral
|
||||
self._padding_side = "right"
|
||||
self._name_or_path = name_or_path
|
||||
self._tokenizer_path = tokenizer_path
|
||||
|
||||
# Manual set to training mode
|
||||
from mistral_common.protocol.instruct.validator import (
|
||||
MistralRequestValidator,
|
||||
ValidationMode,
|
||||
)
|
||||
|
||||
# Check if MistralRequestValidator has a _mode attribute.
|
||||
# This is a private API and may change in the future.
|
||||
# pylint: disable=protected-access
|
||||
if not (
|
||||
hasattr(self._mistral, "_chat_completion_request_validator")
|
||||
and isinstance(
|
||||
self._mistral._chat_completion_request_validator,
|
||||
MistralRequestValidator,
|
||||
)
|
||||
and hasattr(self._mistral._chat_completion_request_validator, "_mode")
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Unable to switch mistral tokenizer to finetuning mode – "
|
||||
"private API `_chat_completion_request_validator._mode` missing."
|
||||
)
|
||||
|
||||
self._mistral._chat_completion_request_validator._mode = (
|
||||
ValidationMode.finetuning
|
||||
)
|
||||
|
||||
def _load_system_prompt(self, path_or_repo_id: str) -> str:
|
||||
"""Load system prompt from local or HF Hub.
|
||||
|
||||
Note: Unused for now as we don't want to explicitly set the system prompt if a user does
|
||||
not provide one.
|
||||
|
||||
Args:
|
||||
path_or_repo_id: The path to the tokenizer files or the repo id.
|
||||
|
||||
Returns:
|
||||
The system prompt.
|
||||
"""
|
||||
file_path = _get_file_path(path_or_repo_id, "SYSTEM_PROMPT.txt")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"System prompt file not found at {file_path}")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.bos_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.pad_id
|
||||
|
||||
@property
|
||||
def unk_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.unk_id
|
||||
|
||||
@property
|
||||
def bos_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.bos_token_id)
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.eos_token_id)
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.pad_token_id)
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.unk_token_id)
|
||||
|
||||
@property
|
||||
def padding_side(self) -> str:
|
||||
return self._padding_side
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return self._name_or_path
|
||||
|
||||
@property
|
||||
def chat_template(self) -> str | None:
|
||||
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
|
||||
return None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.n_words
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
name_or_path: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> "HFMistralTokenizer":
|
||||
"""
|
||||
Load a mistral tekken tokenizer from a local file or HF Hub and wrap it.
|
||||
|
||||
Args:
|
||||
path_or_repo_id: The path to the tokenizer files or the repo id.
|
||||
revision: The revision of the tokenizer to download.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
A HFMistralTokenizer instance.
|
||||
"""
|
||||
if revision:
|
||||
raise NotImplementedError(
|
||||
"Revision not supported yet for mistral-common tokenizer"
|
||||
)
|
||||
|
||||
# only support Tekken tokenizer for now
|
||||
# downloads from HF Hub if not local
|
||||
tokenizer_path = _get_file_path(name_or_path, "tekken.json")
|
||||
|
||||
base = MistralTokenizer.from_file(tokenizer_path)
|
||||
|
||||
return cls(
|
||||
base,
|
||||
name_or_path=name_or_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
)
|
||||
|
||||
def save_pretrained(self, save_directory: str) -> None:
|
||||
"""
|
||||
Save the Tekken/SentencePiece model file so that from_pretrained can pick it up again.
|
||||
|
||||
Only Tekken models are supported.
|
||||
|
||||
Args:
|
||||
save_directory: The directory to save the tokenizer files.
|
||||
"""
|
||||
inner = self._mistral.instruct_tokenizer.tokenizer
|
||||
if isinstance(inner, Tekkenizer):
|
||||
# Create the directory and save the model
|
||||
try:
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# Verify directory was created
|
||||
if not os.path.exists(save_directory):
|
||||
raise RuntimeError(f"Failed to create directory: {save_directory}")
|
||||
|
||||
# Verify source file exists
|
||||
if not os.path.exists(self._tokenizer_path):
|
||||
raise FileNotFoundError(
|
||||
f"Source tokenizer file not found: {self._tokenizer_path}"
|
||||
)
|
||||
|
||||
destination_path = os.path.join(save_directory, "tekken.json")
|
||||
copyfile(self._tokenizer_path, destination_path)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to save tokenizer to {save_directory}: {e}. "
|
||||
f"Source path: {self._tokenizer_path}, "
|
||||
f"Directory exists: {os.path.exists(save_directory)}"
|
||||
) from e
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown tokenizer type: {type(inner)}")
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
|
||||
"""
|
||||
Encode a text string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
text: The text string to encode.
|
||||
add_special_tokens: Whether to add special tokens to the encoded tokens.
|
||||
|
||||
Returns:
|
||||
A list of token IDs.
|
||||
"""
|
||||
return self._mistral.instruct_tokenizer.tokenizer.encode(
|
||||
text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens,
|
||||
)
|
||||
|
||||
def decode(
|
||||
self, token_ids: int | list[int], skip_special_tokens: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Decode a list of token IDs into a text string.
|
||||
|
||||
Args:
|
||||
token_ids: The int or list of token IDs to decode.
|
||||
skip_special_tokens: Whether to skip special tokens in the decoded text.
|
||||
|
||||
Returns:
|
||||
The decoded text string.
|
||||
"""
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
|
||||
if skip_special_tokens:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
|
||||
|
||||
# to_string returns a string with special tokens
|
||||
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
|
||||
|
||||
def _create_mistral_chat_completion_request(
|
||||
self, conversation: list[dict], tools: list[dict] | None = None
|
||||
) -> "ChatCompletionRequest":
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AssistantMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||
|
||||
messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = (
|
||||
[]
|
||||
)
|
||||
for turn in conversation:
|
||||
role = turn.get("role")
|
||||
|
||||
if role == "user":
|
||||
messages.append(UserMessage(content=turn["content"]))
|
||||
elif role == "assistant":
|
||||
messages.append(
|
||||
AssistantMessage(
|
||||
content=turn.get("content"),
|
||||
tool_calls=turn.get("tool_calls"),
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
messages.append(
|
||||
ToolMessage(
|
||||
content=turn.get("content"),
|
||||
tool_call_id=turn.get("tool_call_id"),
|
||||
name=turn.get("name"),
|
||||
)
|
||||
)
|
||||
elif role == "system":
|
||||
messages.append(SystemMessage(content=turn["content"]))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown role for use with mistral-common tokenizer: {turn['role']}"
|
||||
)
|
||||
|
||||
tool_calls: list[Tool] = []
|
||||
if tools:
|
||||
# convert to Tool
|
||||
for tool in tools:
|
||||
if tool["type"] != "function":
|
||||
continue
|
||||
|
||||
function = tool["function"]
|
||||
|
||||
tool_calls.append(
|
||||
Tool(
|
||||
function=Function(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
# set parameters to empty dict if not provided
|
||||
parameters=function.get("parameters", {}),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
chat_completion: ChatCompletionRequest = ChatCompletionRequest(
|
||||
messages=messages,
|
||||
tools=tool_calls,
|
||||
)
|
||||
|
||||
return chat_completion
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tokenize: bool = True,
|
||||
tools: list[dict] | None = None,
|
||||
chat_template: str | None = None, # pylint: disable=unused-argument
|
||||
add_generation_prompt: bool = False, # pylint: disable=unused-argument
|
||||
) -> list[int] | str:
|
||||
if chat_template:
|
||||
raise NotImplementedError("chat_template not supported yet")
|
||||
|
||||
if add_generation_prompt:
|
||||
raise NotImplementedError("add_generation_prompt not supported yet")
|
||||
|
||||
chat_completion: ChatCompletionRequest = (
|
||||
self._create_mistral_chat_completion_request(messages, tools)
|
||||
)
|
||||
|
||||
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
|
||||
|
||||
if tokenize:
|
||||
return tokens
|
||||
|
||||
return self.decode(tokens)
|
||||
|
||||
def pad(
|
||||
self,
|
||||
features: list[dict[str, list[int] | np.ndarray]],
|
||||
*,
|
||||
padding: bool | str | PaddingStrategy = True,
|
||||
max_length: int | None = None,
|
||||
pad_to_multiple_of: int | None = None,
|
||||
return_tensors: str | None = None, # "np", "pt", or "tf"
|
||||
) -> dict[str, np.ndarray | Tensor]:
|
||||
"""
|
||||
HF-style pad method that properly handles all sequence-related features:
|
||||
- pad 'input_ids' & 'labels' to the longest (or to max_length)
|
||||
"""
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
# Check for unsupported fields
|
||||
if any("token_type_ids" in f for f in features):
|
||||
raise ValueError("token_type_ids is not supported by this tokenizer")
|
||||
|
||||
# Determine desired sequence length
|
||||
lengths = [len(f["input_ids"]) for f in features]
|
||||
if padding in (True, "longest", PaddingStrategy.LONGEST):
|
||||
target_length = max(lengths)
|
||||
elif padding in ("max_length", PaddingStrategy.MAX_LENGTH):
|
||||
if max_length is None:
|
||||
raise ValueError("max_length must be set for 'max_length' padding")
|
||||
target_length = max_length
|
||||
elif padding in (False, "do_not_pad", PaddingStrategy.DO_NOT_PAD):
|
||||
target_length = None
|
||||
else:
|
||||
raise ValueError(f"Unknown padding strategy: {padding}")
|
||||
|
||||
# Apply pad_to_multiple_of
|
||||
if target_length is not None and pad_to_multiple_of is not None:
|
||||
target_length = (
|
||||
math.ceil(target_length / pad_to_multiple_of) * pad_to_multiple_of
|
||||
)
|
||||
|
||||
# If no padding requested, just stack tensors
|
||||
do_pad = target_length is not None
|
||||
|
||||
# Pad sequences using torch.nn.utils.rnn.pad_sequence
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["input_ids"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
)
|
||||
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["labels"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=IGNORE_INDEX,
|
||||
)
|
||||
|
||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
|
||||
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
|
||||
if "position_ids" in features[0]:
|
||||
if self.padding_side == "left":
|
||||
# Likely not needed, but keeping for now
|
||||
# For left padding, we'll pad with 0s using pad_sequence, then handle manually
|
||||
position_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
[
|
||||
torch.tensor(x["position_ids"], dtype=torch.long)
|
||||
for x in features
|
||||
],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
else:
|
||||
# For right padding, continue the sequence
|
||||
max_pos_len = max(len(f["position_ids"]) for f in features)
|
||||
position_ids_list = []
|
||||
for f in features:
|
||||
pos_seq = torch.tensor(f["position_ids"], dtype=torch.long)
|
||||
if len(pos_seq) < max_pos_len:
|
||||
# Continue the sequence
|
||||
last_pos = pos_seq[-1].item() if len(pos_seq) > 0 else -1
|
||||
pad_len = max_pos_len - len(pos_seq)
|
||||
pad_positions = torch.arange(
|
||||
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
||||
)
|
||||
pos_seq = torch.cat([pos_seq, pad_positions])
|
||||
position_ids_list.append(pos_seq)
|
||||
position_ids = torch.stack(position_ids_list)
|
||||
else:
|
||||
# Create position_ids if not present
|
||||
seq_len = input_ids.size(1)
|
||||
position_ids = (
|
||||
torch.arange(seq_len, dtype=torch.long)
|
||||
.unsqueeze(0)
|
||||
.expand(input_ids.size(0), -1)
|
||||
)
|
||||
|
||||
# Ensure all tensors have the same sequence length
|
||||
max_seq_len = max(
|
||||
input_ids.size(1),
|
||||
labels.size(1),
|
||||
attention_mask.size(1),
|
||||
position_ids.size(1),
|
||||
)
|
||||
|
||||
# TODO: check if trimming is needed? and correct.
|
||||
|
||||
if do_pad and target_length is not None:
|
||||
max_seq_len = target_length
|
||||
|
||||
# Pad all tensors to the same length
|
||||
if input_ids.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - input_ids.size(1)
|
||||
if self.padding_side == "right":
|
||||
input_ids = F.pad(
|
||||
input_ids,
|
||||
(0, pad_len),
|
||||
value=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
)
|
||||
else:
|
||||
input_ids = F.pad(
|
||||
input_ids,
|
||||
(pad_len, 0),
|
||||
value=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
)
|
||||
elif input_ids.size(1) > max_seq_len:
|
||||
input_ids = input_ids[:, :max_seq_len]
|
||||
|
||||
if labels.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - labels.size(1)
|
||||
if self.padding_side == "right":
|
||||
labels = F.pad(labels, (0, pad_len), value=IGNORE_INDEX)
|
||||
else:
|
||||
labels = F.pad(labels, (pad_len, 0), value=IGNORE_INDEX)
|
||||
elif labels.size(1) > max_seq_len:
|
||||
labels = labels[:, :max_seq_len]
|
||||
|
||||
if attention_mask.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - attention_mask.size(1)
|
||||
if self.padding_side == "right":
|
||||
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
|
||||
else:
|
||||
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
|
||||
elif attention_mask.size(1) > max_seq_len:
|
||||
attention_mask = attention_mask[:, :max_seq_len]
|
||||
|
||||
if position_ids.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - position_ids.size(1)
|
||||
if self.padding_side == "right":
|
||||
batch_size = position_ids.size(0)
|
||||
new_position_ids = []
|
||||
for i in range(batch_size):
|
||||
seq = position_ids[i]
|
||||
if len(seq) > 0:
|
||||
# get last position and pad with sequential values
|
||||
last_pos = seq[-1].item()
|
||||
pad_positions = torch.arange(
|
||||
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
||||
)
|
||||
new_seq = torch.cat([seq, pad_positions])
|
||||
else:
|
||||
new_seq = torch.arange(pad_len, dtype=torch.long)
|
||||
new_position_ids.append(new_seq)
|
||||
position_ids = torch.stack(new_position_ids)
|
||||
else:
|
||||
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
|
||||
elif position_ids.size(1) > max_seq_len:
|
||||
position_ids = position_ids[:, :max_seq_len]
|
||||
|
||||
final_batch = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
# Handle non-sequence fields (raise error)
|
||||
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
|
||||
for f in features:
|
||||
for key in f.keys():
|
||||
if key not in sequence_fields:
|
||||
raise NotImplementedError(
|
||||
f"Non-sequence field {key} not handled yet"
|
||||
)
|
||||
|
||||
# Convert to requested tensor type
|
||||
if return_tensors is None or return_tensors == "np":
|
||||
result = {}
|
||||
for k, v in final_batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
result[k] = v.numpy().astype(np.long)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
if return_tensors == "pt":
|
||||
return final_batch
|
||||
|
||||
raise ValueError(f"Unsupported return_tensors='{return_tensors}'")
|
||||
|
||||
def convert_ids_to_tokens(self, ids: list[int]) -> list[str]:
|
||||
"""
|
||||
Convert a list of token IDs to a list of tokens.
|
||||
|
||||
Args:
|
||||
ids: The list of token IDs to convert.
|
||||
|
||||
Returns:
|
||||
The list of tokens.
|
||||
"""
|
||||
return [
|
||||
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
|
||||
]
|
||||
@@ -1265,6 +1265,68 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tokenizer_use_mistral_common(cls, data):
|
||||
if data.get("tokenizer_use_mistral_common") is None:
|
||||
if any(
|
||||
"magistral" in name.lower()
|
||||
for name in [
|
||||
data.get("base_model", ""),
|
||||
data.get("base_model_config", ""),
|
||||
data.get("tokenizer_config", ""),
|
||||
]
|
||||
):
|
||||
LOG.warning(
|
||||
"tokenizer_use_mistral_common auto inferred to True for Magistral models. Please set it to True explicitly if you want to use mistral-common tokenizer."
|
||||
)
|
||||
data["tokenizer_use_mistral_common"] = True
|
||||
|
||||
return data
|
||||
|
||||
@field_validator("tokenizer_use_mistral_common", mode="after")
|
||||
@classmethod
|
||||
def check_mistral_common_import(cls, tokenizer_use_mistral_common):
|
||||
if tokenizer_use_mistral_common:
|
||||
try:
|
||||
import mistral_common # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`."
|
||||
) from exception
|
||||
|
||||
return tokenizer_use_mistral_common
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_mistral_common_incompatible_options(cls, data):
|
||||
if not data.get("tokenizer_use_mistral_common"):
|
||||
return data
|
||||
|
||||
# NOTE: mistral-common tokenizer is not compatible with editing tokenizer at the moment
|
||||
|
||||
if data.get("added_tokens_overrides"):
|
||||
raise ValueError(
|
||||
"added_tokens_overrides is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
if data.get("special_tokens"):
|
||||
raise ValueError(
|
||||
"special_tokens override is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
if data.get("tokens"):
|
||||
raise ValueError(
|
||||
"tokens override is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
if data.get("chat_template"):
|
||||
raise ValueError(
|
||||
"Setting chat_template is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
@@ -18,6 +18,7 @@ class ModelInputConfig(BaseModel):
|
||||
tokenizer_config: str | None = None
|
||||
tokenizer_use_fast: bool | None = None
|
||||
tokenizer_legacy: bool | None = None
|
||||
tokenizer_use_mistral_common: bool | None = None
|
||||
tokenizer_type: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||
)
|
||||
|
||||
@@ -150,6 +150,14 @@ def fixture_gemma2_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="magistral_tokenizer")
|
||||
def fixture_magistral_tokenizer():
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Magistral-Small-2506")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
|
||||
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
|
||||
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
|
||||
|
||||
290
tests/prompt_strategies/test_chat_templates_mistral.py
Normal file
290
tests/prompt_strategies/test_chat_templates_mistral.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Test chat templates for mistral-common wrapper tokenizer"""
|
||||
|
||||
import unittest
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
|
||||
def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
|
||||
# pylint: disable=duplicate-code
|
||||
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
|
||||
|
||||
# check bos, eos, pad, unk are accessible properties
|
||||
assert magistral_tokenizer.bos_token_id == 1
|
||||
assert magistral_tokenizer.eos_token_id == 2
|
||||
assert magistral_tokenizer.pad_token_id == 11
|
||||
assert magistral_tokenizer.unk_token_id == 0
|
||||
|
||||
assert magistral_tokenizer.pad_token == "<pad>"
|
||||
assert magistral_tokenizer.eos_token == "</s>"
|
||||
assert magistral_tokenizer.bos_token == "<s>"
|
||||
assert magistral_tokenizer.unk_token == "<unk>"
|
||||
|
||||
strategy = MistralStrategy(
|
||||
MistralPrompter(
|
||||
magistral_tokenizer,
|
||||
chat_template=None,
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=magistral_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="turn",
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
# test chat template masking without system prompt
|
||||
res = strategy.tokenize_prompt(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great, thank you!"},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert res["input_ids"] == [
|
||||
1, # bos
|
||||
3, # [INST]
|
||||
22177, # Hello
|
||||
1044, # ,
|
||||
2606, # how
|
||||
1584, # are
|
||||
1636, # you
|
||||
1063, # ?
|
||||
4, # [/INST]
|
||||
1073, # I
|
||||
4525, # 'm
|
||||
6965, # doing
|
||||
4824, # great
|
||||
1044, # ,
|
||||
15412, # thank
|
||||
1636, # you
|
||||
1033, # !
|
||||
2, # </s>
|
||||
]
|
||||
|
||||
assert res["labels"] == [
|
||||
-100, # bos
|
||||
-100, # [INST]
|
||||
-100, # Hello
|
||||
-100, # ,
|
||||
-100, # how
|
||||
-100, # are
|
||||
-100, # you
|
||||
-100, # ?
|
||||
-100, # [/INST]
|
||||
1073, # I
|
||||
4525, # 'm
|
||||
6965, # doing
|
||||
4824, # great
|
||||
1044, # ,
|
||||
15412, # thank
|
||||
1636, # you
|
||||
1033, # !
|
||||
2, # </s>
|
||||
]
|
||||
|
||||
# test chat template masking with system prompt
|
||||
res = strategy.tokenize_prompt(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great, thank you!"},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert res["input_ids"] == [
|
||||
1, # bos
|
||||
17, # [SYSTEM_PROMPT]
|
||||
4568, # You
|
||||
1584, # are
|
||||
1261, # a
|
||||
20351, # helpful
|
||||
27089, # assistant
|
||||
1046, # .
|
||||
18, # [/SYSTEM_PROMPT]
|
||||
3, # [INST]
|
||||
22177, # Hello
|
||||
1044, # ,
|
||||
2606, # how
|
||||
1584, # are
|
||||
1636, # you
|
||||
1063, # ?
|
||||
4, # [/INST]
|
||||
1073, # I
|
||||
4525, # 'm
|
||||
6965, # doing
|
||||
4824, # great
|
||||
1044, # ,
|
||||
15412, # thank
|
||||
1636, # you
|
||||
1033, # !
|
||||
2, # </s>
|
||||
]
|
||||
|
||||
assert res["labels"] == [
|
||||
-100, # bos
|
||||
-100, # [SYSTEM_PROMPT]
|
||||
-100, # You
|
||||
-100, # are
|
||||
-100, # a
|
||||
-100, # helpful
|
||||
-100, # assistant
|
||||
-100, # .
|
||||
-100, # [/SYSTEM_PROMPT]
|
||||
-100, # [INST]
|
||||
-100, # Hello
|
||||
-100, # ,
|
||||
-100, # how
|
||||
-100, # are
|
||||
-100, # you
|
||||
-100, # ?
|
||||
-100, # [/INST]
|
||||
1073, # I
|
||||
4525, # 'm
|
||||
6965, # doing
|
||||
4824, # great
|
||||
1044, # ,
|
||||
15412, # thank
|
||||
1636, # you
|
||||
1033, # !
|
||||
2, # </s>
|
||||
]
|
||||
|
||||
# test chat template with tools
|
||||
res = strategy.tokenize_prompt(
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "The number to find multiples of.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "The upper limit for the multiples.",
|
||||
},
|
||||
},
|
||||
"required": ["number", "limit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey, can you give me a breakdown of how to throw an awesome themed party? Like, what themes work best, and how can I set everything up to really wow my guests? I want some ideas on decorations, food, and activities that will make the party unforgettable!",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call12345",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"arguments": {
|
||||
"number": 16,
|
||||
"limit": 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call12345",
|
||||
"name": "multiples",
|
||||
"content": "1,2",
|
||||
},
|
||||
{"role": "assistant", "content": "The multiples of 16 is 1 and 2."},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
assert res["input_ids"] == [
|
||||
1, # bos
|
||||
5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt
|
||||
3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user
|
||||
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
|
||||
7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8, # tool result
|
||||
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
|
||||
2 # eos
|
||||
]
|
||||
|
||||
assert res["labels"] == [
|
||||
-100, # bos
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt
|
||||
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool result
|
||||
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
|
||||
2 # eos
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# test chat template with tokenize=False
|
||||
res = magistral_tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great, thank you!"},
|
||||
],
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
assert res == "<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>"
|
||||
|
||||
# test encode
|
||||
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=True)
|
||||
assert res == [
|
||||
1, # bos
|
||||
22177, # Hello
|
||||
1044, # ,
|
||||
2606, # how
|
||||
1584, # are
|
||||
1636, # you
|
||||
1063, # ?
|
||||
2, # eos
|
||||
]
|
||||
|
||||
# test decode no skip special tokens
|
||||
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=False)
|
||||
|
||||
assert decoded_res == "<s>Hello, how are you?</s>"
|
||||
|
||||
# test decode skip special tokens
|
||||
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=True)
|
||||
assert decoded_res == "Hello, how are you?"
|
||||
|
||||
# test encode no special tokens
|
||||
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=False)
|
||||
assert res == [
|
||||
22177, # Hello
|
||||
1044, # ,
|
||||
2606, # how
|
||||
1584, # are
|
||||
1636, # you
|
||||
1063, # ?
|
||||
]
|
||||
|
||||
# test convert ids to tokens
|
||||
res = magistral_tokenizer.convert_ids_to_tokens(res)
|
||||
# spacing are needed as we are converting without decoding
|
||||
assert res == ["Hello", ",", " how", " are", " you", "?"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user