feat: update handling for mistraltokenizer decode and multiprocessing pickling fix (#2790)

* feat: update handling for mistraltokenizer decode

* fix: update mistral common package version

* fix: to use correct release

* fix triton path

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2025-07-02 19:07:18 +07:00
committed by GitHub
parent 6383630155
commit 8ae5a2311b
3 changed files with 11 additions and 8 deletions

View File

@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
mistral-common==1.6.0
mistral-common==1.6.3

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
import numpy as np
from huggingface_hub import hf_hub_download
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
from torch import Tensor
from transformers.utils import PaddingStrategy
@@ -251,10 +251,13 @@ class HFMistralTokenizer:
token_ids = [token_ids]
if skip_special_tokens:
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
)
# to_string returns a string with special tokens
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
return self._mistral.instruct_tokenizer.tokenizer.decode(
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None

View File

@@ -10,7 +10,7 @@ import shutil
import sys
import tempfile
import time
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Generator
import datasets
@@ -424,8 +424,8 @@ def temp_dir() -> Generator[str, None, None]:
@pytest.fixture(scope="function", autouse=True)
def unique_triton_cache_dir(temp_dir):
os.environ["TRITON_CACHE_DIR"] = temp_dir + "/~.triton/cache"
def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
@pytest.fixture(scope="function", autouse=True)