Feat: add devstral model support (#2880) [skip ci]

* fix: do not add training and training_detail block by default

* fixed: magistral docs

* fix: address pad adding new fields and use built-in from_openai

* feat: try enable multiprocessing

* fix: check for keys before deleting attn_mask

* feat: add mistral pad test

* feat: add tool calling test

* feat: add devstral tokenizer tests

* fix: comma format

* chore: remove unused support_preprocessing as tokenizer is pickable now

* chore: update magistral doc

* feat: add devstral readme and example

* chore: refactor error handling
This commit is contained in:
NanoCode012
2025-07-08 22:01:19 +07:00
committed by GitHub
parent 78bff4925e
commit 8c6a6ea6eb
10 changed files with 690 additions and 189 deletions

View File

@@ -48,13 +48,6 @@ class TokenizedPromptDataset(Dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
LOG.info(
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
)
num_proc = 1
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True

View File

@@ -681,13 +681,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
for message in messages:
transformed_message = self.transform_message(message)
turn = {
**transformed_message,
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}
turn = transformed_message
training = message.get(self.prompter.message_field_training)
training_detail = message.get(self.prompter.message_field_training_detail)
if training is not None:
turn["training"] = training
if training_detail is not None:
turn["training_detail"] = training_detail
turns.append(turn)
@@ -859,15 +860,6 @@ class MistralStrategy(ChatTemplateStrategy):
# TODO: address this in the future with mistral-specific checks
# self._validate_eot_and_eos_tokens()
@property
def supports_multiprocessing(self) -> bool:
"""
Whether this tokenizing strategy supports multiprocessing.
mistral_common tokenizers cannot be pickled for multiprocessing.
"""
return False
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# mistral-common tokenizer does not support eot_tokens

View File

@@ -70,14 +70,6 @@ class PromptTokenizingStrategy(abc.ABC):
def supports_batched(self):
return False
@property
def supports_multiprocessing(self):
"""
Whether this tokenizing strategy supports multiprocessing.
Should return False if the tokenizer has unpicklable objects.
"""
return True
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:

View File

@@ -108,7 +108,7 @@ class DataCollatorForSeq2Seq:
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if not has_attn_mask:
if not has_attn_mask and "attention_mask" in features:
del features["attention_mask"]
# prepare decoder_input_ids

View File

@@ -3,10 +3,11 @@
import math
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Optional
from typing import Optional
import numpy as np
from huggingface_hub import hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
from torch import Tensor
@@ -14,9 +15,6 @@ from transformers.utils import PaddingStrategy
from axolotl.utils.collators.core import IGNORE_INDEX
if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import ChatCompletionRequest
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
"""Get the file path from local or HF Hub"""
@@ -259,75 +257,6 @@ class HFMistralTokenizer:
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
def _create_mistral_chat_completion_request(
self, conversation: list[dict], tools: list[dict] | None = None
) -> "ChatCompletionRequest":
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
SystemMessage,
ToolMessage,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool
messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = (
[]
)
for turn in conversation:
role = turn.get("role")
if role == "user":
messages.append(UserMessage(content=turn["content"]))
elif role == "assistant":
messages.append(
AssistantMessage(
content=turn.get("content"),
tool_calls=turn.get("tool_calls"),
)
)
elif role == "tool":
messages.append(
ToolMessage(
content=turn.get("content"),
tool_call_id=turn.get("tool_call_id"),
name=turn.get("name"),
)
)
elif role == "system":
messages.append(SystemMessage(content=turn["content"]))
else:
raise ValueError(
f"Unknown role for use with mistral-common tokenizer: {turn['role']}"
)
tool_calls: list[Tool] = []
if tools:
# convert to Tool
for tool in tools:
if tool["type"] != "function":
continue
function = tool["function"]
tool_calls.append(
Tool(
function=Function(
name=function["name"],
description=function["description"],
# set parameters to empty dict if not provided
parameters=function.get("parameters", {}),
)
)
)
chat_completion: ChatCompletionRequest = ChatCompletionRequest(
messages=messages,
tools=tool_calls,
)
return chat_completion
def apply_chat_template(
self,
messages: list[dict],
@@ -342,8 +271,8 @@ class HFMistralTokenizer:
if add_generation_prompt:
raise NotImplementedError("add_generation_prompt not supported yet")
chat_completion: ChatCompletionRequest = (
self._create_mistral_chat_completion_request(messages, tools)
chat_completion: ChatCompletionRequest = ChatCompletionRequest.from_openai(
messages, tools
)
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
@@ -408,13 +337,16 @@ class HFMistralTokenizer:
padding_value=IGNORE_INDEX,
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=0,
)
attention_mask = None
if "attention_mask" in features[0]:
attention_mask = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
batch_first=True,
padding_value=0,
)
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
position_ids = None
if "position_ids" in features[0]:
if self.padding_side == "left":
# Likely not needed, but keeping for now
@@ -443,22 +375,15 @@ class HFMistralTokenizer:
pos_seq = torch.cat([pos_seq, pad_positions])
position_ids_list.append(pos_seq)
position_ids = torch.stack(position_ids_list)
else:
# Create position_ids if not present
seq_len = input_ids.size(1)
position_ids = (
torch.arange(seq_len, dtype=torch.long)
.unsqueeze(0)
.expand(input_ids.size(0), -1)
)
# Ensure all tensors have the same sequence length
max_seq_len = max(
input_ids.size(1),
labels.size(1),
attention_mask.size(1),
position_ids.size(1),
)
# Check attention mask and position ids if they are present
tensor_lengths = [input_ids.size(1), labels.size(1)]
if attention_mask is not None:
tensor_lengths.append(attention_mask.size(1))
if position_ids is not None:
tensor_lengths.append(position_ids.size(1))
max_seq_len = max(tensor_lengths)
# TODO: check if trimming is needed? and correct.
@@ -492,44 +417,48 @@ class HFMistralTokenizer:
elif labels.size(1) > max_seq_len:
labels = labels[:, :max_seq_len]
if attention_mask.size(1) < max_seq_len:
pad_len = max_seq_len - attention_mask.size(1)
if self.padding_side == "right":
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
else:
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
elif attention_mask.size(1) > max_seq_len:
attention_mask = attention_mask[:, :max_seq_len]
if attention_mask is not None:
if attention_mask.size(1) < max_seq_len:
pad_len = max_seq_len - attention_mask.size(1)
if self.padding_side == "right":
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
else:
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
elif attention_mask.size(1) > max_seq_len:
attention_mask = attention_mask[:, :max_seq_len]
if position_ids.size(1) < max_seq_len:
pad_len = max_seq_len - position_ids.size(1)
if self.padding_side == "right":
batch_size = position_ids.size(0)
new_position_ids = []
for i in range(batch_size):
seq = position_ids[i]
if len(seq) > 0:
# get last position and pad with sequential values
last_pos = seq[-1].item()
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
new_seq = torch.cat([seq, pad_positions])
else:
new_seq = torch.arange(pad_len, dtype=torch.long)
new_position_ids.append(new_seq)
position_ids = torch.stack(new_position_ids)
else:
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
elif position_ids.size(1) > max_seq_len:
position_ids = position_ids[:, :max_seq_len]
if position_ids is not None:
if position_ids.size(1) < max_seq_len:
pad_len = max_seq_len - position_ids.size(1)
if self.padding_side == "right":
batch_size = position_ids.size(0)
new_position_ids = []
for i in range(batch_size):
seq = position_ids[i]
if len(seq) > 0:
# get last position and pad with sequential values
last_pos = seq[-1].item()
pad_positions = torch.arange(
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
)
new_seq = torch.cat([seq, pad_positions])
else:
new_seq = torch.arange(pad_len, dtype=torch.long)
new_position_ids.append(new_seq)
position_ids = torch.stack(new_position_ids)
else:
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
elif position_ids.size(1) > max_seq_len:
position_ids = position_ids[:, :max_seq_len]
final_batch = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if attention_mask is not None:
final_batch["attention_mask"] = attention_mask
if position_ids is not None:
final_batch["position_ids"] = position_ids
# Handle non-sequence fields (raise error)
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
@@ -545,7 +474,7 @@ class HFMistralTokenizer:
result = {}
for k, v in final_batch.items():
if isinstance(v, torch.Tensor):
result[k] = v.numpy().astype(np.long)
result[k] = v.numpy().astype(np.int64)
else:
result[k] = v
return result