feat: add audio support for gemma3n

This commit is contained in:
NanoCode012
2025-07-21 19:21:43 +07:00
parent 4890c81c12
commit c9aa8348aa

View File

@@ -84,16 +84,21 @@ class MultiModalChatDataCollator(DataCollatorMixin):
"attention_mask": attention_mask,
}
if "token_type_ids" in final_batch:
if "token_type_ids" in batch:
final_batch["token_type_ids"] = torch.nn.utils.rnn.pad_sequence(
batch["token_type_ids"], batch_first=True, padding_value=0
)
if "pixel_values" in final_batch:
if "pixel_values" in batch:
final_batch["pixel_values"] = torch.stack(batch["pixel_values"])
if "audio_values" in final_batch:
final_batch["audio_values"] = torch.stack(batch["audio_values"])
if "input_features" in batch:
final_batch["input_features"] = torch.stack(batch["input_features"])
if "input_features_mask" in batch:
final_batch["input_features_mask"] = torch.stack(
batch["input_features_mask"]
)
# Process the labels
final_batch["labels"] = self.processing_strategy.process_labels(