Compare commits
3 Commits
version-de
...
mm_mc_chat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5c01c11d8 | ||
|
|
00ebf2faf9 | ||
|
|
641e84188b |
@@ -1,5 +1,6 @@
|
|||||||
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
||||||
|
|
||||||
|
import ast
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -75,6 +76,49 @@ class ProcessingStrategy:
|
|||||||
result["messages"] = messages
|
result["messages"] = messages
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def convert_multiple_choice_to_multimedia_messages(
|
||||||
|
messages: dict,
|
||||||
|
) -> list[dict]:
|
||||||
|
|
||||||
|
def construct_prompt(sample):
|
||||||
|
question = sample["question"]
|
||||||
|
options = sample["options"]
|
||||||
|
if isinstance(options, str):
|
||||||
|
options = ast.literal_eval(options)
|
||||||
|
|
||||||
|
example = ""
|
||||||
|
start_chr = "A"
|
||||||
|
prediction_range = []
|
||||||
|
index2ans = {}
|
||||||
|
for option in options:
|
||||||
|
prediction_range.append(start_chr)
|
||||||
|
example += f"({start_chr}) {option}\n"
|
||||||
|
index2ans[start_chr] = option
|
||||||
|
start_chr = chr(ord(start_chr) + 1)
|
||||||
|
|
||||||
|
empty_prompt_sample_structure = "{}\n\n{}\n\nAnswer with the option's letter from the given choices directly."
|
||||||
|
empty_prompt = empty_prompt_sample_structure.format(question, example)
|
||||||
|
|
||||||
|
return empty_prompt
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
|
|
||||||
|
user_content = construct_prompt(messages)
|
||||||
|
assistant_response = messages["answer"]
|
||||||
|
|
||||||
|
new_messages.append(
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": user_content}]}
|
||||||
|
)
|
||||||
|
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": assistant_response}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
|
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
|
||||||
"""Convert regular messages format to Messages format with content type"""
|
"""Convert regular messages format to Messages format with content type"""
|
||||||
|
|
||||||
@@ -106,39 +150,51 @@ class ProcessingStrategy:
|
|||||||
|
|
||||||
processed_examples = []
|
processed_examples = []
|
||||||
for example in examples:
|
for example in examples:
|
||||||
if not ("messages" in example or "conversations" in example):
|
if not (
|
||||||
|
"messages" in example
|
||||||
|
or "conversations" in example
|
||||||
|
or "question" in example
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only `messages` and `conversations` message keys are currently supported."
|
"Only `messages`, `conversations`, and `question` message keys are currently supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_example = None
|
processed_example = None
|
||||||
if "messages" in example: # OpenAI format
|
if "messages" in example: # OpenAI format
|
||||||
processed_example = example
|
processed_example = example
|
||||||
|
# convert regular messages format to Messages format with content type
|
||||||
|
# for compatibility with apply_chat_template
|
||||||
|
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
||||||
|
processed_example["messages"]
|
||||||
|
)
|
||||||
|
elif "question" in example: # Multiple choice format
|
||||||
|
processed_example = {}
|
||||||
|
processed_example["messages"] = (
|
||||||
|
convert_multiple_choice_to_multimedia_messages(example)
|
||||||
|
)
|
||||||
else: # Legacy format
|
else: # Legacy format
|
||||||
processed_example = convert_legacy_format(example)
|
processed_example = convert_legacy_format(example)
|
||||||
|
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
||||||
# convert regular messages format to Messages format with content type
|
processed_example["messages"]
|
||||||
# for compatibility with apply_chat_template
|
)
|
||||||
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
|
||||||
processed_example["messages"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# find the image key if it exists
|
# find the image key if it exists
|
||||||
possible_image_keys = ["images", "image"]
|
|
||||||
image_key = None
|
|
||||||
for key in possible_image_keys:
|
|
||||||
if key in processed_example:
|
|
||||||
image_key = key
|
|
||||||
break
|
|
||||||
|
|
||||||
# if the image key exists, add the image to the first message
|
image_keys = []
|
||||||
if image_key is not None:
|
for key in example.keys():
|
||||||
# TODO: check if it's normal to be single image only for common datasets
|
if "image" in key:
|
||||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
image_keys.append(key)
|
||||||
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
|
|
||||||
image_value = processed_example[image_key][0]
|
for im_key in image_keys:
|
||||||
|
if example[im_key] is None:
|
||||||
|
continue
|
||||||
|
if isinstance(example[im_key], list):
|
||||||
|
if len(example[im_key]) == 0:
|
||||||
|
continue
|
||||||
|
image_value = example[im_key][0]
|
||||||
|
else:
|
||||||
|
image_value = example[im_key]
|
||||||
|
|
||||||
# Handle image loading (Image, url, path, base64)
|
|
||||||
image_value = load_image(image_value)
|
image_value = load_image(image_value)
|
||||||
|
|
||||||
if self.image_size is not None:
|
if self.image_size is not None:
|
||||||
@@ -163,33 +219,12 @@ class ProcessingStrategy:
|
|||||||
color=padding_color,
|
color=padding_color,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Look for any image type in the first message
|
processed_example["messages"][0]["content"].append(
|
||||||
# some dataset have an {type: "image"} in the first message
|
{
|
||||||
ind_to_add = None
|
"type": "image",
|
||||||
|
"image": image_value,
|
||||||
for i, content in enumerate(
|
}
|
||||||
processed_example["messages"][0]["content"]
|
)
|
||||||
):
|
|
||||||
# Usually datasets created with image columns, don't have it in the messages itself
|
|
||||||
if content["type"] == "image" and all(
|
|
||||||
k not in content for k in ["image", "url", "path", "base64"]
|
|
||||||
):
|
|
||||||
ind_to_add = i
|
|
||||||
break
|
|
||||||
|
|
||||||
# If an image type is found, add the image to that index
|
|
||||||
if ind_to_add is not None:
|
|
||||||
processed_example["messages"][0]["content"][ind_to_add][
|
|
||||||
"image"
|
|
||||||
] = image_value
|
|
||||||
else:
|
|
||||||
# if no image type is found, add it to end of the first message
|
|
||||||
processed_example["messages"][0]["content"].append(
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"image": image_value,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_examples.append(processed_example)
|
processed_examples.append(processed_example)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user