lint
This commit is contained in:
@@ -42,43 +42,17 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return self.__class__.process_rows(
|
return self.__class__.process_rows(
|
||||||
examples, self.processor, self.chat_template, self.max_images
|
examples,
|
||||||
|
self.processor,
|
||||||
|
self.chat_template,
|
||||||
|
self.max_images,
|
||||||
|
chat_template_type=self.chat_template_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pixtral_chat_conversion(messages):
|
def preprocess(examples: list[dict]) -> list[dict]:
|
||||||
is_single_message = not isinstance(messages, list)
|
|
||||||
if is_single_message:
|
|
||||||
messages = [messages]
|
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
|
||||||
if message["role"] == "user":
|
|
||||||
for j, content in enumerate(message["content"]):
|
|
||||||
if "type" in content and content["type"] == "text":
|
|
||||||
messages[i]["content"][j] = {
|
|
||||||
"type": "text",
|
|
||||||
"content": content["text"],
|
|
||||||
}
|
|
||||||
|
|
||||||
if message["role"] == "assistant":
|
|
||||||
messages[i]["content"] = message["content"][0]["text"]
|
|
||||||
|
|
||||||
if is_single_message:
|
|
||||||
return messages[0]
|
|
||||||
return messages
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
|
||||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
|
||||||
# see also DataCollatorWithFlattening and DefaultDataCollator
|
|
||||||
|
|
||||||
# *** This is COPIED from the trl example sft_vlm.py code ***
|
|
||||||
# use this as a starting point
|
|
||||||
|
|
||||||
def _preprocess(examples: list[dict]) -> list[dict]:
|
|
||||||
"""
|
"""
|
||||||
Preprocess conversation examples to ensure consistent format.
|
Preprocess conversation examples to ensure consistent format.
|
||||||
|
|
||||||
Converts different conversation formats to OpenAI format with 'messages'.
|
Converts different conversation formats to OpenAI format with 'messages'.
|
||||||
Supports two formats:
|
Supports two formats:
|
||||||
1. OpenAI format with 'messages'
|
1. OpenAI format with 'messages'
|
||||||
@@ -86,7 +60,6 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
examples: list of conversation dictionaries
|
examples: list of conversation dictionaries
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict in OpenAI format with 'messages' key
|
dict in OpenAI format with 'messages' key
|
||||||
|
|
||||||
@@ -134,7 +107,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
|
|
||||||
return processed_examples
|
return processed_examples
|
||||||
|
|
||||||
def _process_images(examples, max_images):
|
@staticmethod
|
||||||
|
def process_images(examples, max_images):
|
||||||
"""
|
"""
|
||||||
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
||||||
|
|
||||||
@@ -186,10 +160,57 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pixtral_chat_conversion(messages):
|
||||||
|
is_single_message = not isinstance(messages, list)
|
||||||
|
if is_single_message:
|
||||||
|
messages = [messages]
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
if message["role"] == "user":
|
||||||
|
for j, content in enumerate(message["content"]):
|
||||||
|
if "type" in content and content["type"] == "text":
|
||||||
|
messages[i]["content"][j] = {
|
||||||
|
"type": "text",
|
||||||
|
"content": content["text"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if message["role"] == "assistant":
|
||||||
|
messages[i]["content"] = message["content"][0]["text"]
|
||||||
|
|
||||||
|
if is_single_message:
|
||||||
|
return messages[0]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_rows(
|
||||||
|
examples,
|
||||||
|
processor,
|
||||||
|
chat_template,
|
||||||
|
max_images,
|
||||||
|
length_only=False,
|
||||||
|
chat_template_type=None,
|
||||||
|
):
|
||||||
|
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||||
|
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||||
|
|
||||||
|
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||||
|
# use this as a starting point
|
||||||
|
|
||||||
# Preprocess the examples
|
# Preprocess the examples
|
||||||
examples = _preprocess(examples)
|
examples = __class__.preprocess(examples)
|
||||||
|
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
|
if chat_template_type == "pixtral":
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(
|
||||||
|
__class__.pixtral_chat_conversion(example["messages"]),
|
||||||
|
chat_template=chat_template,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
else:
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(
|
processor.apply_chat_template(
|
||||||
example["messages"], chat_template=chat_template, tokenize=False
|
example["messages"], chat_template=chat_template, tokenize=False
|
||||||
@@ -197,7 +218,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
for example in examples
|
for example in examples
|
||||||
]
|
]
|
||||||
|
|
||||||
images = _process_images(examples, max_images=max_images)
|
images = __class__.process_images(examples, max_images=max_images)
|
||||||
|
|
||||||
# Tokenize the texts and process the images
|
# Tokenize the texts and process the images
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user