feat: update handling for mistraltokenizer decode and multiprocessing pickling fix (#2790)
* feat: update handling for mistraltokenizer decode * fix: update mistral common package version * fix: to use correct release * fix triton path --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
mistral-common==1.6.0
|
||||
mistral-common==1.6.3
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
|
||||
from torch import Tensor
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
@@ -251,10 +251,13 @@ class HFMistralTokenizer:
|
||||
token_ids = [token_ids]
|
||||
|
||||
if skip_special_tokens:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
|
||||
return self._mistral.instruct_tokenizer.tokenizer.decode(
|
||||
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
|
||||
)
|
||||
|
||||
# to_string returns a string with special tokens
|
||||
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
|
||||
return self._mistral.instruct_tokenizer.tokenizer.decode(
|
||||
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
|
||||
)
|
||||
|
||||
def _create_mistral_chat_completion_request(
|
||||
self, conversation: list[dict], tools: list[dict] | None = None
|
||||
|
||||
@@ -10,7 +10,7 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PosixPath
|
||||
from typing import Generator
|
||||
|
||||
import datasets
|
||||
@@ -424,8 +424,8 @@ def temp_dir() -> Generator[str, None, None]:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def unique_triton_cache_dir(temp_dir):
|
||||
os.environ["TRITON_CACHE_DIR"] = temp_dir + "/~.triton/cache"
|
||||
def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
|
||||
os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
|
||||
Reference in New Issue
Block a user