From 8ae5a2311b4912f283f248dafb10d88c0770cd97 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 2 Jul 2025 19:07:18 +0700 Subject: [PATCH] feat: update handling for mistraltokenizer decode and multiprocessing pickling fix (#2790) * feat: update handling for mistraltokenizer decode * fix: update mistral common package version * fix: to use correct release * fix triton path --------- Co-authored-by: Wing Lian --- requirements.txt | 2 +- src/axolotl/utils/mistral_tokenizer.py | 11 +++++++---- tests/conftest.py | 6 +++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1fc3a9ff7..10ac04a66 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,4 +68,4 @@ schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.3 -mistral-common==1.6.0 +mistral-common==1.6.3 diff --git a/src/axolotl/utils/mistral_tokenizer.py b/src/axolotl/utils/mistral_tokenizer.py index 3ccf39bb0..1ba824938 100644 --- a/src/axolotl/utils/mistral_tokenizer.py +++ b/src/axolotl/utils/mistral_tokenizer.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional import numpy as np from huggingface_hub import hf_hub_download from mistral_common.tokens.tokenizers.mistral import MistralTokenizer -from mistral_common.tokens.tokenizers.tekken import Tekkenizer +from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer from torch import Tensor from transformers.utils import PaddingStrategy @@ -251,10 +251,13 @@ class HFMistralTokenizer: token_ids = [token_ids] if skip_special_tokens: - return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids) + return self._mistral.instruct_tokenizer.tokenizer.decode( + token_ids, special_token_policy=SpecialTokenPolicy.IGNORE + ) - # to_string returns a string with special tokens - return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids) + return self._mistral.instruct_tokenizer.tokenizer.decode( + token_ids, special_token_policy=SpecialTokenPolicy.KEEP + ) def _create_mistral_chat_completion_request( self, conversation: list[dict], tools: list[dict] | None = None diff --git a/tests/conftest.py b/tests/conftest.py index b8dff2477..bbe2d10ee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ import shutil import sys import tempfile import time -from pathlib import Path +from pathlib import Path, PosixPath from typing import Generator import datasets @@ -424,8 +424,8 @@ def temp_dir() -> Generator[str, None, None]: @pytest.fixture(scope="function", autouse=True) -def unique_triton_cache_dir(temp_dir): - os.environ["TRITON_CACHE_DIR"] = temp_dir + "/~.triton/cache" +def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None: + os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache" @pytest.fixture(scope="function", autouse=True)