From 3c07b6d6b1c47466680fda21412ada1bac003c18 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Fri, 6 Dec 2024 16:06:57 -0500 Subject: [PATCH] lint --- src/axolotl/utils/collators/mm_chat.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 799428088..0e5338704 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -225,9 +225,12 @@ class MultiModalChatDataCollator(DataCollatorMixin): labels = batch["input_ids"].clone() labels[labels == processor.tokenizer.pad_token_id] = -100 # # Ignore the image token index in the loss computation (model specific) - image_token_id = processor.tokenizer.convert_tokens_to_ids( - processor.image_token - ) + if chat_template_type == "qwen2_vl": + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + else: + image_token_id = processor.tokenizer.convert_tokens_to_ids( + processor.image_token + ) labels[labels == image_token_id] = -100 batch["labels"] = labels