This commit is contained in:
bursteratom
2024-12-05 14:59:51 -05:00
parent 169116a50f
commit dc055a4ef7

View File

@@ -42,43 +42,17 @@ class MultiModalChatDataCollator(DataCollatorMixin):
) )
return self.__class__.process_rows( return self.__class__.process_rows(
examples, self.processor, self.chat_template, self.max_images examples,
self.processor,
self.chat_template,
self.max_images,
chat_template_type=self.chat_template_type,
) )
@staticmethod @staticmethod
def pixtral_chat_conversion(messages): def preprocess(examples: list[dict]) -> list[dict]:
is_single_message = not isinstance(messages, list)
if is_single_message:
messages = [messages]
for i, message in enumerate(messages):
if message["role"] == "user":
for j, content in enumerate(message["content"]):
if "type" in content and content["type"] == "text":
messages[i]["content"][j] = {
"type": "text",
"content": content["text"],
}
if message["role"] == "assistant":
messages[i]["content"] = message["content"][0]["text"]
if is_single_message:
return messages[0]
return messages
@staticmethod
def process_rows(examples, processor, chat_template, max_images, length_only=False):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
def _preprocess(examples: list[dict]) -> list[dict]:
""" """
Preprocess conversation examples to ensure consistent format. Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'. Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats: Supports two formats:
1. OpenAI format with 'messages' 1. OpenAI format with 'messages'
@@ -86,7 +60,6 @@ class MultiModalChatDataCollator(DataCollatorMixin):
Args: Args:
examples: list of conversation dictionaries examples: list of conversation dictionaries
Returns: Returns:
dict in OpenAI format with 'messages' key dict in OpenAI format with 'messages' key
@@ -134,7 +107,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
return processed_examples return processed_examples
def _process_images(examples, max_images): @staticmethod
def process_images(examples, max_images):
""" """
Process images from examples, ensuring consistency in image presence and applying max_images limit. Process images from examples, ensuring consistency in image presence and applying max_images limit.
@@ -186,10 +160,57 @@ class MultiModalChatDataCollator(DataCollatorMixin):
return images return images
@staticmethod
def pixtral_chat_conversion(messages):
is_single_message = not isinstance(messages, list)
if is_single_message:
messages = [messages]
for i, message in enumerate(messages):
if message["role"] == "user":
for j, content in enumerate(message["content"]):
if "type" in content and content["type"] == "text":
messages[i]["content"][j] = {
"type": "text",
"content": content["text"],
}
if message["role"] == "assistant":
messages[i]["content"] = message["content"][0]["text"]
if is_single_message:
return messages[0]
return messages
@staticmethod
def process_rows(
examples,
processor,
chat_template,
max_images,
length_only=False,
chat_template_type=None,
):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
# Preprocess the examples # Preprocess the examples
examples = _preprocess(examples) examples = __class__.preprocess(examples)
# Get the texts and images, and apply the chat template # Get the texts and images, and apply the chat template
if chat_template_type == "pixtral":
texts = [
processor.apply_chat_template(
__class__.pixtral_chat_conversion(example["messages"]),
chat_template=chat_template,
tokenize=False,
)
for example in examples
]
else:
texts = [ texts = [
processor.apply_chat_template( processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False example["messages"], chat_template=chat_template, tokenize=False
@@ -197,7 +218,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
for example in examples for example in examples
] ]
images = _process_images(examples, max_images=max_images) images = __class__.process_images(examples, max_images=max_images)
# Tokenize the texts and process the images # Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True) batch = processor(text=texts, images=images, return_tensors="pt", padding=True)