feat: add call method to mistral tokenizer wrapper (#2898)
This commit is contained in:
@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
|
||||
@@ -748,5 +750,100 @@ def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
|
||||
assert "Not the same number of function calls and responses" in str(e)
|
||||
|
||||
|
||||
def test_magistral_tokenizer_call_method(
|
||||
magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer"
|
||||
):
|
||||
"""Test the __call__ method behavior matches HuggingFace standards"""
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
hf_tokenizer = deepcopy(llama3_tokenizer)
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
|
||||
test_text = "Hello, how are you?"
|
||||
batch_texts = ["Hello world", "How are you?"]
|
||||
|
||||
# Test single string with return_tensors=None
|
||||
hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None)
|
||||
mistral_result: dict[str, list[int]] = magistral_tokenizer(
|
||||
test_text, return_tensors=None
|
||||
)
|
||||
|
||||
assert isinstance(mistral_result, dict)
|
||||
assert set(mistral_result.keys()) == {"input_ids", "attention_mask"}
|
||||
assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list
|
||||
assert isinstance(
|
||||
mistral_result["attention_mask"], type(hf_result["attention_mask"])
|
||||
)
|
||||
assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"])
|
||||
assert np.all(mistral_result["attention_mask"])
|
||||
assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array
|
||||
|
||||
# Test single string with return_tensors='pt'
|
||||
hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt")
|
||||
mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||
test_text, return_tensors="pt"
|
||||
)
|
||||
|
||||
# Check structure and types
|
||||
assert isinstance(mistral_result_pt["input_ids"], torch.Tensor)
|
||||
assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor)
|
||||
|
||||
# Check shapes match (don't compare token dimension)
|
||||
assert len(hf_result_pt["input_ids"].shape) == len(
|
||||
mistral_result_pt["input_ids"].shape
|
||||
)
|
||||
assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0]
|
||||
assert (
|
||||
mistral_result_pt["attention_mask"].shape
|
||||
== mistral_result_pt["input_ids"].shape
|
||||
)
|
||||
assert torch.all(mistral_result_pt["attention_mask"] == 1)
|
||||
|
||||
# Test batch input with padding
|
||||
hf_batch: dict[str, torch.Tensor] = hf_tokenizer(
|
||||
batch_texts, return_tensors="pt", padding=True
|
||||
)
|
||||
mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||
batch_texts, return_tensors="pt", padding=True
|
||||
)
|
||||
|
||||
# Check batch behavior
|
||||
assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape)
|
||||
assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0]
|
||||
assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape
|
||||
assert torch.any(
|
||||
mistral_batch["attention_mask"][0] == 0
|
||||
) # padding in shorter sequence
|
||||
assert torch.all(
|
||||
mistral_batch["attention_mask"][1] == 1
|
||||
) # no padding in longer sequence
|
||||
|
||||
# Test numpy tensors
|
||||
mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer(
|
||||
test_text, return_tensors="np"
|
||||
)
|
||||
assert isinstance(mistral_result_np["input_ids"], np.ndarray)
|
||||
assert isinstance(mistral_result_np["attention_mask"], np.ndarray)
|
||||
|
||||
# Test consistency with encode()
|
||||
encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True)
|
||||
called: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||
test_text, return_tensors="pt"
|
||||
)
|
||||
assert encoded == called["input_ids"][0].tolist()
|
||||
|
||||
# Test Error handling
|
||||
with pytest.raises(ValueError, match="Unsupported kwargs"):
|
||||
magistral_tokenizer(test_text, unsupported_param=True)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="return_tensors='pt' or 'np' requires padding or truncation"
|
||||
):
|
||||
magistral_tokenizer(batch_texts, return_tensors="pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user