lint
This commit is contained in:
@@ -36,11 +36,6 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
||||
) -> dict[str, Any]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
if self.chat_template_type == "pixtral":
|
||||
return self.__class__.process_rows_pixtral(
|
||||
examples, self.processor, self.chat_template, self.max_images
|
||||
)
|
||||
|
||||
return self.__class__.process_rows(
|
||||
examples,
|
||||
self.processor,
|
||||
@@ -218,6 +213,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
for example in examples
|
||||
]
|
||||
|
||||
if chat_template_type == "llava":
|
||||
max_images = 1
|
||||
images = __class__.process_images(examples, max_images=max_images)
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
@@ -238,51 +235,3 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||
}
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def process_rows_pixtral(
|
||||
examples, processor, chat_template, max_images, length_only=False
|
||||
):
|
||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||
|
||||
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||
# use this as a starting point
|
||||
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
__class__.pixtral_chat_conversion(example["messages"]),
|
||||
chat_template=chat_template,
|
||||
tokenize=False,
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
images = [
|
||||
Image.open(example["images"])
|
||||
if isinstance(example["images"], str)
|
||||
else example["images"]
|
||||
for example in examples
|
||||
]
|
||||
|
||||
if max_images > 0:
|
||||
images = [img_batch[:max_images] for img_batch in images]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.image_token
|
||||
)
|
||||
labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
if length_only:
|
||||
return {
|
||||
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||
}
|
||||
return batch
|
||||
|
||||
Reference in New Issue
Block a user