Files
axolotl/src/axolotl/utils/collators/mm_chat.py

93 lines
2.7 KiB
Python

"""
Collators for multi-modal chat messages and packing
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
from torch import Tensor
from transformers import PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
from axolotl.processing_strategies import ProcessingStrategy
@dataclass
class MultiModalChatDataCollator(DataCollatorMixin):
"""
Collator for multi-modal chat messages
"""
tokenizer: PreTrainedTokenizerBase
processing_strategy: ProcessingStrategy
packing: bool = False
return_tensors: str = "pt"
padding: Union[bool, str, PaddingStrategy] = True
pad_to_multiple_of: Optional[int] = None
def __post_init__(self):
if self.packing:
raise ValueError("Packing is currently not supported.")
def torch_call(self, examples: list[dict]) -> dict[str, Any]:
return self.process_rows(examples)
def process_rows(
self,
examples: list[dict],
) -> dict[str, Tensor]:
# Preprocess the examples
examples = self.processing_strategy(examples)
# Initialize batch
batch: dict[str, Any] = {}
# Process each example
for example in examples:
# Apply chat template to process the example
# This method requires transformers>=4.49.0
result = self.processing_strategy.processor.apply_chat_template(
example["messages"],
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
# TODO: Check if need handling for len(input_ids) > sequence_len
# Add the processed tensors to our batch
for key in result.keys():
if key not in batch:
batch[key] = []
batch[key].append(result[key].squeeze(0))
# Pad sequences to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(
batch["input_ids"],
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
batch["attention_mask"], batch_first=True, padding_value=0
)
# Create the final batch
final_batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
# Process the labels
final_batch["labels"] = self.processing_strategy.process_labels(
final_batch["input_ids"]
)
return final_batch