This commit is contained in:
bursteratom
2024-12-06 16:06:57 -05:00
parent 89dae7dc6d
commit 3c07b6d6b1

View File

@@ -225,9 +225,12 @@ class MultiModalChatDataCollator(DataCollatorMixin):
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
if chat_template_type == "qwen2_vl":
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
else:
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels