diff --git a/examples/devstral/README.md b/examples/devstral/README.md new file mode 100644 index 000000000..9dc5377bc --- /dev/null +++ b/examples/devstral/README.md @@ -0,0 +1,69 @@ +# Finetune Devstral with Axolotl + +Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking. + +The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of upto 128k tokens. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0+) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install the latest mistral-common from source +pip3 uninstall mistral-common +pip3 install git+https://github.com/mistralai/mistral-common.git@039465d + +``` + +2. Run the finetuning example: + +```bash +axolotl train examples/devstral/devstral-small-qlora.yml +``` + +This config uses about 21GB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) +- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels) + +## Limitations + +We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. + +In addition, we do not support overriding tokens yet. + +## Related Resources + +- [MistralAI Devstral Blog](https://mistral.ai/news/devstral) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) + + +## Future Work + +- Add parity to Preference Tuning, RL, Multi-modal, etc. +- Add parity to other tokenizer configs like overriding tokens. diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml new file mode 100644 index 000000000..d2c5930e3 --- /dev/null +++ b/examples/devstral/devstral-small-qlora.yml @@ -0,0 +1,64 @@ +base_model: mistralai/Devstral-Small-2505 + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +load_in_8bit: false +load_in_4bit: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0 +lora_target_linear: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_ratio: 0.05 +evals_per_epoch: 4 +saves_per_epoch: 1 + +weight_decay: 0.0 +special_tokens: diff --git a/examples/magistral/README.md b/examples/magistral/README.md index a2b09ab70..0c39c061b 100644 --- a/examples/magistral/README.md +++ b/examples/magistral/README.md @@ -18,16 +18,10 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git cd axolotl pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja -pip3 install --no-build-isolation -e '.[flash-attn,mistral]' +pip3 install --no-build-isolation -e '.[flash-attn]' ``` -2. Download the example config: - -```bash -axolotl fetch examples -``` - -3. Run the finetuning example: +2. Run the finetuning example: ```bash axolotl train examples/magistral/magistral-small-qlora.yaml @@ -42,7 +36,7 @@ Let us know how it goes. Happy finetuning! 🚀 - For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`. - You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. - Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). -- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). ## Optimization Guides @@ -54,7 +48,7 @@ Let us know how it goes. Happy finetuning! 🚀 We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. -The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet. +In addition, we do not support overriding tokens yet. ## Related Resources diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 28182b16f..7c112c59e 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -48,13 +48,6 @@ class TokenizedPromptDataset(Dataset): features = dataset.features.keys() num_proc = min(64, self.process_count if self.process_count else os.cpu_count()) - # Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common) - if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True): - LOG.info( - "Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)" - ) - num_proc = 1 - map_kwargs = {} if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 6d2a048b2..a9d26a650 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -681,13 +681,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): for message in messages: transformed_message = self.transform_message(message) - turn = { - **transformed_message, - "training": message.get(self.prompter.message_field_training), - "training_detail": message.get( - self.prompter.message_field_training_detail - ), - } + turn = transformed_message + + training = message.get(self.prompter.message_field_training) + training_detail = message.get(self.prompter.message_field_training_detail) + if training is not None: + turn["training"] = training + if training_detail is not None: + turn["training_detail"] = training_detail turns.append(turn) @@ -859,15 +860,6 @@ class MistralStrategy(ChatTemplateStrategy): # TODO: address this in the future with mistral-specific checks # self._validate_eot_and_eos_tokens() - @property - def supports_multiprocessing(self) -> bool: - """ - Whether this tokenizing strategy supports multiprocessing. - mistral_common tokenizers cannot be pickled for multiprocessing. - """ - - return False - def find_first_eot_token(self, input_ids, start_idx): """Find the first EOT token in the input_ids starting from start_idx.""" # mistral-common tokenizer does not support eot_tokens diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index aae778ae8..9ca645de3 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -70,14 +70,6 @@ class PromptTokenizingStrategy(abc.ABC): def supports_batched(self): return False - @property - def supports_multiprocessing(self): - """ - Whether this tokenizing strategy supports multiprocessing. - Should return False if the tokenizer has unpicklable objects. - """ - return True - def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index a28f360be..25a871b2b 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -108,7 +108,7 @@ class DataCollatorForSeq2Seq: pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, ) - if not has_attn_mask: + if not has_attn_mask and "attention_mask" in features: del features["attention_mask"] # prepare decoder_input_ids diff --git a/src/axolotl/utils/mistral_tokenizer.py b/src/axolotl/utils/mistral_tokenizer.py index 1ba824938..95c87a822 100644 --- a/src/axolotl/utils/mistral_tokenizer.py +++ b/src/axolotl/utils/mistral_tokenizer.py @@ -3,10 +3,11 @@ import math import os from shutil import copyfile -from typing import TYPE_CHECKING, Optional +from typing import Optional import numpy as np from huggingface_hub import hf_hub_download +from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer from torch import Tensor @@ -14,9 +15,6 @@ from transformers.utils import PaddingStrategy from axolotl.utils.collators.core import IGNORE_INDEX -if TYPE_CHECKING: - from mistral_common.protocol.instruct.request import ChatCompletionRequest - def _get_file_path(path_or_repo_id: str, filename: str) -> str: """Get the file path from local or HF Hub""" @@ -259,75 +257,6 @@ class HFMistralTokenizer: token_ids, special_token_policy=SpecialTokenPolicy.KEEP ) - def _create_mistral_chat_completion_request( - self, conversation: list[dict], tools: list[dict] | None = None - ) -> "ChatCompletionRequest": - from mistral_common.protocol.instruct.messages import ( - AssistantMessage, - SystemMessage, - ToolMessage, - UserMessage, - ) - from mistral_common.protocol.instruct.request import ChatCompletionRequest - from mistral_common.protocol.instruct.tool_calls import Function, Tool - - messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = ( - [] - ) - for turn in conversation: - role = turn.get("role") - - if role == "user": - messages.append(UserMessage(content=turn["content"])) - elif role == "assistant": - messages.append( - AssistantMessage( - content=turn.get("content"), - tool_calls=turn.get("tool_calls"), - ) - ) - elif role == "tool": - messages.append( - ToolMessage( - content=turn.get("content"), - tool_call_id=turn.get("tool_call_id"), - name=turn.get("name"), - ) - ) - elif role == "system": - messages.append(SystemMessage(content=turn["content"])) - else: - raise ValueError( - f"Unknown role for use with mistral-common tokenizer: {turn['role']}" - ) - - tool_calls: list[Tool] = [] - if tools: - # convert to Tool - for tool in tools: - if tool["type"] != "function": - continue - - function = tool["function"] - - tool_calls.append( - Tool( - function=Function( - name=function["name"], - description=function["description"], - # set parameters to empty dict if not provided - parameters=function.get("parameters", {}), - ) - ) - ) - - chat_completion: ChatCompletionRequest = ChatCompletionRequest( - messages=messages, - tools=tool_calls, - ) - - return chat_completion - def apply_chat_template( self, messages: list[dict], @@ -342,8 +271,8 @@ class HFMistralTokenizer: if add_generation_prompt: raise NotImplementedError("add_generation_prompt not supported yet") - chat_completion: ChatCompletionRequest = ( - self._create_mistral_chat_completion_request(messages, tools) + chat_completion: ChatCompletionRequest = ChatCompletionRequest.from_openai( + messages, tools ) tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens @@ -408,13 +337,16 @@ class HFMistralTokenizer: padding_value=IGNORE_INDEX, ) - attention_mask = torch.nn.utils.rnn.pad_sequence( - [torch.tensor(x["attention_mask"], dtype=torch.long) for x in features], - batch_first=True, - padding_value=0, - ) + attention_mask = None + if "attention_mask" in features[0]: + attention_mask = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(x["attention_mask"], dtype=torch.long) for x in features], + batch_first=True, + padding_value=0, + ) # Handle position_ids - pad with sequential values for right padding, 0s for left padding + position_ids = None if "position_ids" in features[0]: if self.padding_side == "left": # Likely not needed, but keeping for now @@ -443,22 +375,15 @@ class HFMistralTokenizer: pos_seq = torch.cat([pos_seq, pad_positions]) position_ids_list.append(pos_seq) position_ids = torch.stack(position_ids_list) - else: - # Create position_ids if not present - seq_len = input_ids.size(1) - position_ids = ( - torch.arange(seq_len, dtype=torch.long) - .unsqueeze(0) - .expand(input_ids.size(0), -1) - ) # Ensure all tensors have the same sequence length - max_seq_len = max( - input_ids.size(1), - labels.size(1), - attention_mask.size(1), - position_ids.size(1), - ) + # Check attention mask and position ids if they are present + tensor_lengths = [input_ids.size(1), labels.size(1)] + if attention_mask is not None: + tensor_lengths.append(attention_mask.size(1)) + if position_ids is not None: + tensor_lengths.append(position_ids.size(1)) + max_seq_len = max(tensor_lengths) # TODO: check if trimming is needed? and correct. @@ -492,44 +417,48 @@ class HFMistralTokenizer: elif labels.size(1) > max_seq_len: labels = labels[:, :max_seq_len] - if attention_mask.size(1) < max_seq_len: - pad_len = max_seq_len - attention_mask.size(1) - if self.padding_side == "right": - attention_mask = F.pad(attention_mask, (0, pad_len), value=0) - else: - attention_mask = F.pad(attention_mask, (pad_len, 0), value=0) - elif attention_mask.size(1) > max_seq_len: - attention_mask = attention_mask[:, :max_seq_len] + if attention_mask is not None: + if attention_mask.size(1) < max_seq_len: + pad_len = max_seq_len - attention_mask.size(1) + if self.padding_side == "right": + attention_mask = F.pad(attention_mask, (0, pad_len), value=0) + else: + attention_mask = F.pad(attention_mask, (pad_len, 0), value=0) + elif attention_mask.size(1) > max_seq_len: + attention_mask = attention_mask[:, :max_seq_len] - if position_ids.size(1) < max_seq_len: - pad_len = max_seq_len - position_ids.size(1) - if self.padding_side == "right": - batch_size = position_ids.size(0) - new_position_ids = [] - for i in range(batch_size): - seq = position_ids[i] - if len(seq) > 0: - # get last position and pad with sequential values - last_pos = seq[-1].item() - pad_positions = torch.arange( - last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long - ) - new_seq = torch.cat([seq, pad_positions]) - else: - new_seq = torch.arange(pad_len, dtype=torch.long) - new_position_ids.append(new_seq) - position_ids = torch.stack(new_position_ids) - else: - position_ids = F.pad(position_ids, (pad_len, 0), value=0) - elif position_ids.size(1) > max_seq_len: - position_ids = position_ids[:, :max_seq_len] + if position_ids is not None: + if position_ids.size(1) < max_seq_len: + pad_len = max_seq_len - position_ids.size(1) + if self.padding_side == "right": + batch_size = position_ids.size(0) + new_position_ids = [] + for i in range(batch_size): + seq = position_ids[i] + if len(seq) > 0: + # get last position and pad with sequential values + last_pos = seq[-1].item() + pad_positions = torch.arange( + last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long + ) + new_seq = torch.cat([seq, pad_positions]) + else: + new_seq = torch.arange(pad_len, dtype=torch.long) + new_position_ids.append(new_seq) + position_ids = torch.stack(new_position_ids) + else: + position_ids = F.pad(position_ids, (pad_len, 0), value=0) + elif position_ids.size(1) > max_seq_len: + position_ids = position_ids[:, :max_seq_len] final_batch = { "input_ids": input_ids, "labels": labels, - "attention_mask": attention_mask, - "position_ids": position_ids, } + if attention_mask is not None: + final_batch["attention_mask"] = attention_mask + if position_ids is not None: + final_batch["position_ids"] = position_ids # Handle non-sequence fields (raise error) sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"} @@ -545,7 +474,7 @@ class HFMistralTokenizer: result = {} for k, v in final_batch.items(): if isinstance(v, torch.Tensor): - result[k] = v.numpy().astype(np.long) + result[k] = v.numpy().astype(np.int64) else: result[k] = v return result diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py index d440565d2..60b14d652 100644 --- a/tests/prompt_strategies/conftest.py +++ b/tests/prompt_strategies/conftest.py @@ -164,6 +164,14 @@ def fixture_magistral_tokenizer(): return tokenizer +@pytest.fixture(name="devstral_tokenizer") +def fixture_devstral_tokenizer(): + from axolotl.utils.mistral_tokenizer import HFMistralTokenizer + + tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2505") + return tokenizer + + @pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja") def fixture_mistralv03_chat_template_jinja_w_system() -> str: return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n' diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py index 3c60a15c2..dcf5138d3 100644 --- a/tests/prompt_strategies/test_chat_templates_mistral.py +++ b/tests/prompt_strategies/test_chat_templates_mistral.py @@ -3,32 +3,50 @@ import unittest from typing import TYPE_CHECKING +import pytest + if TYPE_CHECKING: from axolotl.utils.mistral_tokenizer import HFMistralTokenizer -def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): +# fmt: off +@pytest.mark.parametrize( + ("tokenizer_str", "assistant_toolcall_ids"), + ( + ("magistral_tokenizer", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2)), + ("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2)), + ) +) +# fmt: on +def test_mistral_chat_template( + tokenizer_str: str, + assistant_toolcall_ids: tuple[int, ...], + request: pytest.FixtureRequest, +): + """Test chat template with the Magistral/Devstral tokenizer""" # pylint: disable=duplicate-code from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy - # check bos, eos, pad, unk are accessible properties - assert magistral_tokenizer.bos_token_id == 1 - assert magistral_tokenizer.eos_token_id == 2 - assert magistral_tokenizer.pad_token_id == 11 - assert magistral_tokenizer.unk_token_id == 0 + tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str) - assert magistral_tokenizer.pad_token == "" - assert magistral_tokenizer.eos_token == "" - assert magistral_tokenizer.bos_token == "" - assert magistral_tokenizer.unk_token == "" + # check bos, eos, pad, unk are accessible properties + assert tokenizer.bos_token_id == 1 + assert tokenizer.eos_token_id == 2 + assert tokenizer.pad_token_id == 11 + assert tokenizer.unk_token_id == 0 + + assert tokenizer.pad_token == "" + assert tokenizer.eos_token == "" + assert tokenizer.bos_token == "" + assert tokenizer.unk_token == "" strategy = MistralStrategy( MistralPrompter( - magistral_tokenizer, + tokenizer, chat_template=None, message_property_mappings={"role": "role", "content": "content"}, ), - tokenizer=magistral_tokenizer, + tokenizer=tokenizer, train_on_inputs=False, train_on_eos="turn", sequence_len=512, @@ -219,7 +237,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): 1, # bos 5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt 3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user - 9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling + *assistant_toolcall_ids, # assistant tool calling 7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8, # tool result 1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant 2 # eos @@ -229,7 +247,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): -100, # bos -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt - 9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling + *assistant_toolcall_ids, # assistant tool calling -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool result 1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant 2 # eos @@ -237,7 +255,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): # fmt: on # test chat template with tokenize=False - res = magistral_tokenizer.apply_chat_template( + res = tokenizer.apply_chat_template( [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing great, thank you!"}, @@ -248,7 +266,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): assert res == "[INST]Hello, how are you?[/INST]I'm doing great, thank you!" # test encode - res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=True) + res = tokenizer.encode("Hello, how are you?", add_special_tokens=True) assert res == [ 1, # bos 22177, # Hello @@ -261,16 +279,16 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): ] # test decode no skip special tokens - decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=False) + decoded_res = tokenizer.decode(res, skip_special_tokens=False) assert decoded_res == "Hello, how are you?" # test decode skip special tokens - decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=True) + decoded_res = tokenizer.decode(res, skip_special_tokens=True) assert decoded_res == "Hello, how are you?" # test encode no special tokens - res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=False) + res = tokenizer.encode("Hello, how are you?", add_special_tokens=False) assert res == [ 22177, # Hello 1044, # , @@ -281,10 +299,452 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): ] # test convert ids to tokens - res = magistral_tokenizer.convert_ids_to_tokens(res) + res = tokenizer.convert_ids_to_tokens(res) # spacing are needed as we are converting without decoding assert res == ["Hello", ",", " how", " are", " you", "?"] +def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"): + """Test the MistralTokenizer pad method""" + from axolotl.utils.collators.core import IGNORE_INDEX + + magistral_pad_token_id = 11 # taken from tokenizer.pad_token_id + + # Test padding with input_ids and labels only + features = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [7, 8], "labels": [9, 10]}, + ] + + result = magistral_tokenizer.pad(features, padding=True, return_tensors="pt") + + # Check that input_ids are padded correctly + assert result["input_ids"].shape == (2, 3) + assert result["input_ids"].tolist() == [[1, 2, 3], [7, 8, magistral_pad_token_id]] + + # Check that labels are padded correctly + assert result["labels"].shape == (2, 3) + assert result["labels"].tolist() == [[4, 5, 6], [9, 10, IGNORE_INDEX]] + + # Check that attention_mask and position_ids are NOT created + assert "attention_mask" not in result + assert "position_ids" not in result + + # Test padding with attention_mask + features_with_attention = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "attention_mask": [1, 1, 1]}, + {"input_ids": [7, 8], "labels": [9, 10], "attention_mask": [1, 1]}, + ] + + result = magistral_tokenizer.pad( + features_with_attention, padding=True, return_tensors="pt" + ) + + # Check that attention_mask is padded correctly + assert result["attention_mask"].shape == (2, 3) + assert result["attention_mask"].tolist() == [[1, 1, 1], [1, 1, 0]] + + # Test padding with position_ids + features_with_position = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "position_ids": [0, 1, 2]}, + {"input_ids": [7, 8], "labels": [9, 10], "position_ids": [0, 1]}, + ] + + result = magistral_tokenizer.pad( + features_with_position, padding=True, return_tensors="pt" + ) + + # Check that position_ids are padded correctly (continuing sequence) + assert result["position_ids"].shape == (2, 3) + assert result["position_ids"].tolist() == [[0, 1, 2], [0, 1, 2]] + + # Test padding with all fields + features_all = [ + { + "input_ids": [1, 2, 3], + "labels": [4, 5, 6], + "attention_mask": [1, 1, 1], + "position_ids": [0, 1, 2], + }, + { + "input_ids": [7, 8], + "labels": [9, 10], + "attention_mask": [1, 1], + "position_ids": [0, 1], + }, + ] + + result = magistral_tokenizer.pad(features_all, padding=True, return_tensors="pt") + + # All fields should be present and correctly padded + assert "input_ids" in result + assert "labels" in result + assert "attention_mask" in result + assert "position_ids" in result + + # Test padding with all sequences same length + features_same_length = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [7, 8, 9], "labels": [10, 11, 12]}, + ] + + result = magistral_tokenizer.pad( + features_same_length, padding=True, return_tensors="pt" + ) + + # Check match when no padding is needed + assert result["input_ids"][0].tolist() == features_same_length[0]["input_ids"] + assert result["labels"][0].tolist() == features_same_length[0]["labels"] + + assert result["input_ids"][1].tolist() == features_same_length[1]["input_ids"] + assert result["labels"][1].tolist() == features_same_length[1]["labels"] + + # Test padding with max_length parameter + result = magistral_tokenizer.pad( + features, padding="max_length", max_length=5, return_tensors="pt" + ) + + # Should pad to max_length + assert result["input_ids"].shape == (2, 5) + assert result["labels"].shape == (2, 5) + + # Test numpy return type + result = magistral_tokenizer.pad(features, padding=True, return_tensors="np") + + # Should return numpy arrays + import numpy as np + + assert isinstance(result["input_ids"], np.ndarray) + assert isinstance(result["labels"], np.ndarray) + + # Test unsupported field rejection + features_unsupported = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "unsupported_field": [7, 8, 9]}, + ] + + with pytest.raises(NotImplementedError, match="unsupported_field"): + magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors="pt") + + # Test token_type_ids rejection + features_token_type = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "token_type_ids": [0, 0, 0]}, + ] + + with pytest.raises(ValueError, match="token_type_ids is not supported"): + magistral_tokenizer.pad(features_token_type, padding=True, return_tensors="pt") + + +def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"): + """Test tool calling with the Magistral tokenizer""" + from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy + + strategy = MistralStrategy( + MistralPrompter( + magistral_tokenizer, + chat_template=None, + message_property_mappings={"role": "role", "content": "content"}, + ), + tokenizer=magistral_tokenizer, + train_on_inputs=False, + train_on_eos="turn", + sequence_len=512, + roles_to_train=["assistant"], + ) + + # Test basic tool calling with single function + basic_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + }, + "required": ["location"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "What's the weather like in San Francisco?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call12345", + "type": "function", + "function": { + "name": "get_weather", + "arguments": { + "location": "San Francisco, CA", + }, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call12345", + "name": "get_weather", + "content": "Sunny, 72°F", + }, + { + "role": "assistant", + "content": "The weather in San Francisco is sunny and 72°F.", + }, + ], + } + + res = strategy.tokenize_prompt(basic_tool_calling) + + # Basic validation + assert "input_ids" in res + assert "labels" in res + assert len(res["input_ids"]) > 0 + assert len(res["labels"]) == len(res["input_ids"]) + + # Decode and verify structure + decoded = magistral_tokenizer.decode(res["input_ids"]) + assert ( + '[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS]' + in decoded + ) + assert ( + '[TOOL_CALLS]get_weather[CALL_ID]call12345[ARGS]{"location": "San Francisco, CA"}' + in decoded + ) + assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]Sunny, 72°F[/TOOL_RESULTS]" in decoded + assert "The weather in San Francisco is sunny and 72°F." in decoded + + # Test multiple tool calls in sequence + multi_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "add_numbers", + "description": "Add two numbers together", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "multiply_numbers", + "description": "Multiply two numbers", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "First number"}, + "y": {"type": "number", "description": "Second number"}, + }, + "required": ["x", "y"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "Add 5 and 3, then multiply the result by 2", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call12345", + "type": "function", + "function": { + "name": "add_numbers", + "arguments": {"a": 5, "b": 3}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call12345", + "name": "add_numbers", + "content": "8", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call23456", + "type": "function", + "function": { + "name": "multiply_numbers", + "arguments": {"x": 8, "y": 2}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call23456", + "name": "multiply_numbers", + "content": "16", + }, + { + "role": "assistant", + "content": "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.", + }, + ], + } + + res = strategy.tokenize_prompt(multi_tool_calling) + + # Validation + assert len(res["input_ids"]) > 0 + assert len(res["labels"]) == len(res["input_ids"]) + + decoded = magistral_tokenizer.decode(res["input_ids"]) + assert ( + '[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "add_numbers", "description": "Add two numbers together", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "First number"}, "b": {"type": "number", "description": "Second number"}}, "required": ["a", "b"]}}}, {"type": "function", "function": {"name": "multiply_numbers", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "First number"}, "y": {"type": "number", "description": "Second number"}}, "required": ["x", "y"]}}}][/AVAILABLE_TOOLS]' + in decoded + ) + assert ( + '[TOOL_CALLS]add_numbers[CALL_ID]call12345[ARGS]{"a": 5, "b": 3}' in decoded + ) + assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]8[/TOOL_RESULTS]" in decoded + assert ( + '[TOOL_CALLS]multiply_numbers[CALL_ID]call23456[ARGS]{"x": 8, "y": 2}' + in decoded + ) + assert "[TOOL_RESULTS]call23456[TOOL_CONTENT]16[/TOOL_RESULTS]" in decoded + assert ( + "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16." + in decoded + ) + + # Test tool calling with system message + system_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "search_database", + "description": "Search for information in database", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + }, + "required": ["query"], + }, + }, + }, + ], + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant with access to a database.", + }, + { + "role": "user", + "content": "Find information about Python programming", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "search123", + "type": "function", + "function": { + "name": "search_database", + "arguments": {"query": "Python programming"}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "search123", + "name": "search_database", + "content": "Python is a high-level programming language known for its simplicity.", + }, + { + "role": "assistant", + "content": "Based on the database search, Python is a high-level programming language known for its simplicity and readability.", + }, + ], + } + + res = strategy.tokenize_prompt(system_tool_calling) + + # Validation + assert len(res["input_ids"]) > 0 + assert len(res["labels"]) == len(res["input_ids"]) + + decoded = magistral_tokenizer.decode(res["input_ids"]) + + assert ( + '[SYSTEM_PROMPT]You are a helpful assistant with access to a database.[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "search_database", "description": "Search for information in database", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}}}][/AVAILABLE_TOOLS]' + in decoded + ) + + # Test error handling - missing tool response + incomplete_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "get_time", + "description": "Get current time", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "What time is it?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "time12345", + "type": "function", + "function": { + "name": "get_time", + "arguments": {}, + }, + } + ], + }, + { + "role": "assistant", + "content": "The current time is 12:00 PM.", + }, + ], + } + + from mistral_common.exceptions import InvalidMessageStructureException + + try: + strategy.tokenize_prompt(incomplete_tool_calling) + except InvalidMessageStructureException as e: + assert "Not the same number of function calls and responses" in str(e) + + if __name__ == "__main__": unittest.main()