Compare commits
3 Commits
uv-first
...
mm_mc_chat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5c01c11d8 | ||
|
|
00ebf2faf9 | ||
|
|
641e84188b |
@@ -1,5 +1,6 @@
|
||||
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
||||
|
||||
import ast
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
@@ -75,6 +76,49 @@ class ProcessingStrategy:
|
||||
result["messages"] = messages
|
||||
return result
|
||||
|
||||
def convert_multiple_choice_to_multimedia_messages(
|
||||
messages: dict,
|
||||
) -> list[dict]:
|
||||
|
||||
def construct_prompt(sample):
|
||||
question = sample["question"]
|
||||
options = sample["options"]
|
||||
if isinstance(options, str):
|
||||
options = ast.literal_eval(options)
|
||||
|
||||
example = ""
|
||||
start_chr = "A"
|
||||
prediction_range = []
|
||||
index2ans = {}
|
||||
for option in options:
|
||||
prediction_range.append(start_chr)
|
||||
example += f"({start_chr}) {option}\n"
|
||||
index2ans[start_chr] = option
|
||||
start_chr = chr(ord(start_chr) + 1)
|
||||
|
||||
empty_prompt_sample_structure = "{}\n\n{}\n\nAnswer with the option's letter from the given choices directly."
|
||||
empty_prompt = empty_prompt_sample_structure.format(question, example)
|
||||
|
||||
return empty_prompt
|
||||
|
||||
new_messages = []
|
||||
|
||||
user_content = construct_prompt(messages)
|
||||
assistant_response = messages["answer"]
|
||||
|
||||
new_messages.append(
|
||||
{"role": "user", "content": [{"type": "text", "text": user_content}]}
|
||||
)
|
||||
|
||||
new_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": assistant_response}],
|
||||
}
|
||||
)
|
||||
|
||||
return new_messages
|
||||
|
||||
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
|
||||
"""Convert regular messages format to Messages format with content type"""
|
||||
|
||||
@@ -106,39 +150,51 @@ class ProcessingStrategy:
|
||||
|
||||
processed_examples = []
|
||||
for example in examples:
|
||||
if not ("messages" in example or "conversations" in example):
|
||||
if not (
|
||||
"messages" in example
|
||||
or "conversations" in example
|
||||
or "question" in example
|
||||
):
|
||||
raise ValueError(
|
||||
"Only `messages` and `conversations` message keys are currently supported."
|
||||
"Only `messages`, `conversations`, and `question` message keys are currently supported."
|
||||
)
|
||||
|
||||
processed_example = None
|
||||
if "messages" in example: # OpenAI format
|
||||
processed_example = example
|
||||
# convert regular messages format to Messages format with content type
|
||||
# for compatibility with apply_chat_template
|
||||
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
||||
processed_example["messages"]
|
||||
)
|
||||
elif "question" in example: # Multiple choice format
|
||||
processed_example = {}
|
||||
processed_example["messages"] = (
|
||||
convert_multiple_choice_to_multimedia_messages(example)
|
||||
)
|
||||
else: # Legacy format
|
||||
processed_example = convert_legacy_format(example)
|
||||
|
||||
# convert regular messages format to Messages format with content type
|
||||
# for compatibility with apply_chat_template
|
||||
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
||||
processed_example["messages"]
|
||||
)
|
||||
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
||||
processed_example["messages"]
|
||||
)
|
||||
|
||||
# find the image key if it exists
|
||||
possible_image_keys = ["images", "image"]
|
||||
image_key = None
|
||||
for key in possible_image_keys:
|
||||
if key in processed_example:
|
||||
image_key = key
|
||||
break
|
||||
|
||||
# if the image key exists, add the image to the first message
|
||||
if image_key is not None:
|
||||
# TODO: check if it's normal to be single image only for common datasets
|
||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
|
||||
image_value = processed_example[image_key][0]
|
||||
image_keys = []
|
||||
for key in example.keys():
|
||||
if "image" in key:
|
||||
image_keys.append(key)
|
||||
|
||||
for im_key in image_keys:
|
||||
if example[im_key] is None:
|
||||
continue
|
||||
if isinstance(example[im_key], list):
|
||||
if len(example[im_key]) == 0:
|
||||
continue
|
||||
image_value = example[im_key][0]
|
||||
else:
|
||||
image_value = example[im_key]
|
||||
|
||||
# Handle image loading (Image, url, path, base64)
|
||||
image_value = load_image(image_value)
|
||||
|
||||
if self.image_size is not None:
|
||||
@@ -163,33 +219,12 @@ class ProcessingStrategy:
|
||||
color=padding_color,
|
||||
)
|
||||
|
||||
# Look for any image type in the first message
|
||||
# some dataset have an {type: "image"} in the first message
|
||||
ind_to_add = None
|
||||
|
||||
for i, content in enumerate(
|
||||
processed_example["messages"][0]["content"]
|
||||
):
|
||||
# Usually datasets created with image columns, don't have it in the messages itself
|
||||
if content["type"] == "image" and all(
|
||||
k not in content for k in ["image", "url", "path", "base64"]
|
||||
):
|
||||
ind_to_add = i
|
||||
break
|
||||
|
||||
# If an image type is found, add the image to that index
|
||||
if ind_to_add is not None:
|
||||
processed_example["messages"][0]["content"][ind_to_add][
|
||||
"image"
|
||||
] = image_value
|
||||
else:
|
||||
# if no image type is found, add it to end of the first message
|
||||
processed_example["messages"][0]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": image_value,
|
||||
}
|
||||
)
|
||||
processed_example["messages"][0]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": image_value,
|
||||
}
|
||||
)
|
||||
|
||||
processed_examples.append(processed_example)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user