feat: update handling for mistraltokenizer decode and multiprocessing pickling fix (#2790)

* feat: update handling for mistraltokenizer decode

* fix: update mistral common package version

* fix: to use correct release

* fix triton path

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2025-07-02 19:07:18 +07:00
committed by GitHub
parent 6383630155
commit 8ae5a2311b
3 changed files with 11 additions and 8 deletions

View File

@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3 axolotl-contribs-mit==0.0.3
mistral-common==1.6.0 mistral-common==1.6.3

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
from torch import Tensor from torch import Tensor
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
@@ -251,10 +251,13 @@ class HFMistralTokenizer:
token_ids = [token_ids] token_ids = [token_ids]
if skip_special_tokens: if skip_special_tokens:
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids) return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
)
# to_string returns a string with special tokens return self._mistral.instruct_tokenizer.tokenizer.decode(
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids) token_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
def _create_mistral_chat_completion_request( def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None self, conversation: list[dict], tools: list[dict] | None = None

View File

@@ -10,7 +10,7 @@ import shutil
import sys import sys
import tempfile import tempfile
import time import time
from pathlib import Path from pathlib import Path, PosixPath
from typing import Generator from typing import Generator
import datasets import datasets
@@ -424,8 +424,8 @@ def temp_dir() -> Generator[str, None, None]:
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def unique_triton_cache_dir(temp_dir): def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
os.environ["TRITON_CACHE_DIR"] = temp_dir + "/~.triton/cache" os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)