This commit is contained in:
bursteratom
2024-12-05 15:34:04 -05:00
parent dc055a4ef7
commit 1ad56303b2

View File

@@ -36,11 +36,6 @@ class MultiModalChatDataCollator(DataCollatorMixin):
self, examples: list[Union[list[int], Any, dict[str, Any]]]
) -> dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
if self.chat_template_type == "pixtral":
return self.__class__.process_rows_pixtral(
examples, self.processor, self.chat_template, self.max_images
)
return self.__class__.process_rows(
examples,
self.processor,
@@ -218,6 +213,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
for example in examples
]
if chat_template_type == "llava":
max_images = 1
images = __class__.process_images(examples, max_images=max_images)
# Tokenize the texts and process the images
@@ -238,51 +235,3 @@ class MultiModalChatDataCollator(DataCollatorMixin):
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
}
return batch
@staticmethod
def process_rows_pixtral(
examples, processor, chat_template, max_images, length_only=False
):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
__class__.pixtral_chat_conversion(example["messages"]),
chat_template=chat_template,
tokenize=False,
)
for example in examples
]
images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels
if length_only:
return {
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
}
return batch