From 312832e1fe7d1e238cb6cc9f89210e4e2ad8fe9e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 21 Jul 2025 17:19:03 +0700 Subject: [PATCH] feat: support audio and return pixel values in collator --- src/axolotl/utils/collators/mm_chat.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 75d72f8dc..1db06c2a1 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -50,7 +50,7 @@ class MultiModalChatDataCollator(DataCollatorMixin): # This method requires transformers>=4.49.0 result = self.processing_strategy.processor.apply_chat_template( example["messages"], - add_generation_prompt=True, + add_generation_prompt=False, tokenize=True, return_tensors="pt", padding=True, @@ -84,6 +84,17 @@ class MultiModalChatDataCollator(DataCollatorMixin): "attention_mask": attention_mask, } + if "token_type_ids" in final_batch: + final_batch["token_type_ids"] = torch.nn.utils.rnn.pad_sequence( + batch["token_type_ids"], batch_first=True, padding_value=0 + ) + + if "pixel_values" in final_batch: + final_batch["pixel_values"] = torch.stack(batch["pixel_values"]) + + if "audio_values" in final_batch: + final_batch["audio_values"] = torch.stack(batch["audio_values"]) + # Process the labels final_batch["labels"] = self.processing_strategy.process_labels( final_batch["input_ids"]