feat: update handling for mistraltokenizer decode and multiprocessing pickling fix (#2790)
* feat: update handling for mistraltokenizer decode * fix: update mistral common package version * fix: to use correct release * fix triton path --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
|||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
mistral-common==1.6.0
|
mistral-common==1.6.3
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
@@ -251,10 +251,13 @@ class HFMistralTokenizer:
|
|||||||
token_ids = [token_ids]
|
token_ids = [token_ids]
|
||||||
|
|
||||||
if skip_special_tokens:
|
if skip_special_tokens:
|
||||||
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
|
return self._mistral.instruct_tokenizer.tokenizer.decode(
|
||||||
|
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
|
||||||
|
)
|
||||||
|
|
||||||
# to_string returns a string with special tokens
|
return self._mistral.instruct_tokenizer.tokenizer.decode(
|
||||||
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
|
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
|
||||||
|
)
|
||||||
|
|
||||||
def _create_mistral_chat_completion_request(
|
def _create_mistral_chat_completion_request(
|
||||||
self, conversation: list[dict], tools: list[dict] | None = None
|
self, conversation: list[dict], tools: list[dict] | None = None
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path, PosixPath
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -424,8 +424,8 @@ def temp_dir() -> Generator[str, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def unique_triton_cache_dir(temp_dir):
|
def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
|
||||||
os.environ["TRITON_CACHE_DIR"] = temp_dir + "/~.triton/cache"
|
os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user