diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 799428088..0e5338704 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -225,9 +225,12 @@ class MultiModalChatDataCollator(DataCollatorMixin): labels = batch["input_ids"].clone() labels[labels == processor.tokenizer.pad_token_id] = -100 # # Ignore the image token index in the loss computation (model specific) - image_token_id = processor.tokenizer.convert_tokens_to_ids( - processor.image_token - ) + if chat_template_type == "qwen2_vl": + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + else: + image_token_id = processor.tokenizer.convert_tokens_to_ids( + processor.image_token + ) labels[labels == image_token_id] = -100 batch["labels"] = labels