Feat: Add Magistral and mistral-common tokenizer support (#2780)
This commit is contained in:
@@ -48,6 +48,13 @@ class TokenizedPromptDataset(Dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||
|
||||
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
|
||||
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
|
||||
LOG.info(
|
||||
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
|
||||
)
|
||||
num_proc = 1
|
||||
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
@@ -189,7 +189,7 @@ class KDStrategyLoader(StrategyLoader):
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
|
||||
return ChatTemplateStrategyWithKD
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
|
||||
@@ -121,6 +121,19 @@ def modify_tokenizer_files(
|
||||
|
||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
|
||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||
"""Load mistral-common tokenizer"""
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
# Load the HF-compatible wrapper around MistralTokenizer
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
||||
|
||||
return tokenizer
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
return _load_mistral_common_tokenizer(cfg)
|
||||
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
HF Chat Templates prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import ProcessorMixin
|
||||
@@ -15,6 +17,9 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
# Configure the logger
|
||||
LOG = get_logger(__name__)
|
||||
LOG.setLevel("INFO")
|
||||
@@ -81,7 +86,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
conversation,
|
||||
conversation: list[dict],
|
||||
add_generation_prompt=False,
|
||||
images=None,
|
||||
tools=None,
|
||||
@@ -271,9 +276,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
|
||||
|
||||
# Default to eos_token if eot_tokens not provided
|
||||
self.eot_tokens = (
|
||||
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
|
||||
)
|
||||
self.eot_tokens = []
|
||||
if eot_tokens is not None:
|
||||
self.eot_tokens = eot_tokens
|
||||
elif (
|
||||
hasattr(self.tokenizer, "eos_token")
|
||||
and self.tokenizer.eos_token is not None
|
||||
):
|
||||
self.eot_tokens = [self.tokenizer.eos_token]
|
||||
|
||||
self.split_thinking = split_thinking
|
||||
|
||||
self.images = "images"
|
||||
@@ -796,14 +807,104 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
class MistralStrategy(ChatTemplateStrategy):
|
||||
"""
|
||||
Mistral strategy for chat template.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: "ChatTemplatePrompter",
|
||||
tokenizer: "HFMistralTokenizer",
|
||||
train_on_inputs: bool,
|
||||
sequence_len: int,
|
||||
roles_to_train: list[str] | None = None,
|
||||
train_on_eos: str | None = None,
|
||||
train_on_eot: str | None = None,
|
||||
eot_tokens: list[str] | None = None,
|
||||
split_thinking: bool | None = False,
|
||||
):
|
||||
# Call the parent's parent __init__ (PromptTokenizingStrategy) to skip ChatTemplateStrategy's validation
|
||||
# pylint: disable=non-parent-init-called,super-init-not-called
|
||||
PromptTokenizingStrategy.__init__(
|
||||
self, prompter, tokenizer, train_on_inputs, sequence_len
|
||||
)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
|
||||
self.roles_to_train = []
|
||||
if roles_to_train:
|
||||
# map roles if exist in prompter.roles else use the role as is
|
||||
self.roles_to_train = [
|
||||
prompter.roles.get(role, role) for role in roles_to_train
|
||||
]
|
||||
|
||||
self.train_on_eos = train_on_eos
|
||||
# Backward compatibility, load from train_on_eos
|
||||
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
|
||||
|
||||
# Default to eos_token if eot_tokens not provided
|
||||
self.eot_tokens = []
|
||||
if eot_tokens is not None:
|
||||
self.eot_tokens = eot_tokens
|
||||
else:
|
||||
# set eot_tokens to the eos_token
|
||||
self.eot_tokens = [self.tokenizer.eos_token]
|
||||
|
||||
self.split_thinking = split_thinking
|
||||
|
||||
self.images = "images"
|
||||
|
||||
LOG.debug(
|
||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||
)
|
||||
|
||||
# Skip the validation that ChatTemplateStrategy calls
|
||||
# TODO: address this in the future with mistral-specific checks
|
||||
# self._validate_eot_and_eos_tokens()
|
||||
|
||||
@property
|
||||
def supports_multiprocessing(self) -> bool:
|
||||
"""
|
||||
Whether this tokenizing strategy supports multiprocessing.
|
||||
mistral_common tokenizers cannot be pickled for multiprocessing.
|
||||
"""
|
||||
|
||||
return False
|
||||
|
||||
def find_first_eot_token(self, input_ids, start_idx):
|
||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||
# mistral-common tokenizer does not support eot_tokens
|
||||
return self.find_first_eos_token(input_ids, start_idx)
|
||||
|
||||
|
||||
class MistralPrompter(ChatTemplatePrompter):
|
||||
"""
|
||||
Mistral prompter for chat template.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._chat_template_msg_variables = set(["tool_call_id", "name", "tool_calls"])
|
||||
|
||||
|
||||
class StrategyLoader:
|
||||
"""
|
||||
Load chat template strategy based on configuration.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
def _get_strategy_cls(self, cfg):
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
return MistralStrategy
|
||||
|
||||
return ChatTemplateStrategy
|
||||
|
||||
def _get_prompter_cls(self, cfg):
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
return MistralPrompter
|
||||
|
||||
return ChatTemplatePrompter
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
return {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
@@ -829,9 +930,14 @@ class StrategyLoader:
|
||||
else:
|
||||
dataset_config = ds_cfg
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
)
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
# mistral-common does not use this, so we pass an empty string
|
||||
chat_template_string = ""
|
||||
else:
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
@@ -857,10 +963,11 @@ class StrategyLoader:
|
||||
}
|
||||
|
||||
strategy_params = self._get_strategy_params(cfg, dataset_config)
|
||||
strategy_cls = self._get_strategy_cls()
|
||||
strategy_cls = self._get_strategy_cls(cfg)
|
||||
prompter_cls = self._get_prompter_cls(cfg)
|
||||
|
||||
strategy = strategy_cls(
|
||||
ChatTemplatePrompter(**prompter_params),
|
||||
prompter_cls(**prompter_params),
|
||||
tokenizer=tokenizer,
|
||||
**strategy_params,
|
||||
)
|
||||
|
||||
@@ -70,6 +70,14 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def supports_batched(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_multiprocessing(self):
|
||||
"""
|
||||
Whether this tokenizing strategy supports multiprocessing.
|
||||
Should return False if the tokenizer has unpicklable objects.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
|
||||
567
src/axolotl/utils/mistral_tokenizer.py
Normal file
567
src/axolotl/utils/mistral_tokenizer.py
Normal file
@@ -0,0 +1,567 @@
|
||||
"""Wrapper for MistralTokenizer from mistral-common"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from torch import Tensor
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.utils.collators.core import IGNORE_INDEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
|
||||
|
||||
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
|
||||
"""Get the file path from local or HF Hub"""
|
||||
if os.path.exists(path_or_repo_id):
|
||||
maybe_file_path = os.path.join(path_or_repo_id, filename)
|
||||
if os.path.exists(maybe_file_path):
|
||||
return maybe_file_path
|
||||
|
||||
raise FileNotFoundError(f"File not found at {path_or_repo_id}")
|
||||
|
||||
return hf_hub_download(repo_id=path_or_repo_id, filename=filename)
|
||||
|
||||
|
||||
class HFMistralTokenizer:
|
||||
"""
|
||||
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
|
||||
and exposes HuggingFace API for special tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, mistral: MistralTokenizer, name_or_path: str, tokenizer_path: str
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
mistral: The mistral-common tokenizer to wrap.
|
||||
name_or_path: The name or path to the tokenizer files or the repo id.
|
||||
"""
|
||||
self._mistral = mistral
|
||||
self._padding_side = "right"
|
||||
self._name_or_path = name_or_path
|
||||
self._tokenizer_path = tokenizer_path
|
||||
|
||||
# Manual set to training mode
|
||||
from mistral_common.protocol.instruct.validator import (
|
||||
MistralRequestValidator,
|
||||
ValidationMode,
|
||||
)
|
||||
|
||||
# Check if MistralRequestValidator has a _mode attribute.
|
||||
# This is a private API and may change in the future.
|
||||
# pylint: disable=protected-access
|
||||
if not (
|
||||
hasattr(self._mistral, "_chat_completion_request_validator")
|
||||
and isinstance(
|
||||
self._mistral._chat_completion_request_validator,
|
||||
MistralRequestValidator,
|
||||
)
|
||||
and hasattr(self._mistral._chat_completion_request_validator, "_mode")
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Unable to switch mistral tokenizer to finetuning mode – "
|
||||
"private API `_chat_completion_request_validator._mode` missing."
|
||||
)
|
||||
|
||||
self._mistral._chat_completion_request_validator._mode = (
|
||||
ValidationMode.finetuning
|
||||
)
|
||||
|
||||
def _load_system_prompt(self, path_or_repo_id: str) -> str:
|
||||
"""Load system prompt from local or HF Hub.
|
||||
|
||||
Note: Unused for now as we don't want to explicitly set the system prompt if a user does
|
||||
not provide one.
|
||||
|
||||
Args:
|
||||
path_or_repo_id: The path to the tokenizer files or the repo id.
|
||||
|
||||
Returns:
|
||||
The system prompt.
|
||||
"""
|
||||
file_path = _get_file_path(path_or_repo_id, "SYSTEM_PROMPT.txt")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"System prompt file not found at {file_path}")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.bos_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.pad_id
|
||||
|
||||
@property
|
||||
def unk_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.unk_id
|
||||
|
||||
@property
|
||||
def bos_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.bos_token_id)
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.eos_token_id)
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.pad_token_id)
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.unk_token_id)
|
||||
|
||||
@property
|
||||
def padding_side(self) -> str:
|
||||
return self._padding_side
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return self._name_or_path
|
||||
|
||||
@property
|
||||
def chat_template(self) -> str | None:
|
||||
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
|
||||
return None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.n_words
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
name_or_path: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> "HFMistralTokenizer":
|
||||
"""
|
||||
Load a mistral tekken tokenizer from a local file or HF Hub and wrap it.
|
||||
|
||||
Args:
|
||||
path_or_repo_id: The path to the tokenizer files or the repo id.
|
||||
revision: The revision of the tokenizer to download.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
A HFMistralTokenizer instance.
|
||||
"""
|
||||
if revision:
|
||||
raise NotImplementedError(
|
||||
"Revision not supported yet for mistral-common tokenizer"
|
||||
)
|
||||
|
||||
# only support Tekken tokenizer for now
|
||||
# downloads from HF Hub if not local
|
||||
tokenizer_path = _get_file_path(name_or_path, "tekken.json")
|
||||
|
||||
base = MistralTokenizer.from_file(tokenizer_path)
|
||||
|
||||
return cls(
|
||||
base,
|
||||
name_or_path=name_or_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
)
|
||||
|
||||
def save_pretrained(self, save_directory: str) -> None:
|
||||
"""
|
||||
Save the Tekken/SentencePiece model file so that from_pretrained can pick it up again.
|
||||
|
||||
Only Tekken models are supported.
|
||||
|
||||
Args:
|
||||
save_directory: The directory to save the tokenizer files.
|
||||
"""
|
||||
inner = self._mistral.instruct_tokenizer.tokenizer
|
||||
if isinstance(inner, Tekkenizer):
|
||||
# Create the directory and save the model
|
||||
try:
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# Verify directory was created
|
||||
if not os.path.exists(save_directory):
|
||||
raise RuntimeError(f"Failed to create directory: {save_directory}")
|
||||
|
||||
# Verify source file exists
|
||||
if not os.path.exists(self._tokenizer_path):
|
||||
raise FileNotFoundError(
|
||||
f"Source tokenizer file not found: {self._tokenizer_path}"
|
||||
)
|
||||
|
||||
destination_path = os.path.join(save_directory, "tekken.json")
|
||||
copyfile(self._tokenizer_path, destination_path)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to save tokenizer to {save_directory}: {e}. "
|
||||
f"Source path: {self._tokenizer_path}, "
|
||||
f"Directory exists: {os.path.exists(save_directory)}"
|
||||
) from e
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown tokenizer type: {type(inner)}")
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
|
||||
"""
|
||||
Encode a text string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
text: The text string to encode.
|
||||
add_special_tokens: Whether to add special tokens to the encoded tokens.
|
||||
|
||||
Returns:
|
||||
A list of token IDs.
|
||||
"""
|
||||
return self._mistral.instruct_tokenizer.tokenizer.encode(
|
||||
text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens,
|
||||
)
|
||||
|
||||
def decode(
|
||||
self, token_ids: int | list[int], skip_special_tokens: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Decode a list of token IDs into a text string.
|
||||
|
||||
Args:
|
||||
token_ids: The int or list of token IDs to decode.
|
||||
skip_special_tokens: Whether to skip special tokens in the decoded text.
|
||||
|
||||
Returns:
|
||||
The decoded text string.
|
||||
"""
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
|
||||
if skip_special_tokens:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
|
||||
|
||||
# to_string returns a string with special tokens
|
||||
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
|
||||
|
||||
def _create_mistral_chat_completion_request(
|
||||
self, conversation: list[dict], tools: list[dict] | None = None
|
||||
) -> "ChatCompletionRequest":
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AssistantMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||
|
||||
messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = (
|
||||
[]
|
||||
)
|
||||
for turn in conversation:
|
||||
role = turn.get("role")
|
||||
|
||||
if role == "user":
|
||||
messages.append(UserMessage(content=turn["content"]))
|
||||
elif role == "assistant":
|
||||
messages.append(
|
||||
AssistantMessage(
|
||||
content=turn.get("content"),
|
||||
tool_calls=turn.get("tool_calls"),
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
messages.append(
|
||||
ToolMessage(
|
||||
content=turn.get("content"),
|
||||
tool_call_id=turn.get("tool_call_id"),
|
||||
name=turn.get("name"),
|
||||
)
|
||||
)
|
||||
elif role == "system":
|
||||
messages.append(SystemMessage(content=turn["content"]))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown role for use with mistral-common tokenizer: {turn['role']}"
|
||||
)
|
||||
|
||||
tool_calls: list[Tool] = []
|
||||
if tools:
|
||||
# convert to Tool
|
||||
for tool in tools:
|
||||
if tool["type"] != "function":
|
||||
continue
|
||||
|
||||
function = tool["function"]
|
||||
|
||||
tool_calls.append(
|
||||
Tool(
|
||||
function=Function(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
# set parameters to empty dict if not provided
|
||||
parameters=function.get("parameters", {}),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
chat_completion: ChatCompletionRequest = ChatCompletionRequest(
|
||||
messages=messages,
|
||||
tools=tool_calls,
|
||||
)
|
||||
|
||||
return chat_completion
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tokenize: bool = True,
|
||||
tools: list[dict] | None = None,
|
||||
chat_template: str | None = None, # pylint: disable=unused-argument
|
||||
add_generation_prompt: bool = False, # pylint: disable=unused-argument
|
||||
) -> list[int] | str:
|
||||
if chat_template:
|
||||
raise NotImplementedError("chat_template not supported yet")
|
||||
|
||||
if add_generation_prompt:
|
||||
raise NotImplementedError("add_generation_prompt not supported yet")
|
||||
|
||||
chat_completion: ChatCompletionRequest = (
|
||||
self._create_mistral_chat_completion_request(messages, tools)
|
||||
)
|
||||
|
||||
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
|
||||
|
||||
if tokenize:
|
||||
return tokens
|
||||
|
||||
return self.decode(tokens)
|
||||
|
||||
def pad(
|
||||
self,
|
||||
features: list[dict[str, list[int] | np.ndarray]],
|
||||
*,
|
||||
padding: bool | str | PaddingStrategy = True,
|
||||
max_length: int | None = None,
|
||||
pad_to_multiple_of: int | None = None,
|
||||
return_tensors: str | None = None, # "np", "pt", or "tf"
|
||||
) -> dict[str, np.ndarray | Tensor]:
|
||||
"""
|
||||
HF-style pad method that properly handles all sequence-related features:
|
||||
- pad 'input_ids' & 'labels' to the longest (or to max_length)
|
||||
"""
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
# Check for unsupported fields
|
||||
if any("token_type_ids" in f for f in features):
|
||||
raise ValueError("token_type_ids is not supported by this tokenizer")
|
||||
|
||||
# Determine desired sequence length
|
||||
lengths = [len(f["input_ids"]) for f in features]
|
||||
if padding in (True, "longest", PaddingStrategy.LONGEST):
|
||||
target_length = max(lengths)
|
||||
elif padding in ("max_length", PaddingStrategy.MAX_LENGTH):
|
||||
if max_length is None:
|
||||
raise ValueError("max_length must be set for 'max_length' padding")
|
||||
target_length = max_length
|
||||
elif padding in (False, "do_not_pad", PaddingStrategy.DO_NOT_PAD):
|
||||
target_length = None
|
||||
else:
|
||||
raise ValueError(f"Unknown padding strategy: {padding}")
|
||||
|
||||
# Apply pad_to_multiple_of
|
||||
if target_length is not None and pad_to_multiple_of is not None:
|
||||
target_length = (
|
||||
math.ceil(target_length / pad_to_multiple_of) * pad_to_multiple_of
|
||||
)
|
||||
|
||||
# If no padding requested, just stack tensors
|
||||
do_pad = target_length is not None
|
||||
|
||||
# Pad sequences using torch.nn.utils.rnn.pad_sequence
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["input_ids"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
)
|
||||
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["labels"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=IGNORE_INDEX,
|
||||
)
|
||||
|
||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
|
||||
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
|
||||
if "position_ids" in features[0]:
|
||||
if self.padding_side == "left":
|
||||
# Likely not needed, but keeping for now
|
||||
# For left padding, we'll pad with 0s using pad_sequence, then handle manually
|
||||
position_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
[
|
||||
torch.tensor(x["position_ids"], dtype=torch.long)
|
||||
for x in features
|
||||
],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
else:
|
||||
# For right padding, continue the sequence
|
||||
max_pos_len = max(len(f["position_ids"]) for f in features)
|
||||
position_ids_list = []
|
||||
for f in features:
|
||||
pos_seq = torch.tensor(f["position_ids"], dtype=torch.long)
|
||||
if len(pos_seq) < max_pos_len:
|
||||
# Continue the sequence
|
||||
last_pos = pos_seq[-1].item() if len(pos_seq) > 0 else -1
|
||||
pad_len = max_pos_len - len(pos_seq)
|
||||
pad_positions = torch.arange(
|
||||
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
||||
)
|
||||
pos_seq = torch.cat([pos_seq, pad_positions])
|
||||
position_ids_list.append(pos_seq)
|
||||
position_ids = torch.stack(position_ids_list)
|
||||
else:
|
||||
# Create position_ids if not present
|
||||
seq_len = input_ids.size(1)
|
||||
position_ids = (
|
||||
torch.arange(seq_len, dtype=torch.long)
|
||||
.unsqueeze(0)
|
||||
.expand(input_ids.size(0), -1)
|
||||
)
|
||||
|
||||
# Ensure all tensors have the same sequence length
|
||||
max_seq_len = max(
|
||||
input_ids.size(1),
|
||||
labels.size(1),
|
||||
attention_mask.size(1),
|
||||
position_ids.size(1),
|
||||
)
|
||||
|
||||
# TODO: check if trimming is needed? and correct.
|
||||
|
||||
if do_pad and target_length is not None:
|
||||
max_seq_len = target_length
|
||||
|
||||
# Pad all tensors to the same length
|
||||
if input_ids.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - input_ids.size(1)
|
||||
if self.padding_side == "right":
|
||||
input_ids = F.pad(
|
||||
input_ids,
|
||||
(0, pad_len),
|
||||
value=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
)
|
||||
else:
|
||||
input_ids = F.pad(
|
||||
input_ids,
|
||||
(pad_len, 0),
|
||||
value=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
)
|
||||
elif input_ids.size(1) > max_seq_len:
|
||||
input_ids = input_ids[:, :max_seq_len]
|
||||
|
||||
if labels.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - labels.size(1)
|
||||
if self.padding_side == "right":
|
||||
labels = F.pad(labels, (0, pad_len), value=IGNORE_INDEX)
|
||||
else:
|
||||
labels = F.pad(labels, (pad_len, 0), value=IGNORE_INDEX)
|
||||
elif labels.size(1) > max_seq_len:
|
||||
labels = labels[:, :max_seq_len]
|
||||
|
||||
if attention_mask.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - attention_mask.size(1)
|
||||
if self.padding_side == "right":
|
||||
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
|
||||
else:
|
||||
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
|
||||
elif attention_mask.size(1) > max_seq_len:
|
||||
attention_mask = attention_mask[:, :max_seq_len]
|
||||
|
||||
if position_ids.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - position_ids.size(1)
|
||||
if self.padding_side == "right":
|
||||
batch_size = position_ids.size(0)
|
||||
new_position_ids = []
|
||||
for i in range(batch_size):
|
||||
seq = position_ids[i]
|
||||
if len(seq) > 0:
|
||||
# get last position and pad with sequential values
|
||||
last_pos = seq[-1].item()
|
||||
pad_positions = torch.arange(
|
||||
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
||||
)
|
||||
new_seq = torch.cat([seq, pad_positions])
|
||||
else:
|
||||
new_seq = torch.arange(pad_len, dtype=torch.long)
|
||||
new_position_ids.append(new_seq)
|
||||
position_ids = torch.stack(new_position_ids)
|
||||
else:
|
||||
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
|
||||
elif position_ids.size(1) > max_seq_len:
|
||||
position_ids = position_ids[:, :max_seq_len]
|
||||
|
||||
final_batch = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
# Handle non-sequence fields (raise error)
|
||||
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
|
||||
for f in features:
|
||||
for key in f.keys():
|
||||
if key not in sequence_fields:
|
||||
raise NotImplementedError(
|
||||
f"Non-sequence field {key} not handled yet"
|
||||
)
|
||||
|
||||
# Convert to requested tensor type
|
||||
if return_tensors is None or return_tensors == "np":
|
||||
result = {}
|
||||
for k, v in final_batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
result[k] = v.numpy().astype(np.long)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
if return_tensors == "pt":
|
||||
return final_batch
|
||||
|
||||
raise ValueError(f"Unsupported return_tensors='{return_tensors}'")
|
||||
|
||||
def convert_ids_to_tokens(self, ids: list[int]) -> list[str]:
|
||||
"""
|
||||
Convert a list of token IDs to a list of tokens.
|
||||
|
||||
Args:
|
||||
ids: The list of token IDs to convert.
|
||||
|
||||
Returns:
|
||||
The list of tokens.
|
||||
"""
|
||||
return [
|
||||
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
|
||||
]
|
||||
@@ -1265,6 +1265,68 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tokenizer_use_mistral_common(cls, data):
|
||||
if data.get("tokenizer_use_mistral_common") is None:
|
||||
if any(
|
||||
"magistral" in name.lower()
|
||||
for name in [
|
||||
data.get("base_model", ""),
|
||||
data.get("base_model_config", ""),
|
||||
data.get("tokenizer_config", ""),
|
||||
]
|
||||
):
|
||||
LOG.warning(
|
||||
"tokenizer_use_mistral_common auto inferred to True for Magistral models. Please set it to True explicitly if you want to use mistral-common tokenizer."
|
||||
)
|
||||
data["tokenizer_use_mistral_common"] = True
|
||||
|
||||
return data
|
||||
|
||||
@field_validator("tokenizer_use_mistral_common", mode="after")
|
||||
@classmethod
|
||||
def check_mistral_common_import(cls, tokenizer_use_mistral_common):
|
||||
if tokenizer_use_mistral_common:
|
||||
try:
|
||||
import mistral_common # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`."
|
||||
) from exception
|
||||
|
||||
return tokenizer_use_mistral_common
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_mistral_common_incompatible_options(cls, data):
|
||||
if not data.get("tokenizer_use_mistral_common"):
|
||||
return data
|
||||
|
||||
# NOTE: mistral-common tokenizer is not compatible with editing tokenizer at the moment
|
||||
|
||||
if data.get("added_tokens_overrides"):
|
||||
raise ValueError(
|
||||
"added_tokens_overrides is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
if data.get("special_tokens"):
|
||||
raise ValueError(
|
||||
"special_tokens override is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
if data.get("tokens"):
|
||||
raise ValueError(
|
||||
"tokens override is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
if data.get("chat_template"):
|
||||
raise ValueError(
|
||||
"Setting chat_template is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
@@ -18,6 +18,7 @@ class ModelInputConfig(BaseModel):
|
||||
tokenizer_config: str | None = None
|
||||
tokenizer_use_fast: bool | None = None
|
||||
tokenizer_legacy: bool | None = None
|
||||
tokenizer_use_mistral_common: bool | None = None
|
||||
tokenizer_type: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user