diff --git a/README.md b/README.md index f7765c475..ef5523898 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ## 🎉 Latest Updates +- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl! - 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more! - 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version! - 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning. diff --git a/docs/config.qmd b/docs/config.qmd index 2ca236708..d146b4c84 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -27,6 +27,8 @@ trust_remote_code: tokenizer_use_fast: # Whether to use the legacy tokenizer setting, defaults to True tokenizer_legacy: +# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer. +tokenizer_use_mistral_common: # Resize the model embeddings when new tokens are added to multiples of 32 # This is reported to improve training speed on some models resize_token_embeddings_to_32x: diff --git a/examples/magistral/README.md b/examples/magistral/README.md index 172d9ac93..a2b09ab70 100644 --- a/examples/magistral/README.md +++ b/examples/magistral/README.md @@ -1,3 +1,71 @@ -# Coming Soon! +# Finetune Magistral Small with Axolotl -Watch this space for configs for fine-tuning [Magistral Small 2506](https://huggingface.co/mistralai/Magistral-Small-2506). +Magistral Small is a 24B parameter opensource model from MistralAI found on [HuggingFace](https://huggingface.co/mistralai/Magistral-Small-2506). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking. + +MistralAI has also released a proprietary medium-sized version called Magistral Medium. + +Thanks to the team at MistralAI for giving us early access to prepare for this release. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 recommended) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn,mistral]' +``` + +2. Download the example config: + +```bash +axolotl fetch examples +``` + +3. Run the finetuning example: + +```bash +axolotl train examples/magistral/magistral-small-qlora.yaml +``` + +This config uses about 24GB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Limitations + +We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. + +The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet. + +## Related Resources + +- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) + + +## Future Work + +- Add parity to Preference Tuning, RL, Multi-modal, etc. +- Add parity to other tokenizer configs like overriding tokens. diff --git a/examples/magistral/magistral-small-qlora.yaml b/examples/magistral/magistral-small-qlora.yaml new file mode 100644 index 000000000..e3e746f22 --- /dev/null +++ b/examples/magistral/magistral-small-qlora.yaml @@ -0,0 +1,63 @@ +base_model: mistralai/Magistral-Small-2506 + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 diff --git a/requirements.txt b/requirements.txt index 3af94421d..cf8caba00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -67,3 +67,5 @@ schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.3 + +mistral-common==1.6.0 diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 7c112c59e..28182b16f 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -48,6 +48,13 @@ class TokenizedPromptDataset(Dataset): features = dataset.features.keys() num_proc = min(64, self.process_count if self.process_count else os.cpu_count()) + # Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common) + if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True): + LOG.info( + "Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)" + ) + num_proc = 1 + map_kwargs = {} if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index eb067cd04..7c99a9c3d 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -189,7 +189,7 @@ class KDStrategyLoader(StrategyLoader): Load ChatTemplateStrategy with KD support using StrategyLoader. """ - def _get_strategy_cls(self): + def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument return ChatTemplateStrategyWithKD def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 5a174186d..4f9a60a69 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -121,6 +121,19 @@ def modify_tokenizer_files( def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: """Load and configure the tokenizer based on the provided config.""" + + def _load_mistral_common_tokenizer(cfg: DictDefault): + """Load mistral-common tokenizer""" + from axolotl.utils.mistral_tokenizer import HFMistralTokenizer + + # Load the HF-compatible wrapper around MistralTokenizer + tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config) + + return tokenizer + + if cfg.tokenizer_use_mistral_common: + return _load_mistral_common_tokenizer(cfg) + model_config = load_model_config(cfg) tokenizer_kwargs = {} use_fast = True # this is the default diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 1fee0f7f6..4a358928e 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -2,8 +2,10 @@ HF Chat Templates prompt strategy """ +# pylint: disable=too-many-lines + from collections import defaultdict -from typing import Any, Dict, List, Set, Union +from typing import TYPE_CHECKING, Any, Dict, List, Set, Union from pydantic import BaseModel from transformers import ProcessorMixin @@ -15,6 +17,9 @@ from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import DatasetConfig +if TYPE_CHECKING: + from axolotl.utils.mistral_tokenizer import HFMistralTokenizer + # Configure the logger LOG = get_logger(__name__) LOG.setLevel("INFO") @@ -81,7 +86,7 @@ class ChatTemplatePrompter(Prompter): def build_prompt( self, - conversation, + conversation: list[dict], add_generation_prompt=False, images=None, tools=None, @@ -271,9 +276,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos # Default to eos_token if eot_tokens not provided - self.eot_tokens = ( - eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token] - ) + self.eot_tokens = [] + if eot_tokens is not None: + self.eot_tokens = eot_tokens + elif ( + hasattr(self.tokenizer, "eos_token") + and self.tokenizer.eos_token is not None + ): + self.eot_tokens = [self.tokenizer.eos_token] + self.split_thinking = split_thinking self.images = "images" @@ -796,14 +807,104 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): ) +class MistralStrategy(ChatTemplateStrategy): + """ + Mistral strategy for chat template. + """ + + def __init__( + self, + prompter: "ChatTemplatePrompter", + tokenizer: "HFMistralTokenizer", + train_on_inputs: bool, + sequence_len: int, + roles_to_train: list[str] | None = None, + train_on_eos: str | None = None, + train_on_eot: str | None = None, + eot_tokens: list[str] | None = None, + split_thinking: bool | None = False, + ): + # Call the parent's parent __init__ (PromptTokenizingStrategy) to skip ChatTemplateStrategy's validation + # pylint: disable=non-parent-init-called,super-init-not-called + PromptTokenizingStrategy.__init__( + self, prompter, tokenizer, train_on_inputs, sequence_len + ) + self.prompter: ChatTemplatePrompter = prompter + + self.roles_to_train = [] + if roles_to_train: + # map roles if exist in prompter.roles else use the role as is + self.roles_to_train = [ + prompter.roles.get(role, role) for role in roles_to_train + ] + + self.train_on_eos = train_on_eos + # Backward compatibility, load from train_on_eos + self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos + + # Default to eos_token if eot_tokens not provided + self.eot_tokens = [] + if eot_tokens is not None: + self.eot_tokens = eot_tokens + else: + # set eot_tokens to the eos_token + self.eot_tokens = [self.tokenizer.eos_token] + + self.split_thinking = split_thinking + + self.images = "images" + + LOG.debug( + f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}" + ) + + # Skip the validation that ChatTemplateStrategy calls + # TODO: address this in the future with mistral-specific checks + # self._validate_eot_and_eos_tokens() + + @property + def supports_multiprocessing(self) -> bool: + """ + Whether this tokenizing strategy supports multiprocessing. + mistral_common tokenizers cannot be pickled for multiprocessing. + """ + + return False + + def find_first_eot_token(self, input_ids, start_idx): + """Find the first EOT token in the input_ids starting from start_idx.""" + # mistral-common tokenizer does not support eot_tokens + return self.find_first_eos_token(input_ids, start_idx) + + +class MistralPrompter(ChatTemplatePrompter): + """ + Mistral prompter for chat template. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._chat_template_msg_variables = set(["tool_call_id", "name", "tool_calls"]) + + class StrategyLoader: """ Load chat template strategy based on configuration. """ - def _get_strategy_cls(self): + def _get_strategy_cls(self, cfg): + if cfg.tokenizer_use_mistral_common: + return MistralStrategy + return ChatTemplateStrategy + def _get_prompter_cls(self, cfg): + if cfg.tokenizer_use_mistral_common: + return MistralPrompter + + return ChatTemplatePrompter + def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): return { "train_on_inputs": cfg.train_on_inputs, @@ -829,9 +930,14 @@ class StrategyLoader: else: dataset_config = ds_cfg - chat_template_string = get_chat_template_from_config( - cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer - ) + if cfg.tokenizer_use_mistral_common: + # mistral-common does not use this, so we pass an empty string + chat_template_string = "" + else: + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer + ) + LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---") prompter_params = { @@ -857,10 +963,11 @@ class StrategyLoader: } strategy_params = self._get_strategy_params(cfg, dataset_config) - strategy_cls = self._get_strategy_cls() + strategy_cls = self._get_strategy_cls(cfg) + prompter_cls = self._get_prompter_cls(cfg) strategy = strategy_cls( - ChatTemplatePrompter(**prompter_params), + prompter_cls(**prompter_params), tokenizer=tokenizer, **strategy_params, ) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 9ca645de3..aae778ae8 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -70,6 +70,14 @@ class PromptTokenizingStrategy(abc.ABC): def supports_batched(self): return False + @property + def supports_multiprocessing(self): + """ + Whether this tokenizing strategy supports multiprocessing. + Should return False if the tokenizer has unpicklable objects. + """ + return True + def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: diff --git a/src/axolotl/utils/mistral_tokenizer.py b/src/axolotl/utils/mistral_tokenizer.py new file mode 100644 index 000000000..3ccf39bb0 --- /dev/null +++ b/src/axolotl/utils/mistral_tokenizer.py @@ -0,0 +1,567 @@ +"""Wrapper for MistralTokenizer from mistral-common""" + +import math +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Optional + +import numpy as np +from huggingface_hub import hf_hub_download +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.tekken import Tekkenizer +from torch import Tensor +from transformers.utils import PaddingStrategy + +from axolotl.utils.collators.core import IGNORE_INDEX + +if TYPE_CHECKING: + from mistral_common.protocol.instruct.request import ChatCompletionRequest + + +def _get_file_path(path_or_repo_id: str, filename: str) -> str: + """Get the file path from local or HF Hub""" + if os.path.exists(path_or_repo_id): + maybe_file_path = os.path.join(path_or_repo_id, filename) + if os.path.exists(maybe_file_path): + return maybe_file_path + + raise FileNotFoundError(f"File not found at {path_or_repo_id}") + + return hf_hub_download(repo_id=path_or_repo_id, filename=filename) + + +class HFMistralTokenizer: + """ + Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer + and exposes HuggingFace API for special tokens. + """ + + def __init__( + self, mistral: MistralTokenizer, name_or_path: str, tokenizer_path: str + ): + """ + Args: + mistral: The mistral-common tokenizer to wrap. + name_or_path: The name or path to the tokenizer files or the repo id. + """ + self._mistral = mistral + self._padding_side = "right" + self._name_or_path = name_or_path + self._tokenizer_path = tokenizer_path + + # Manual set to training mode + from mistral_common.protocol.instruct.validator import ( + MistralRequestValidator, + ValidationMode, + ) + + # Check if MistralRequestValidator has a _mode attribute. + # This is a private API and may change in the future. + # pylint: disable=protected-access + if not ( + hasattr(self._mistral, "_chat_completion_request_validator") + and isinstance( + self._mistral._chat_completion_request_validator, + MistralRequestValidator, + ) + and hasattr(self._mistral._chat_completion_request_validator, "_mode") + ): + raise RuntimeError( + "Unable to switch mistral tokenizer to finetuning mode – " + "private API `_chat_completion_request_validator._mode` missing." + ) + + self._mistral._chat_completion_request_validator._mode = ( + ValidationMode.finetuning + ) + + def _load_system_prompt(self, path_or_repo_id: str) -> str: + """Load system prompt from local or HF Hub. + + Note: Unused for now as we don't want to explicitly set the system prompt if a user does + not provide one. + + Args: + path_or_repo_id: The path to the tokenizer files or the repo id. + + Returns: + The system prompt. + """ + file_path = _get_file_path(path_or_repo_id, "SYSTEM_PROMPT.txt") + + if not os.path.exists(file_path): + raise FileNotFoundError(f"System prompt file not found at {file_path}") + + with open(file_path, "r", encoding="utf-8") as file: + return file.read() + + @property + def bos_token_id(self) -> int: + return self._mistral.instruct_tokenizer.tokenizer.bos_id + + @property + def eos_token_id(self) -> int: + return self._mistral.instruct_tokenizer.tokenizer.eos_id + + @property + def pad_token_id(self) -> int: + return self._mistral.instruct_tokenizer.tokenizer.pad_id + + @property + def unk_token_id(self) -> int: + return self._mistral.instruct_tokenizer.tokenizer.unk_id + + @property + def bos_token(self) -> str: + return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.bos_token_id) + + @property + def eos_token(self) -> str: + return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.eos_token_id) + + @property + def pad_token(self) -> str: + return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.pad_token_id) + + @property + def unk_token(self) -> str: + return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.unk_token_id) + + @property + def padding_side(self) -> str: + return self._padding_side + + @property + def name_or_path(self) -> str: + return self._name_or_path + + @property + def chat_template(self) -> str | None: + """Chat template is not supported. Dummy method to satisfy HuggingFace API.""" + return None + + def __len__(self) -> int: + return self._mistral.instruct_tokenizer.tokenizer.n_words + + @classmethod + def from_pretrained( + cls, + name_or_path: str, + *, + revision: Optional[str] = None, + **kwargs, # pylint: disable=unused-argument + ) -> "HFMistralTokenizer": + """ + Load a mistral tekken tokenizer from a local file or HF Hub and wrap it. + + Args: + path_or_repo_id: The path to the tokenizer files or the repo id. + revision: The revision of the tokenizer to download. + kwargs: Additional keyword arguments. + + Returns: + A HFMistralTokenizer instance. + """ + if revision: + raise NotImplementedError( + "Revision not supported yet for mistral-common tokenizer" + ) + + # only support Tekken tokenizer for now + # downloads from HF Hub if not local + tokenizer_path = _get_file_path(name_or_path, "tekken.json") + + base = MistralTokenizer.from_file(tokenizer_path) + + return cls( + base, + name_or_path=name_or_path, + tokenizer_path=tokenizer_path, + ) + + def save_pretrained(self, save_directory: str) -> None: + """ + Save the Tekken/SentencePiece model file so that from_pretrained can pick it up again. + + Only Tekken models are supported. + + Args: + save_directory: The directory to save the tokenizer files. + """ + inner = self._mistral.instruct_tokenizer.tokenizer + if isinstance(inner, Tekkenizer): + # Create the directory and save the model + try: + os.makedirs(save_directory, exist_ok=True) + + # Verify directory was created + if not os.path.exists(save_directory): + raise RuntimeError(f"Failed to create directory: {save_directory}") + + # Verify source file exists + if not os.path.exists(self._tokenizer_path): + raise FileNotFoundError( + f"Source tokenizer file not found: {self._tokenizer_path}" + ) + + destination_path = os.path.join(save_directory, "tekken.json") + copyfile(self._tokenizer_path, destination_path) + + except Exception as e: + raise RuntimeError( + f"Failed to save tokenizer to {save_directory}: {e}. " + f"Source path: {self._tokenizer_path}, " + f"Directory exists: {os.path.exists(save_directory)}" + ) from e + + else: + raise RuntimeError(f"Unknown tokenizer type: {type(inner)}") + + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + """ + Encode a text string into a list of token IDs. + + Args: + text: The text string to encode. + add_special_tokens: Whether to add special tokens to the encoded tokens. + + Returns: + A list of token IDs. + """ + return self._mistral.instruct_tokenizer.tokenizer.encode( + text, + bos=add_special_tokens, + eos=add_special_tokens, + ) + + def decode( + self, token_ids: int | list[int], skip_special_tokens: bool = False + ) -> str: + """ + Decode a list of token IDs into a text string. + + Args: + token_ids: The int or list of token IDs to decode. + skip_special_tokens: Whether to skip special tokens in the decoded text. + + Returns: + The decoded text string. + """ + if isinstance(token_ids, int): + token_ids = [token_ids] + + if skip_special_tokens: + return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids) + + # to_string returns a string with special tokens + return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids) + + def _create_mistral_chat_completion_request( + self, conversation: list[dict], tools: list[dict] | None = None + ) -> "ChatCompletionRequest": + from mistral_common.protocol.instruct.messages import ( + AssistantMessage, + SystemMessage, + ToolMessage, + UserMessage, + ) + from mistral_common.protocol.instruct.request import ChatCompletionRequest + from mistral_common.protocol.instruct.tool_calls import Function, Tool + + messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = ( + [] + ) + for turn in conversation: + role = turn.get("role") + + if role == "user": + messages.append(UserMessage(content=turn["content"])) + elif role == "assistant": + messages.append( + AssistantMessage( + content=turn.get("content"), + tool_calls=turn.get("tool_calls"), + ) + ) + elif role == "tool": + messages.append( + ToolMessage( + content=turn.get("content"), + tool_call_id=turn.get("tool_call_id"), + name=turn.get("name"), + ) + ) + elif role == "system": + messages.append(SystemMessage(content=turn["content"])) + else: + raise ValueError( + f"Unknown role for use with mistral-common tokenizer: {turn['role']}" + ) + + tool_calls: list[Tool] = [] + if tools: + # convert to Tool + for tool in tools: + if tool["type"] != "function": + continue + + function = tool["function"] + + tool_calls.append( + Tool( + function=Function( + name=function["name"], + description=function["description"], + # set parameters to empty dict if not provided + parameters=function.get("parameters", {}), + ) + ) + ) + + chat_completion: ChatCompletionRequest = ChatCompletionRequest( + messages=messages, + tools=tool_calls, + ) + + return chat_completion + + def apply_chat_template( + self, + messages: list[dict], + tokenize: bool = True, + tools: list[dict] | None = None, + chat_template: str | None = None, # pylint: disable=unused-argument + add_generation_prompt: bool = False, # pylint: disable=unused-argument + ) -> list[int] | str: + if chat_template: + raise NotImplementedError("chat_template not supported yet") + + if add_generation_prompt: + raise NotImplementedError("add_generation_prompt not supported yet") + + chat_completion: ChatCompletionRequest = ( + self._create_mistral_chat_completion_request(messages, tools) + ) + + tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens + + if tokenize: + return tokens + + return self.decode(tokens) + + def pad( + self, + features: list[dict[str, list[int] | np.ndarray]], + *, + padding: bool | str | PaddingStrategy = True, + max_length: int | None = None, + pad_to_multiple_of: int | None = None, + return_tensors: str | None = None, # "np", "pt", or "tf" + ) -> dict[str, np.ndarray | Tensor]: + """ + HF-style pad method that properly handles all sequence-related features: + - pad 'input_ids' & 'labels' to the longest (or to max_length) + """ + import torch + from torch.nn import functional as F + + # Check for unsupported fields + if any("token_type_ids" in f for f in features): + raise ValueError("token_type_ids is not supported by this tokenizer") + + # Determine desired sequence length + lengths = [len(f["input_ids"]) for f in features] + if padding in (True, "longest", PaddingStrategy.LONGEST): + target_length = max(lengths) + elif padding in ("max_length", PaddingStrategy.MAX_LENGTH): + if max_length is None: + raise ValueError("max_length must be set for 'max_length' padding") + target_length = max_length + elif padding in (False, "do_not_pad", PaddingStrategy.DO_NOT_PAD): + target_length = None + else: + raise ValueError(f"Unknown padding strategy: {padding}") + + # Apply pad_to_multiple_of + if target_length is not None and pad_to_multiple_of is not None: + target_length = ( + math.ceil(target_length / pad_to_multiple_of) * pad_to_multiple_of + ) + + # If no padding requested, just stack tensors + do_pad = target_length is not None + + # Pad sequences using torch.nn.utils.rnn.pad_sequence + input_ids = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(x["input_ids"], dtype=torch.long) for x in features], + batch_first=True, + padding_value=self.pad_token_id if self.pad_token_id is not None else 0, + ) + + labels = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(x["labels"], dtype=torch.long) for x in features], + batch_first=True, + padding_value=IGNORE_INDEX, + ) + + attention_mask = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(x["attention_mask"], dtype=torch.long) for x in features], + batch_first=True, + padding_value=0, + ) + + # Handle position_ids - pad with sequential values for right padding, 0s for left padding + if "position_ids" in features[0]: + if self.padding_side == "left": + # Likely not needed, but keeping for now + # For left padding, we'll pad with 0s using pad_sequence, then handle manually + position_ids = torch.nn.utils.rnn.pad_sequence( + [ + torch.tensor(x["position_ids"], dtype=torch.long) + for x in features + ], + batch_first=True, + padding_value=0, + ) + else: + # For right padding, continue the sequence + max_pos_len = max(len(f["position_ids"]) for f in features) + position_ids_list = [] + for f in features: + pos_seq = torch.tensor(f["position_ids"], dtype=torch.long) + if len(pos_seq) < max_pos_len: + # Continue the sequence + last_pos = pos_seq[-1].item() if len(pos_seq) > 0 else -1 + pad_len = max_pos_len - len(pos_seq) + pad_positions = torch.arange( + last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long + ) + pos_seq = torch.cat([pos_seq, pad_positions]) + position_ids_list.append(pos_seq) + position_ids = torch.stack(position_ids_list) + else: + # Create position_ids if not present + seq_len = input_ids.size(1) + position_ids = ( + torch.arange(seq_len, dtype=torch.long) + .unsqueeze(0) + .expand(input_ids.size(0), -1) + ) + + # Ensure all tensors have the same sequence length + max_seq_len = max( + input_ids.size(1), + labels.size(1), + attention_mask.size(1), + position_ids.size(1), + ) + + # TODO: check if trimming is needed? and correct. + + if do_pad and target_length is not None: + max_seq_len = target_length + + # Pad all tensors to the same length + if input_ids.size(1) < max_seq_len: + pad_len = max_seq_len - input_ids.size(1) + if self.padding_side == "right": + input_ids = F.pad( + input_ids, + (0, pad_len), + value=self.pad_token_id if self.pad_token_id is not None else 0, + ) + else: + input_ids = F.pad( + input_ids, + (pad_len, 0), + value=self.pad_token_id if self.pad_token_id is not None else 0, + ) + elif input_ids.size(1) > max_seq_len: + input_ids = input_ids[:, :max_seq_len] + + if labels.size(1) < max_seq_len: + pad_len = max_seq_len - labels.size(1) + if self.padding_side == "right": + labels = F.pad(labels, (0, pad_len), value=IGNORE_INDEX) + else: + labels = F.pad(labels, (pad_len, 0), value=IGNORE_INDEX) + elif labels.size(1) > max_seq_len: + labels = labels[:, :max_seq_len] + + if attention_mask.size(1) < max_seq_len: + pad_len = max_seq_len - attention_mask.size(1) + if self.padding_side == "right": + attention_mask = F.pad(attention_mask, (0, pad_len), value=0) + else: + attention_mask = F.pad(attention_mask, (pad_len, 0), value=0) + elif attention_mask.size(1) > max_seq_len: + attention_mask = attention_mask[:, :max_seq_len] + + if position_ids.size(1) < max_seq_len: + pad_len = max_seq_len - position_ids.size(1) + if self.padding_side == "right": + batch_size = position_ids.size(0) + new_position_ids = [] + for i in range(batch_size): + seq = position_ids[i] + if len(seq) > 0: + # get last position and pad with sequential values + last_pos = seq[-1].item() + pad_positions = torch.arange( + last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long + ) + new_seq = torch.cat([seq, pad_positions]) + else: + new_seq = torch.arange(pad_len, dtype=torch.long) + new_position_ids.append(new_seq) + position_ids = torch.stack(new_position_ids) + else: + position_ids = F.pad(position_ids, (pad_len, 0), value=0) + elif position_ids.size(1) > max_seq_len: + position_ids = position_ids[:, :max_seq_len] + + final_batch = { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + # Handle non-sequence fields (raise error) + sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"} + for f in features: + for key in f.keys(): + if key not in sequence_fields: + raise NotImplementedError( + f"Non-sequence field {key} not handled yet" + ) + + # Convert to requested tensor type + if return_tensors is None or return_tensors == "np": + result = {} + for k, v in final_batch.items(): + if isinstance(v, torch.Tensor): + result[k] = v.numpy().astype(np.long) + else: + result[k] = v + return result + + if return_tensors == "pt": + return final_batch + + raise ValueError(f"Unsupported return_tensors='{return_tensors}'") + + def convert_ids_to_tokens(self, ids: list[int]) -> list[str]: + """ + Convert a list of token IDs to a list of tokens. + + Args: + ids: The list of token IDs to convert. + + Returns: + The list of tokens. + """ + return [ + self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids + ] diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index dad6aac62..505d39858 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1265,6 +1265,68 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_tokenizer_use_mistral_common(cls, data): + if data.get("tokenizer_use_mistral_common") is None: + if any( + "magistral" in name.lower() + for name in [ + data.get("base_model", ""), + data.get("base_model_config", ""), + data.get("tokenizer_config", ""), + ] + ): + LOG.warning( + "tokenizer_use_mistral_common auto inferred to True for Magistral models. Please set it to True explicitly if you want to use mistral-common tokenizer." + ) + data["tokenizer_use_mistral_common"] = True + + return data + + @field_validator("tokenizer_use_mistral_common", mode="after") + @classmethod + def check_mistral_common_import(cls, tokenizer_use_mistral_common): + if tokenizer_use_mistral_common: + try: + import mistral_common # noqa: F401 # pylint:disable=unused-import + except ImportError as exception: + raise ImportError( + "mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`." + ) from exception + + return tokenizer_use_mistral_common + + @model_validator(mode="before") + @classmethod + def check_mistral_common_incompatible_options(cls, data): + if not data.get("tokenizer_use_mistral_common"): + return data + + # NOTE: mistral-common tokenizer is not compatible with editing tokenizer at the moment + + if data.get("added_tokens_overrides"): + raise ValueError( + "added_tokens_overrides is not supported with mistral-common tokenizer" + ) + + if data.get("special_tokens"): + raise ValueError( + "special_tokens override is not supported with mistral-common tokenizer" + ) + + if data.get("tokens"): + raise ValueError( + "tokens override is not supported with mistral-common tokenizer" + ) + + if data.get("chat_template"): + raise ValueError( + "Setting chat_template is not supported with mistral-common tokenizer" + ) + + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options""" diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 57f5ae309..aafb52152 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -18,6 +18,7 @@ class ModelInputConfig(BaseModel): tokenizer_config: str | None = None tokenizer_use_fast: bool | None = None tokenizer_legacy: bool | None = None + tokenizer_use_mistral_common: bool | None = None tokenizer_type: str | None = Field( default=None, json_schema_extra={"description": "transformers tokenizer class"} ) diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py index fe59e00d8..98488a988 100644 --- a/tests/prompt_strategies/conftest.py +++ b/tests/prompt_strategies/conftest.py @@ -150,6 +150,14 @@ def fixture_gemma2_tokenizer(): return tokenizer +@pytest.fixture(name="magistral_tokenizer") +def fixture_magistral_tokenizer(): + from axolotl.utils.mistral_tokenizer import HFMistralTokenizer + + tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Magistral-Small-2506") + return tokenizer + + @pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja") def fixture_mistralv03_chat_template_jinja_w_system() -> str: return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n' diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py new file mode 100644 index 000000000..3c60a15c2 --- /dev/null +++ b/tests/prompt_strategies/test_chat_templates_mistral.py @@ -0,0 +1,290 @@ +"""Test chat templates for mistral-common wrapper tokenizer""" + +import unittest +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from axolotl.utils.mistral_tokenizer import HFMistralTokenizer + + +def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): + # pylint: disable=duplicate-code + from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy + + # check bos, eos, pad, unk are accessible properties + assert magistral_tokenizer.bos_token_id == 1 + assert magistral_tokenizer.eos_token_id == 2 + assert magistral_tokenizer.pad_token_id == 11 + assert magistral_tokenizer.unk_token_id == 0 + + assert magistral_tokenizer.pad_token == "" + assert magistral_tokenizer.eos_token == "" + assert magistral_tokenizer.bos_token == "" + assert magistral_tokenizer.unk_token == "" + + strategy = MistralStrategy( + MistralPrompter( + magistral_tokenizer, + chat_template=None, + message_property_mappings={"role": "role", "content": "content"}, + ), + tokenizer=magistral_tokenizer, + train_on_inputs=False, + train_on_eos="turn", + sequence_len=512, + roles_to_train=["assistant"], + ) + + # test chat template masking without system prompt + res = strategy.tokenize_prompt( + { + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great, thank you!"}, + ] + } + ) + + assert res["input_ids"] == [ + 1, # bos + 3, # [INST] + 22177, # Hello + 1044, # , + 2606, # how + 1584, # are + 1636, # you + 1063, # ? + 4, # [/INST] + 1073, # I + 4525, # 'm + 6965, # doing + 4824, # great + 1044, # , + 15412, # thank + 1636, # you + 1033, # ! + 2, # + ] + + assert res["labels"] == [ + -100, # bos + -100, # [INST] + -100, # Hello + -100, # , + -100, # how + -100, # are + -100, # you + -100, # ? + -100, # [/INST] + 1073, # I + 4525, # 'm + 6965, # doing + 4824, # great + 1044, # , + 15412, # thank + 1636, # you + 1033, # ! + 2, # + ] + + # test chat template masking with system prompt + res = strategy.tokenize_prompt( + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great, thank you!"}, + ] + } + ) + + assert res["input_ids"] == [ + 1, # bos + 17, # [SYSTEM_PROMPT] + 4568, # You + 1584, # are + 1261, # a + 20351, # helpful + 27089, # assistant + 1046, # . + 18, # [/SYSTEM_PROMPT] + 3, # [INST] + 22177, # Hello + 1044, # , + 2606, # how + 1584, # are + 1636, # you + 1063, # ? + 4, # [/INST] + 1073, # I + 4525, # 'm + 6965, # doing + 4824, # great + 1044, # , + 15412, # thank + 1636, # you + 1033, # ! + 2, # + ] + + assert res["labels"] == [ + -100, # bos + -100, # [SYSTEM_PROMPT] + -100, # You + -100, # are + -100, # a + -100, # helpful + -100, # assistant + -100, # . + -100, # [/SYSTEM_PROMPT] + -100, # [INST] + -100, # Hello + -100, # , + -100, # how + -100, # are + -100, # you + -100, # ? + -100, # [/INST] + 1073, # I + 4525, # 'm + 6965, # doing + 4824, # great + 1044, # , + 15412, # thank + 1636, # you + 1033, # ! + 2, # + ] + + # test chat template with tools + res = strategy.tokenize_prompt( + { + "tools": [ + { + "type": "function", + "function": { + "name": "multiples", + "description": "Generates a list of all the multiples of a number that are less than a given limit.", + "parameters": { + "type": "object", + "properties": { + "number": { + "type": "integer", + "description": "The number to find multiples of.", + }, + "limit": { + "type": "integer", + "description": "The upper limit for the multiples.", + }, + }, + "required": ["number", "limit"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "Hey, can you give me a breakdown of how to throw an awesome themed party? Like, what themes work best, and how can I set everything up to really wow my guests? I want some ideas on decorations, food, and activities that will make the party unforgettable!", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call12345", + "type": "function", + "function": { + "name": "multiples", + "arguments": { + "number": 16, + "limit": 2, + }, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call12345", + "name": "multiples", + "content": "1,2", + }, + {"role": "assistant", "content": "The multiples of 16 is 1 and 2."}, + ], + } + ) + + # fmt: off + assert res["input_ids"] == [ + 1, # bos + 5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt + 3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user + 9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling + 7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8, # tool result + 1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant + 2 # eos + ] + + assert res["labels"] == [ + -100, # bos + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt + 9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool result + 1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant + 2 # eos + ] + # fmt: on + + # test chat template with tokenize=False + res = magistral_tokenizer.apply_chat_template( + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great, thank you!"}, + ], + tokenize=False, + ) + + assert res == "[INST]Hello, how are you?[/INST]I'm doing great, thank you!" + + # test encode + res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=True) + assert res == [ + 1, # bos + 22177, # Hello + 1044, # , + 2606, # how + 1584, # are + 1636, # you + 1063, # ? + 2, # eos + ] + + # test decode no skip special tokens + decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=False) + + assert decoded_res == "Hello, how are you?" + + # test decode skip special tokens + decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=True) + assert decoded_res == "Hello, how are you?" + + # test encode no special tokens + res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=False) + assert res == [ + 22177, # Hello + 1044, # , + 2606, # how + 1584, # are + 1636, # you + 1063, # ? + ] + + # test convert ids to tokens + res = magistral_tokenizer.convert_ids_to_tokens(res) + # spacing are needed as we are converting without decoding + assert res == ["Hello", ",", " how", " are", " you", "?"] + + +if __name__ == "__main__": + unittest.main()