diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 7e9223baa..772fe818a 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -36,11 +36,6 @@ class MultiModalChatDataCollator(DataCollatorMixin): self, examples: list[Union[list[int], Any, dict[str, Any]]] ) -> dict[str, Any]: # Handle dict or lists with proper padding and conversion to tensor. - if self.chat_template_type == "pixtral": - return self.__class__.process_rows_pixtral( - examples, self.processor, self.chat_template, self.max_images - ) - return self.__class__.process_rows( examples, self.processor, @@ -218,6 +213,8 @@ class MultiModalChatDataCollator(DataCollatorMixin): for example in examples ] + if chat_template_type == "llava": + max_images = 1 images = __class__.process_images(examples, max_images=max_images) # Tokenize the texts and process the images @@ -238,51 +235,3 @@ class MultiModalChatDataCollator(DataCollatorMixin): "length": [len(sample["input_ids"]) for sample in batch["input_ids"]] } return batch - - @staticmethod - def process_rows_pixtral( - examples, processor, chat_template, max_images, length_only=False - ): - # HINT: use `_torch_collate_batch` to stack and pad tensors - # see also DataCollatorWithFlattening and DefaultDataCollator - - # *** This is COPIED from the trl example sft_vlm.py code *** - # use this as a starting point - - # Get the texts and images, and apply the chat template - texts = [ - processor.apply_chat_template( - __class__.pixtral_chat_conversion(example["messages"]), - chat_template=chat_template, - tokenize=False, - ) - for example in examples - ] - images = [ - Image.open(example["images"]) - if isinstance(example["images"], str) - else example["images"] - for example in examples - ] - - if max_images > 0: - images = [img_batch[:max_images] for img_batch in images] - - # Tokenize the texts and process the images - batch = processor(text=texts, images=images, return_tensors="pt", padding=True) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() - labels[labels == processor.tokenizer.pad_token_id] = -100 # - # Ignore the image token index in the loss computation (model specific) - image_token_id = processor.tokenizer.convert_tokens_to_ids( - processor.image_token - ) - labels[labels == image_token_id] = -100 - batch["labels"] = labels - - if length_only: - return { - "length": [len(sample["input_ids"]) for sample in batch["input_ids"]] - } - return batch