diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index ac1a59e44..7e9223baa 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -42,9 +42,124 @@ class MultiModalChatDataCollator(DataCollatorMixin): ) return self.__class__.process_rows( - examples, self.processor, self.chat_template, self.max_images + examples, + self.processor, + self.chat_template, + self.max_images, + chat_template_type=self.chat_template_type, ) + @staticmethod + def preprocess(examples: list[dict]) -> list[dict]: + """ + Preprocess conversation examples to ensure consistent format. + Converts different conversation formats to OpenAI format with 'messages'. + Supports two formats: + 1. OpenAI format with 'messages' + 2. Legacy format with 'conversations' + + Args: + examples: list of conversation dictionaries + Returns: + dict in OpenAI format with 'messages' key + + Raises: + ValueError: If the conversation format is not supported + """ + role_mapping = { + "human": "user", + "gpt": "assistant", + } + + def normalize_role(role: str) -> str: + """Normalize role names to OpenAI format. Default to original role if not found.""" + return role_mapping.get(role, role) + + def convert_legacy_format(example: dict) -> dict: + """Convert legacy 'conversations' format to OpenAI 'messages' format.""" + messages = [ + { + "role": normalize_role(convo["from"]), + "content": convo["value"], + } + for convo in example["conversations"] + ] + + # Create new dict without 'conversations' key + result = deepcopy(example) + result.pop("conversations") + return {"messages": messages, **result} + + processed_examples = [] + for example in examples: + # OpenAI format + if "messages" in example: + processed_examples.append(example) + + # Legacy format + elif "conversations" in example: + processed_examples.append(convert_legacy_format(example)) + + else: + raise ValueError( + "Only `messages` and `conversations` message keys are currently supported." + ) + + return processed_examples + + @staticmethod + def process_images(examples, max_images): + """ + Process images from examples, ensuring consistency in image presence and applying max_images limit. + + Args: + examples: List of dictionaries that may contain 'images' key + max_images: Maximum number of images to keep per example (0 means no limit) + + Returns: + Either None (if no images) or List[Image objects] (if all examples have images) + + Raises: + ValueError: If there's a mix of None and non-None images + """ + + def get_image(example): + if "images" not in example: + return None + images = example["images"] + if isinstance(images, str): + return Image.open(images) + return images + + images = [get_image(example) for example in examples] + + # Count None and non-None images + none_count = sum(1 for img in images if img is None) + + # All images are None + if none_count == len(images): + return None + + # Mix of None and non-None images + if none_count > 0: + raise ValueError( + "All images should be either None or not None. " + "Please provide images for all examples or None." + ) + + # Apply max_images limit if specified + if max_images > 0: + images = [ + ( + img_batch[:max_images] + if isinstance(img_batch, (list, tuple)) + else img_batch + ) + for img_batch in images + ] + + return images + @staticmethod def pixtral_chat_conversion(messages): is_single_message = not isinstance(messages, list) @@ -68,136 +183,42 @@ class MultiModalChatDataCollator(DataCollatorMixin): return messages @staticmethod - def process_rows(examples, processor, chat_template, max_images, length_only=False): + def process_rows( + examples, + processor, + chat_template, + max_images, + length_only=False, + chat_template_type=None, + ): # HINT: use `_torch_collate_batch` to stack and pad tensors # see also DataCollatorWithFlattening and DefaultDataCollator # *** This is COPIED from the trl example sft_vlm.py code *** # use this as a starting point - def _preprocess(examples: list[dict]) -> list[dict]: - """ - Preprocess conversation examples to ensure consistent format. - - Converts different conversation formats to OpenAI format with 'messages'. - Supports two formats: - 1. OpenAI format with 'messages' - 2. Legacy format with 'conversations' - - Args: - examples: list of conversation dictionaries - - Returns: - dict in OpenAI format with 'messages' key - - Raises: - ValueError: If the conversation format is not supported - """ - role_mapping = { - "human": "user", - "gpt": "assistant", - } - - def normalize_role(role: str) -> str: - """Normalize role names to OpenAI format. Default to original role if not found.""" - return role_mapping.get(role, role) - - def convert_legacy_format(example: dict) -> dict: - """Convert legacy 'conversations' format to OpenAI 'messages' format.""" - messages = [ - { - "role": normalize_role(convo["from"]), - "content": convo["value"], - } - for convo in example["conversations"] - ] - - # Create new dict without 'conversations' key - result = deepcopy(example) - result.pop("conversations") - return {"messages": messages, **result} - - processed_examples = [] - for example in examples: - # OpenAI format - if "messages" in example: - processed_examples.append(example) - - # Legacy format - elif "conversations" in example: - processed_examples.append(convert_legacy_format(example)) - - else: - raise ValueError( - "Only `messages` and `conversations` message keys are currently supported." - ) - - return processed_examples - - def _process_images(examples, max_images): - """ - Process images from examples, ensuring consistency in image presence and applying max_images limit. - - Args: - examples: List of dictionaries that may contain 'images' key - max_images: Maximum number of images to keep per example (0 means no limit) - - Returns: - Either None (if no images) or List[Image objects] (if all examples have images) - - Raises: - ValueError: If there's a mix of None and non-None images - """ - - def get_image(example): - if "images" not in example: - return None - images = example["images"] - if isinstance(images, str): - return Image.open(images) - return images - - images = [get_image(example) for example in examples] - - # Count None and non-None images - none_count = sum(1 for img in images if img is None) - - # All images are None - if none_count == len(images): - return None - - # Mix of None and non-None images - if none_count > 0: - raise ValueError( - "All images should be either None or not None. " - "Please provide images for all examples or None." - ) - - # Apply max_images limit if specified - if max_images > 0: - images = [ - ( - img_batch[:max_images] - if isinstance(img_batch, (list, tuple)) - else img_batch - ) - for img_batch in images - ] - - return images - # Preprocess the examples - examples = _preprocess(examples) + examples = __class__.preprocess(examples) # Get the texts and images, and apply the chat template - texts = [ - processor.apply_chat_template( - example["messages"], chat_template=chat_template, tokenize=False - ) - for example in examples - ] + if chat_template_type == "pixtral": + texts = [ + processor.apply_chat_template( + __class__.pixtral_chat_conversion(example["messages"]), + chat_template=chat_template, + tokenize=False, + ) + for example in examples + ] + else: + texts = [ + processor.apply_chat_template( + example["messages"], chat_template=chat_template, tokenize=False + ) + for example in examples + ] - images = _process_images(examples, max_images=max_images) + images = __class__.process_images(examples, max_images=max_images) # Tokenize the texts and process the images batch = processor(text=texts, images=images, return_tensors="pt", padding=True)