From f84c3b37e7b0278aa550bb7b863509634c241431 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Wed, 4 Dec 2024 11:59:45 -0500 Subject: [PATCH] lint --- src/axolotl/core/trainer_builder.py | 1 + src/axolotl/utils/collators/mm_chat.py | 75 ++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 75219a274..9e906179b 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1817,6 +1817,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator = MultiModalChatDataCollator kwargs["processor"] = self.processor kwargs["chat_template"] = training_args.chat_template + kwargs["chat_template_type"] = self.cfg.chat_template else: collator = DataCollatorForSeq2Seq diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index b9b67f875..e1923ffe3 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -20,6 +20,7 @@ class MultiModalChatDataCollator(DataCollatorMixin): processor: ProcessorMixin return_tensors: str = "pt" chat_template: Optional[str] = None + chat_template_type: Optional[str] = None packing: bool = False max_images: int = -1 padding: Union[bool, str, PaddingStrategy] = True @@ -33,11 +34,37 @@ class MultiModalChatDataCollator(DataCollatorMixin): self, examples: List[Union[List[int], Any, Dict[str, Any]]] ) -> Dict[str, Any]: # Handle dict or lists with proper padding and conversion to tensor. + if self.chat_template_type == "pixtral": + return self.__class__.process_rows_pixtral( + examples, self.processor, self.chat_template, self.max_images + ) return self.__class__.process_rows( examples, self.processor, self.chat_template, self.max_images ) + @staticmethod + def pixtral_chat_conversion(messages): + is_single_message = not isinstance(messages, list) + if is_single_message: + messages = [messages] + + for i, message in enumerate(messages): + if message["role"] == "user": + for j, content in enumerate(message["content"]): + if "type" in content and content["type"] == "text": + messages[i]["content"][j] = { + "type": "text", + "content": content["text"], + } + + if message["role"] == "assistant": + messages[i]["content"] = message["content"][0]["text"] + + if is_single_message: + return messages[0] + return messages + @staticmethod def process_rows(examples, processor, chat_template, max_images, length_only=False): # HINT: use `_torch_collate_batch` to stack and pad tensors @@ -81,3 +108,51 @@ class MultiModalChatDataCollator(DataCollatorMixin): "length": [len(sample["input_ids"]) for sample in batch["input_ids"]] } return batch + + @staticmethod + def process_rows_pixtral( + examples, processor, chat_template, max_images, length_only=False + ): + # HINT: use `_torch_collate_batch` to stack and pad tensors + # see also DataCollatorWithFlattening and DefaultDataCollator + + # *** This is COPIED from the trl example sft_vlm.py code *** + # use this as a starting point + + # Get the texts and images, and apply the chat template + texts = [ + processor.apply_chat_template( + __class__.pixtral_chat_conversion(example["messages"]), + chat_template=chat_template, + tokenize=False, + ) + for example in examples + ] + images = [ + Image.open(example["images"]) + if isinstance(example["images"], str) + else example["images"] + for example in examples + ] + + if max_images > 0: + images = [img_batch[:max_images] for img_batch in images] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + image_token_id = processor.tokenizer.convert_tokens_to_ids( + processor.image_token + ) + labels[labels == image_token_id] = -100 + batch["labels"] = labels + + if length_only: + return { + "length": [len(sample["input_ids"]) for sample in batch["input_ids"]] + } + return batch