From c9aa8348aa9818137959d73dbfa43662f8860e8a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 21 Jul 2025 19:21:43 +0700 Subject: [PATCH] feat: add audio support for gemma3n --- src/axolotl/utils/collators/mm_chat.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 1db06c2a1..dba27f7d8 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -84,16 +84,21 @@ class MultiModalChatDataCollator(DataCollatorMixin): "attention_mask": attention_mask, } - if "token_type_ids" in final_batch: + if "token_type_ids" in batch: final_batch["token_type_ids"] = torch.nn.utils.rnn.pad_sequence( batch["token_type_ids"], batch_first=True, padding_value=0 ) - if "pixel_values" in final_batch: + if "pixel_values" in batch: final_batch["pixel_values"] = torch.stack(batch["pixel_values"]) - if "audio_values" in final_batch: - final_batch["audio_values"] = torch.stack(batch["audio_values"]) + if "input_features" in batch: + final_batch["input_features"] = torch.stack(batch["input_features"]) + + if "input_features_mask" in batch: + final_batch["input_features_mask"] = torch.stack( + batch["input_features_mask"] + ) # Process the labels final_batch["labels"] = self.processing_strategy.process_labels(