lint
This commit is contained in:
@@ -225,9 +225,12 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
labels = batch["input_ids"].clone()
|
labels = batch["input_ids"].clone()
|
||||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||||
# Ignore the image token index in the loss computation (model specific)
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
if chat_template_type == "qwen2_vl":
|
||||||
processor.image_token
|
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
|
||||||
)
|
else:
|
||||||
|
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
processor.image_token
|
||||||
|
)
|
||||||
labels[labels == image_token_id] = -100
|
labels[labels == image_token_id] = -100
|
||||||
batch["labels"] = labels
|
batch["labels"] = labels
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user