Feat: add devstral model support (#2880) [skip ci]
* fix: do not add training and training_detail block by default * fixed: magistral docs * fix: address pad adding new fields and use built-in from_openai * feat: try enable multiprocessing * fix: check for keys before deleting attn_mask * feat: add mistral pad test * feat: add tool calling test * feat: add devstral tokenizer tests * fix: comma format * chore: remove unused support_preprocessing as tokenizer is pickable now * chore: update magistral doc * feat: add devstral readme and example * chore: refactor error handling
This commit is contained in:
@@ -48,13 +48,6 @@ class TokenizedPromptDataset(Dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||
|
||||
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
|
||||
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
|
||||
LOG.info(
|
||||
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
|
||||
)
|
||||
num_proc = 1
|
||||
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
@@ -681,13 +681,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
for message in messages:
|
||||
transformed_message = self.transform_message(message)
|
||||
|
||||
turn = {
|
||||
**transformed_message,
|
||||
"training": message.get(self.prompter.message_field_training),
|
||||
"training_detail": message.get(
|
||||
self.prompter.message_field_training_detail
|
||||
),
|
||||
}
|
||||
turn = transformed_message
|
||||
|
||||
training = message.get(self.prompter.message_field_training)
|
||||
training_detail = message.get(self.prompter.message_field_training_detail)
|
||||
if training is not None:
|
||||
turn["training"] = training
|
||||
if training_detail is not None:
|
||||
turn["training_detail"] = training_detail
|
||||
|
||||
turns.append(turn)
|
||||
|
||||
@@ -859,15 +860,6 @@ class MistralStrategy(ChatTemplateStrategy):
|
||||
# TODO: address this in the future with mistral-specific checks
|
||||
# self._validate_eot_and_eos_tokens()
|
||||
|
||||
@property
|
||||
def supports_multiprocessing(self) -> bool:
|
||||
"""
|
||||
Whether this tokenizing strategy supports multiprocessing.
|
||||
mistral_common tokenizers cannot be pickled for multiprocessing.
|
||||
"""
|
||||
|
||||
return False
|
||||
|
||||
def find_first_eot_token(self, input_ids, start_idx):
|
||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||
# mistral-common tokenizer does not support eot_tokens
|
||||
|
||||
@@ -70,14 +70,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def supports_batched(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_multiprocessing(self):
|
||||
"""
|
||||
Whether this tokenizing strategy supports multiprocessing.
|
||||
Should return False if the tokenizer has unpicklable objects.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
|
||||
@@ -108,7 +108,7 @@ class DataCollatorForSeq2Seq:
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
if not has_attn_mask:
|
||||
if not has_attn_mask and "attention_mask" in features:
|
||||
del features["attention_mask"]
|
||||
|
||||
# prepare decoder_input_ids
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
import math
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
|
||||
from torch import Tensor
|
||||
@@ -14,9 +15,6 @@ from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.utils.collators.core import IGNORE_INDEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
|
||||
|
||||
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
|
||||
"""Get the file path from local or HF Hub"""
|
||||
@@ -259,75 +257,6 @@ class HFMistralTokenizer:
|
||||
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
|
||||
)
|
||||
|
||||
def _create_mistral_chat_completion_request(
|
||||
self, conversation: list[dict], tools: list[dict] | None = None
|
||||
) -> "ChatCompletionRequest":
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AssistantMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||
|
||||
messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = (
|
||||
[]
|
||||
)
|
||||
for turn in conversation:
|
||||
role = turn.get("role")
|
||||
|
||||
if role == "user":
|
||||
messages.append(UserMessage(content=turn["content"]))
|
||||
elif role == "assistant":
|
||||
messages.append(
|
||||
AssistantMessage(
|
||||
content=turn.get("content"),
|
||||
tool_calls=turn.get("tool_calls"),
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
messages.append(
|
||||
ToolMessage(
|
||||
content=turn.get("content"),
|
||||
tool_call_id=turn.get("tool_call_id"),
|
||||
name=turn.get("name"),
|
||||
)
|
||||
)
|
||||
elif role == "system":
|
||||
messages.append(SystemMessage(content=turn["content"]))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown role for use with mistral-common tokenizer: {turn['role']}"
|
||||
)
|
||||
|
||||
tool_calls: list[Tool] = []
|
||||
if tools:
|
||||
# convert to Tool
|
||||
for tool in tools:
|
||||
if tool["type"] != "function":
|
||||
continue
|
||||
|
||||
function = tool["function"]
|
||||
|
||||
tool_calls.append(
|
||||
Tool(
|
||||
function=Function(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
# set parameters to empty dict if not provided
|
||||
parameters=function.get("parameters", {}),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
chat_completion: ChatCompletionRequest = ChatCompletionRequest(
|
||||
messages=messages,
|
||||
tools=tool_calls,
|
||||
)
|
||||
|
||||
return chat_completion
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -342,8 +271,8 @@ class HFMistralTokenizer:
|
||||
if add_generation_prompt:
|
||||
raise NotImplementedError("add_generation_prompt not supported yet")
|
||||
|
||||
chat_completion: ChatCompletionRequest = (
|
||||
self._create_mistral_chat_completion_request(messages, tools)
|
||||
chat_completion: ChatCompletionRequest = ChatCompletionRequest.from_openai(
|
||||
messages, tools
|
||||
)
|
||||
|
||||
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
|
||||
@@ -408,13 +337,16 @@ class HFMistralTokenizer:
|
||||
padding_value=IGNORE_INDEX,
|
||||
)
|
||||
|
||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
attention_mask = None
|
||||
if "attention_mask" in features[0]:
|
||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
|
||||
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
|
||||
position_ids = None
|
||||
if "position_ids" in features[0]:
|
||||
if self.padding_side == "left":
|
||||
# Likely not needed, but keeping for now
|
||||
@@ -443,22 +375,15 @@ class HFMistralTokenizer:
|
||||
pos_seq = torch.cat([pos_seq, pad_positions])
|
||||
position_ids_list.append(pos_seq)
|
||||
position_ids = torch.stack(position_ids_list)
|
||||
else:
|
||||
# Create position_ids if not present
|
||||
seq_len = input_ids.size(1)
|
||||
position_ids = (
|
||||
torch.arange(seq_len, dtype=torch.long)
|
||||
.unsqueeze(0)
|
||||
.expand(input_ids.size(0), -1)
|
||||
)
|
||||
|
||||
# Ensure all tensors have the same sequence length
|
||||
max_seq_len = max(
|
||||
input_ids.size(1),
|
||||
labels.size(1),
|
||||
attention_mask.size(1),
|
||||
position_ids.size(1),
|
||||
)
|
||||
# Check attention mask and position ids if they are present
|
||||
tensor_lengths = [input_ids.size(1), labels.size(1)]
|
||||
if attention_mask is not None:
|
||||
tensor_lengths.append(attention_mask.size(1))
|
||||
if position_ids is not None:
|
||||
tensor_lengths.append(position_ids.size(1))
|
||||
max_seq_len = max(tensor_lengths)
|
||||
|
||||
# TODO: check if trimming is needed? and correct.
|
||||
|
||||
@@ -492,44 +417,48 @@ class HFMistralTokenizer:
|
||||
elif labels.size(1) > max_seq_len:
|
||||
labels = labels[:, :max_seq_len]
|
||||
|
||||
if attention_mask.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - attention_mask.size(1)
|
||||
if self.padding_side == "right":
|
||||
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
|
||||
else:
|
||||
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
|
||||
elif attention_mask.size(1) > max_seq_len:
|
||||
attention_mask = attention_mask[:, :max_seq_len]
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - attention_mask.size(1)
|
||||
if self.padding_side == "right":
|
||||
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
|
||||
else:
|
||||
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
|
||||
elif attention_mask.size(1) > max_seq_len:
|
||||
attention_mask = attention_mask[:, :max_seq_len]
|
||||
|
||||
if position_ids.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - position_ids.size(1)
|
||||
if self.padding_side == "right":
|
||||
batch_size = position_ids.size(0)
|
||||
new_position_ids = []
|
||||
for i in range(batch_size):
|
||||
seq = position_ids[i]
|
||||
if len(seq) > 0:
|
||||
# get last position and pad with sequential values
|
||||
last_pos = seq[-1].item()
|
||||
pad_positions = torch.arange(
|
||||
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
||||
)
|
||||
new_seq = torch.cat([seq, pad_positions])
|
||||
else:
|
||||
new_seq = torch.arange(pad_len, dtype=torch.long)
|
||||
new_position_ids.append(new_seq)
|
||||
position_ids = torch.stack(new_position_ids)
|
||||
else:
|
||||
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
|
||||
elif position_ids.size(1) > max_seq_len:
|
||||
position_ids = position_ids[:, :max_seq_len]
|
||||
if position_ids is not None:
|
||||
if position_ids.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - position_ids.size(1)
|
||||
if self.padding_side == "right":
|
||||
batch_size = position_ids.size(0)
|
||||
new_position_ids = []
|
||||
for i in range(batch_size):
|
||||
seq = position_ids[i]
|
||||
if len(seq) > 0:
|
||||
# get last position and pad with sequential values
|
||||
last_pos = seq[-1].item()
|
||||
pad_positions = torch.arange(
|
||||
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
||||
)
|
||||
new_seq = torch.cat([seq, pad_positions])
|
||||
else:
|
||||
new_seq = torch.arange(pad_len, dtype=torch.long)
|
||||
new_position_ids.append(new_seq)
|
||||
position_ids = torch.stack(new_position_ids)
|
||||
else:
|
||||
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
|
||||
elif position_ids.size(1) > max_seq_len:
|
||||
position_ids = position_ids[:, :max_seq_len]
|
||||
|
||||
final_batch = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
if attention_mask is not None:
|
||||
final_batch["attention_mask"] = attention_mask
|
||||
if position_ids is not None:
|
||||
final_batch["position_ids"] = position_ids
|
||||
|
||||
# Handle non-sequence fields (raise error)
|
||||
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
|
||||
@@ -545,7 +474,7 @@ class HFMistralTokenizer:
|
||||
result = {}
|
||||
for k, v in final_batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
result[k] = v.numpy().astype(np.long)
|
||||
result[k] = v.numpy().astype(np.int64)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user