feat: add call method to mistral tokenizer wrapper (#2898)
This commit is contained in:
@@ -497,3 +497,131 @@ class HFMistralTokenizer:
|
|||||||
return [
|
return [
|
||||||
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
|
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: str | list[str],
|
||||||
|
add_special_tokens: bool = True,
|
||||||
|
padding: bool | str = False,
|
||||||
|
truncation: bool = False,
|
||||||
|
max_length: int | None = None,
|
||||||
|
return_tensors: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, list[int] | np.ndarray | Tensor]:
|
||||||
|
"""
|
||||||
|
Tokenize text and return a dictionary with input_ids and attention_mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text string or list of strings to tokenize.
|
||||||
|
add_special_tokens: Whether to add special tokens (BOS/EOS).
|
||||||
|
padding: Whether to pad sequences. Can be True, False, "longest", or "max_length".
|
||||||
|
truncation: Whether to truncate sequences to max_length.
|
||||||
|
max_length: Maximum sequence length for truncation/padding.
|
||||||
|
return_tensors: Return format ("pt" for PyTorch, "np" for NumPy, None for lists).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with "input_ids" and "attention_mask" keys.
|
||||||
|
"""
|
||||||
|
# if kwargs passed, raise error
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported kwargs: {kwargs}. Please create an issue on GitHub."
|
||||||
|
)
|
||||||
|
|
||||||
|
# `np` can work with inhomogeneous shapes but let's not support it until needed.
|
||||||
|
if (
|
||||||
|
isinstance(text, list)
|
||||||
|
and len(text) > 1
|
||||||
|
and return_tensors in ("pt", "np")
|
||||||
|
and padding is False
|
||||||
|
and truncation is False
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"return_tensors='pt' or 'np' requires padding or truncation."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle single string input
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
# Encode all texts
|
||||||
|
# TODO: figure out how to parallelize this
|
||||||
|
batch_input_ids = []
|
||||||
|
for single_text in text:
|
||||||
|
input_ids = self.encode(single_text, add_special_tokens=add_special_tokens)
|
||||||
|
|
||||||
|
# Handle truncation
|
||||||
|
if truncation and max_length is not None and len(input_ids) > max_length:
|
||||||
|
input_ids = input_ids[:max_length]
|
||||||
|
|
||||||
|
batch_input_ids.append(input_ids)
|
||||||
|
|
||||||
|
# Create attention masks (1 for real tokens, 0 for padding)
|
||||||
|
attention_masks = [[1] * len(input_ids) for input_ids in batch_input_ids]
|
||||||
|
|
||||||
|
# Handle padding
|
||||||
|
if padding in (True, "longest"):
|
||||||
|
# Pad to longest sequence in batch
|
||||||
|
max_len = max(len(input_ids) for input_ids in batch_input_ids)
|
||||||
|
|
||||||
|
for i, input_ids in enumerate(batch_input_ids):
|
||||||
|
pad_length = max_len - len(input_ids)
|
||||||
|
if pad_length > 0:
|
||||||
|
if self.padding_side == "right":
|
||||||
|
batch_input_ids[i] = (
|
||||||
|
input_ids + [self.pad_token_id] * pad_length
|
||||||
|
)
|
||||||
|
attention_masks[i] = attention_masks[i] + [0] * pad_length
|
||||||
|
else: # left padding
|
||||||
|
batch_input_ids[i] = [
|
||||||
|
self.pad_token_id
|
||||||
|
] * pad_length + input_ids
|
||||||
|
attention_masks[i] = [0] * pad_length + attention_masks[i]
|
||||||
|
|
||||||
|
elif padding == "max_length":
|
||||||
|
if max_length is None:
|
||||||
|
raise ValueError(
|
||||||
|
"max_length must be specified when padding='max_length'"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, input_ids in enumerate(batch_input_ids):
|
||||||
|
pad_length = max_length - len(input_ids)
|
||||||
|
if pad_length > 0:
|
||||||
|
if self.padding_side == "right":
|
||||||
|
batch_input_ids[i] = (
|
||||||
|
input_ids + [self.pad_token_id] * pad_length
|
||||||
|
)
|
||||||
|
attention_masks[i] = attention_masks[i] + [0] * pad_length
|
||||||
|
else: # left padding
|
||||||
|
batch_input_ids[i] = [
|
||||||
|
self.pad_token_id
|
||||||
|
] * pad_length + input_ids
|
||||||
|
attention_masks[i] = [0] * pad_length + attention_masks[i]
|
||||||
|
|
||||||
|
# Prepare result
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Handle return tensor format
|
||||||
|
if return_tensors == "pt":
|
||||||
|
import torch
|
||||||
|
|
||||||
|
result["input_ids"] = torch.tensor(batch_input_ids, dtype=torch.long)
|
||||||
|
result["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
|
||||||
|
elif return_tensors == "np":
|
||||||
|
result["input_ids"] = np.array(batch_input_ids, dtype=np.int64)
|
||||||
|
result["attention_mask"] = np.array(attention_masks, dtype=np.int64)
|
||||||
|
elif return_tensors is None:
|
||||||
|
result["input_ids"] = batch_input_ids
|
||||||
|
result["attention_mask"] = attention_masks
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported return_tensors='{return_tensors}'. "
|
||||||
|
"Only 'pt' and 'np' are supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If single input, return single sequences (not batched)
|
||||||
|
if len(text) == 1 and return_tensors is None:
|
||||||
|
result["input_ids"] = result["input_ids"][0]
|
||||||
|
result["attention_mask"] = result["attention_mask"][0]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -748,5 +750,100 @@ def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
|
|||||||
assert "Not the same number of function calls and responses" in str(e)
|
assert "Not the same number of function calls and responses" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_magistral_tokenizer_call_method(
|
||||||
|
magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer"
|
||||||
|
):
|
||||||
|
"""Test the __call__ method behavior matches HuggingFace standards"""
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
hf_tokenizer = deepcopy(llama3_tokenizer)
|
||||||
|
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||||
|
|
||||||
|
test_text = "Hello, how are you?"
|
||||||
|
batch_texts = ["Hello world", "How are you?"]
|
||||||
|
|
||||||
|
# Test single string with return_tensors=None
|
||||||
|
hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None)
|
||||||
|
mistral_result: dict[str, list[int]] = magistral_tokenizer(
|
||||||
|
test_text, return_tensors=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(mistral_result, dict)
|
||||||
|
assert set(mistral_result.keys()) == {"input_ids", "attention_mask"}
|
||||||
|
assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list
|
||||||
|
assert isinstance(
|
||||||
|
mistral_result["attention_mask"], type(hf_result["attention_mask"])
|
||||||
|
)
|
||||||
|
assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"])
|
||||||
|
assert np.all(mistral_result["attention_mask"])
|
||||||
|
assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array
|
||||||
|
|
||||||
|
# Test single string with return_tensors='pt'
|
||||||
|
hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt")
|
||||||
|
mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||||
|
test_text, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check structure and types
|
||||||
|
assert isinstance(mistral_result_pt["input_ids"], torch.Tensor)
|
||||||
|
assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor)
|
||||||
|
|
||||||
|
# Check shapes match (don't compare token dimension)
|
||||||
|
assert len(hf_result_pt["input_ids"].shape) == len(
|
||||||
|
mistral_result_pt["input_ids"].shape
|
||||||
|
)
|
||||||
|
assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0]
|
||||||
|
assert (
|
||||||
|
mistral_result_pt["attention_mask"].shape
|
||||||
|
== mistral_result_pt["input_ids"].shape
|
||||||
|
)
|
||||||
|
assert torch.all(mistral_result_pt["attention_mask"] == 1)
|
||||||
|
|
||||||
|
# Test batch input with padding
|
||||||
|
hf_batch: dict[str, torch.Tensor] = hf_tokenizer(
|
||||||
|
batch_texts, return_tensors="pt", padding=True
|
||||||
|
)
|
||||||
|
mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||||
|
batch_texts, return_tensors="pt", padding=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check batch behavior
|
||||||
|
assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape)
|
||||||
|
assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0]
|
||||||
|
assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape
|
||||||
|
assert torch.any(
|
||||||
|
mistral_batch["attention_mask"][0] == 0
|
||||||
|
) # padding in shorter sequence
|
||||||
|
assert torch.all(
|
||||||
|
mistral_batch["attention_mask"][1] == 1
|
||||||
|
) # no padding in longer sequence
|
||||||
|
|
||||||
|
# Test numpy tensors
|
||||||
|
mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer(
|
||||||
|
test_text, return_tensors="np"
|
||||||
|
)
|
||||||
|
assert isinstance(mistral_result_np["input_ids"], np.ndarray)
|
||||||
|
assert isinstance(mistral_result_np["attention_mask"], np.ndarray)
|
||||||
|
|
||||||
|
# Test consistency with encode()
|
||||||
|
encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True)
|
||||||
|
called: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||||
|
test_text, return_tensors="pt"
|
||||||
|
)
|
||||||
|
assert encoded == called["input_ids"][0].tolist()
|
||||||
|
|
||||||
|
# Test Error handling
|
||||||
|
with pytest.raises(ValueError, match="Unsupported kwargs"):
|
||||||
|
magistral_tokenizer(test_text, unsupported_param=True)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="return_tensors='pt' or 'np' requires padding or truncation"
|
||||||
|
):
|
||||||
|
magistral_tokenizer(batch_texts, return_tensors="pt")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user