feat: add call method to mistral tokenizer wrapper (#2898)

This commit is contained in:
NanoCode012
2025-07-15 09:33:35 +07:00
committed by GitHub
parent a061446540
commit 354eaaf0d3
2 changed files with 225 additions and 0 deletions

View File

@@ -497,3 +497,131 @@ class HFMistralTokenizer:
return [
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
]
def __call__(
self,
text: str | list[str],
add_special_tokens: bool = True,
padding: bool | str = False,
truncation: bool = False,
max_length: int | None = None,
return_tensors: str | None = None,
**kwargs,
) -> dict[str, list[int] | np.ndarray | Tensor]:
"""
Tokenize text and return a dictionary with input_ids and attention_mask.
Args:
text: Input text string or list of strings to tokenize.
add_special_tokens: Whether to add special tokens (BOS/EOS).
padding: Whether to pad sequences. Can be True, False, "longest", or "max_length".
truncation: Whether to truncate sequences to max_length.
max_length: Maximum sequence length for truncation/padding.
return_tensors: Return format ("pt" for PyTorch, "np" for NumPy, None for lists).
Returns:
Dictionary with "input_ids" and "attention_mask" keys.
"""
# if kwargs passed, raise error
if kwargs:
raise ValueError(
f"Unsupported kwargs: {kwargs}. Please create an issue on GitHub."
)
# `np` can work with inhomogeneous shapes but let's not support it until needed.
if (
isinstance(text, list)
and len(text) > 1
and return_tensors in ("pt", "np")
and padding is False
and truncation is False
):
raise ValueError(
"return_tensors='pt' or 'np' requires padding or truncation."
)
# Handle single string input
if isinstance(text, str):
text = [text]
# Encode all texts
# TODO: figure out how to parallelize this
batch_input_ids = []
for single_text in text:
input_ids = self.encode(single_text, add_special_tokens=add_special_tokens)
# Handle truncation
if truncation and max_length is not None and len(input_ids) > max_length:
input_ids = input_ids[:max_length]
batch_input_ids.append(input_ids)
# Create attention masks (1 for real tokens, 0 for padding)
attention_masks = [[1] * len(input_ids) for input_ids in batch_input_ids]
# Handle padding
if padding in (True, "longest"):
# Pad to longest sequence in batch
max_len = max(len(input_ids) for input_ids in batch_input_ids)
for i, input_ids in enumerate(batch_input_ids):
pad_length = max_len - len(input_ids)
if pad_length > 0:
if self.padding_side == "right":
batch_input_ids[i] = (
input_ids + [self.pad_token_id] * pad_length
)
attention_masks[i] = attention_masks[i] + [0] * pad_length
else: # left padding
batch_input_ids[i] = [
self.pad_token_id
] * pad_length + input_ids
attention_masks[i] = [0] * pad_length + attention_masks[i]
elif padding == "max_length":
if max_length is None:
raise ValueError(
"max_length must be specified when padding='max_length'"
)
for i, input_ids in enumerate(batch_input_ids):
pad_length = max_length - len(input_ids)
if pad_length > 0:
if self.padding_side == "right":
batch_input_ids[i] = (
input_ids + [self.pad_token_id] * pad_length
)
attention_masks[i] = attention_masks[i] + [0] * pad_length
else: # left padding
batch_input_ids[i] = [
self.pad_token_id
] * pad_length + input_ids
attention_masks[i] = [0] * pad_length + attention_masks[i]
# Prepare result
result = {}
# Handle return tensor format
if return_tensors == "pt":
import torch
result["input_ids"] = torch.tensor(batch_input_ids, dtype=torch.long)
result["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
elif return_tensors == "np":
result["input_ids"] = np.array(batch_input_ids, dtype=np.int64)
result["attention_mask"] = np.array(attention_masks, dtype=np.int64)
elif return_tensors is None:
result["input_ids"] = batch_input_ids
result["attention_mask"] = attention_masks
else:
raise ValueError(
f"Unsupported return_tensors='{return_tensors}'. "
"Only 'pt' and 'np' are supported."
)
# If single input, return single sequences (not batched)
if len(text) == 1 and return_tensors is None:
result["input_ids"] = result["input_ids"][0]
result["attention_mask"] = result["attention_mask"][0]
return result