feat: support audio and return pixel values in collator
This commit is contained in:
@@ -50,7 +50,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
# This method requires transformers>=4.49.0
|
# This method requires transformers>=4.49.0
|
||||||
result = self.processing_strategy.processor.apply_chat_template(
|
result = self.processing_strategy.processor.apply_chat_template(
|
||||||
example["messages"],
|
example["messages"],
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=False,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
@@ -84,6 +84,17 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if "token_type_ids" in final_batch:
|
||||||
|
final_batch["token_type_ids"] = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
batch["token_type_ids"], batch_first=True, padding_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
if "pixel_values" in final_batch:
|
||||||
|
final_batch["pixel_values"] = torch.stack(batch["pixel_values"])
|
||||||
|
|
||||||
|
if "audio_values" in final_batch:
|
||||||
|
final_batch["audio_values"] = torch.stack(batch["audio_values"])
|
||||||
|
|
||||||
# Process the labels
|
# Process the labels
|
||||||
final_batch["labels"] = self.processing_strategy.process_labels(
|
final_batch["labels"] = self.processing_strategy.process_labels(
|
||||||
final_batch["input_ids"]
|
final_batch["input_ids"]
|
||||||
|
|||||||
Reference in New Issue
Block a user