diff --git a/src/axolotl/utils/mistral_tokenizer.py b/src/axolotl/utils/mistral_tokenizer.py index 95c87a822..33c08db46 100644 --- a/src/axolotl/utils/mistral_tokenizer.py +++ b/src/axolotl/utils/mistral_tokenizer.py @@ -497,3 +497,131 @@ class HFMistralTokenizer: return [ self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids ] + + def __call__( + self, + text: str | list[str], + add_special_tokens: bool = True, + padding: bool | str = False, + truncation: bool = False, + max_length: int | None = None, + return_tensors: str | None = None, + **kwargs, + ) -> dict[str, list[int] | np.ndarray | Tensor]: + """ + Tokenize text and return a dictionary with input_ids and attention_mask. + + Args: + text: Input text string or list of strings to tokenize. + add_special_tokens: Whether to add special tokens (BOS/EOS). + padding: Whether to pad sequences. Can be True, False, "longest", or "max_length". + truncation: Whether to truncate sequences to max_length. + max_length: Maximum sequence length for truncation/padding. + return_tensors: Return format ("pt" for PyTorch, "np" for NumPy, None for lists). + + Returns: + Dictionary with "input_ids" and "attention_mask" keys. + """ + # if kwargs passed, raise error + if kwargs: + raise ValueError( + f"Unsupported kwargs: {kwargs}. Please create an issue on GitHub." + ) + + # `np` can work with inhomogeneous shapes but let's not support it until needed. + if ( + isinstance(text, list) + and len(text) > 1 + and return_tensors in ("pt", "np") + and padding is False + and truncation is False + ): + raise ValueError( + "return_tensors='pt' or 'np' requires padding or truncation." + ) + + # Handle single string input + if isinstance(text, str): + text = [text] + + # Encode all texts + # TODO: figure out how to parallelize this + batch_input_ids = [] + for single_text in text: + input_ids = self.encode(single_text, add_special_tokens=add_special_tokens) + + # Handle truncation + if truncation and max_length is not None and len(input_ids) > max_length: + input_ids = input_ids[:max_length] + + batch_input_ids.append(input_ids) + + # Create attention masks (1 for real tokens, 0 for padding) + attention_masks = [[1] * len(input_ids) for input_ids in batch_input_ids] + + # Handle padding + if padding in (True, "longest"): + # Pad to longest sequence in batch + max_len = max(len(input_ids) for input_ids in batch_input_ids) + + for i, input_ids in enumerate(batch_input_ids): + pad_length = max_len - len(input_ids) + if pad_length > 0: + if self.padding_side == "right": + batch_input_ids[i] = ( + input_ids + [self.pad_token_id] * pad_length + ) + attention_masks[i] = attention_masks[i] + [0] * pad_length + else: # left padding + batch_input_ids[i] = [ + self.pad_token_id + ] * pad_length + input_ids + attention_masks[i] = [0] * pad_length + attention_masks[i] + + elif padding == "max_length": + if max_length is None: + raise ValueError( + "max_length must be specified when padding='max_length'" + ) + + for i, input_ids in enumerate(batch_input_ids): + pad_length = max_length - len(input_ids) + if pad_length > 0: + if self.padding_side == "right": + batch_input_ids[i] = ( + input_ids + [self.pad_token_id] * pad_length + ) + attention_masks[i] = attention_masks[i] + [0] * pad_length + else: # left padding + batch_input_ids[i] = [ + self.pad_token_id + ] * pad_length + input_ids + attention_masks[i] = [0] * pad_length + attention_masks[i] + + # Prepare result + result = {} + + # Handle return tensor format + if return_tensors == "pt": + import torch + + result["input_ids"] = torch.tensor(batch_input_ids, dtype=torch.long) + result["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long) + elif return_tensors == "np": + result["input_ids"] = np.array(batch_input_ids, dtype=np.int64) + result["attention_mask"] = np.array(attention_masks, dtype=np.int64) + elif return_tensors is None: + result["input_ids"] = batch_input_ids + result["attention_mask"] = attention_masks + else: + raise ValueError( + f"Unsupported return_tensors='{return_tensors}'. " + "Only 'pt' and 'np' are supported." + ) + + # If single input, return single sequences (not batched) + if len(text) == 1 and return_tensors is None: + result["input_ids"] = result["input_ids"][0] + result["attention_mask"] = result["attention_mask"][0] + + return result diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py index f26ed0838..8e3f494b1 100644 --- a/tests/prompt_strategies/test_chat_templates_mistral.py +++ b/tests/prompt_strategies/test_chat_templates_mistral.py @@ -6,6 +6,8 @@ from typing import TYPE_CHECKING import pytest if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + from axolotl.utils.mistral_tokenizer import HFMistralTokenizer @@ -748,5 +750,100 @@ def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"): assert "Not the same number of function calls and responses" in str(e) +def test_magistral_tokenizer_call_method( + magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer" +): + """Test the __call__ method behavior matches HuggingFace standards""" + from copy import deepcopy + + import numpy as np + import torch + + hf_tokenizer = deepcopy(llama3_tokenizer) + hf_tokenizer.pad_token = hf_tokenizer.eos_token + + test_text = "Hello, how are you?" + batch_texts = ["Hello world", "How are you?"] + + # Test single string with return_tensors=None + hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None) + mistral_result: dict[str, list[int]] = magistral_tokenizer( + test_text, return_tensors=None + ) + + assert isinstance(mistral_result, dict) + assert set(mistral_result.keys()) == {"input_ids", "attention_mask"} + assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list + assert isinstance( + mistral_result["attention_mask"], type(hf_result["attention_mask"]) + ) + assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"]) + assert np.all(mistral_result["attention_mask"]) + assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array + + # Test single string with return_tensors='pt' + hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt") + mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer( + test_text, return_tensors="pt" + ) + + # Check structure and types + assert isinstance(mistral_result_pt["input_ids"], torch.Tensor) + assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor) + + # Check shapes match (don't compare token dimension) + assert len(hf_result_pt["input_ids"].shape) == len( + mistral_result_pt["input_ids"].shape + ) + assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0] + assert ( + mistral_result_pt["attention_mask"].shape + == mistral_result_pt["input_ids"].shape + ) + assert torch.all(mistral_result_pt["attention_mask"] == 1) + + # Test batch input with padding + hf_batch: dict[str, torch.Tensor] = hf_tokenizer( + batch_texts, return_tensors="pt", padding=True + ) + mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer( + batch_texts, return_tensors="pt", padding=True + ) + + # Check batch behavior + assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape) + assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0] + assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape + assert torch.any( + mistral_batch["attention_mask"][0] == 0 + ) # padding in shorter sequence + assert torch.all( + mistral_batch["attention_mask"][1] == 1 + ) # no padding in longer sequence + + # Test numpy tensors + mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer( + test_text, return_tensors="np" + ) + assert isinstance(mistral_result_np["input_ids"], np.ndarray) + assert isinstance(mistral_result_np["attention_mask"], np.ndarray) + + # Test consistency with encode() + encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True) + called: dict[str, torch.Tensor] = magistral_tokenizer( + test_text, return_tensors="pt" + ) + assert encoded == called["input_ids"][0].tolist() + + # Test Error handling + with pytest.raises(ValueError, match="Unsupported kwargs"): + magistral_tokenizer(test_text, unsupported_param=True) + + with pytest.raises( + ValueError, match="return_tensors='pt' or 'np' requires padding or truncation" + ): + magistral_tokenizer(batch_texts, return_tensors="pt") + + if __name__ == "__main__": unittest.main()