lint
This commit is contained in:
@@ -1817,6 +1817,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = MultiModalChatDataCollator
|
collator = MultiModalChatDataCollator
|
||||||
kwargs["processor"] = self.processor
|
kwargs["processor"] = self.processor
|
||||||
kwargs["chat_template"] = training_args.chat_template
|
kwargs["chat_template"] = training_args.chat_template
|
||||||
|
kwargs["chat_template_type"] = self.cfg.chat_template
|
||||||
else:
|
else:
|
||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
processor: ProcessorMixin
|
processor: ProcessorMixin
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
|
chat_template_type: Optional[str] = None
|
||||||
packing: bool = False
|
packing: bool = False
|
||||||
max_images: int = -1
|
max_images: int = -1
|
||||||
padding: Union[bool, str, PaddingStrategy] = True
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
@@ -33,11 +34,37 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
|
if self.chat_template_type == "pixtral":
|
||||||
|
return self.__class__.process_rows_pixtral(
|
||||||
|
examples, self.processor, self.chat_template, self.max_images
|
||||||
|
)
|
||||||
|
|
||||||
return self.__class__.process_rows(
|
return self.__class__.process_rows(
|
||||||
examples, self.processor, self.chat_template, self.max_images
|
examples, self.processor, self.chat_template, self.max_images
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pixtral_chat_conversion(messages):
|
||||||
|
is_single_message = not isinstance(messages, list)
|
||||||
|
if is_single_message:
|
||||||
|
messages = [messages]
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
if message["role"] == "user":
|
||||||
|
for j, content in enumerate(message["content"]):
|
||||||
|
if "type" in content and content["type"] == "text":
|
||||||
|
messages[i]["content"][j] = {
|
||||||
|
"type": "text",
|
||||||
|
"content": content["text"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if message["role"] == "assistant":
|
||||||
|
messages[i]["content"] = message["content"][0]["text"]
|
||||||
|
|
||||||
|
if is_single_message:
|
||||||
|
return messages[0]
|
||||||
|
return messages
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||||
@@ -81,3 +108,51 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||||
}
|
}
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_rows_pixtral(
|
||||||
|
examples, processor, chat_template, max_images, length_only=False
|
||||||
|
):
|
||||||
|
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||||
|
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||||
|
|
||||||
|
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||||
|
# use this as a starting point
|
||||||
|
|
||||||
|
# Get the texts and images, and apply the chat template
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(
|
||||||
|
__class__.pixtral_chat_conversion(example["messages"]),
|
||||||
|
chat_template=chat_template,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
images = [
|
||||||
|
Image.open(example["images"])
|
||||||
|
if isinstance(example["images"], str)
|
||||||
|
else example["images"]
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
|
||||||
|
if max_images > 0:
|
||||||
|
images = [img_batch[:max_images] for img_batch in images]
|
||||||
|
|
||||||
|
# Tokenize the texts and process the images
|
||||||
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||||
|
labels = batch["input_ids"].clone()
|
||||||
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||||
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
processor.image_token
|
||||||
|
)
|
||||||
|
labels[labels == image_token_id] = -100
|
||||||
|
batch["labels"] = labels
|
||||||
|
|
||||||
|
if length_only:
|
||||||
|
return {
|
||||||
|
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|||||||
Reference in New Issue
Block a user