wip on multimodal packing support
This commit is contained in:
@@ -35,7 +35,11 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
if self.packing:
|
if self.packing:
|
||||||
return self.__class__.process_rows_packing(
|
return self.__class__.process_rows_packing(
|
||||||
examples, self.processor, self.chat_template, self.max_images, self.sequence_length
|
examples,
|
||||||
|
self.processor,
|
||||||
|
self.chat_template,
|
||||||
|
self.max_images,
|
||||||
|
self.sequence_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.__class__.process_rows(
|
return self.__class__.process_rows(
|
||||||
@@ -43,15 +47,25 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_rows_packing(examples, processor, chat_template, max_images, sequence_length, length_only=False):
|
def process_rows_packing(
|
||||||
|
examples,
|
||||||
|
processor,
|
||||||
|
chat_template,
|
||||||
|
max_images,
|
||||||
|
sequence_length,
|
||||||
|
length_only=False,
|
||||||
|
):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Perform sample packing within a batch
|
# Perform sample packing within a batch
|
||||||
|
|
||||||
if processor.tokenizer.sep_token is None:
|
if processor.tokenizer.sep_token is None:
|
||||||
sep_token = '[SEP]'
|
sep_token = "[SEP]"
|
||||||
processor.tokenizer.add_tokens([sep_token])
|
processor.tokenizer.add_tokens([sep_token])
|
||||||
processor.tokenizer.sep_token = sep_token
|
processor.tokenizer.sep_token = sep_token
|
||||||
sep_token_id = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.sep_token)
|
sep_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
processor.tokenizer.sep_token
|
||||||
|
)
|
||||||
pad_token_id = processor.tokenizer.pad_token_id
|
pad_token_id = processor.tokenizer.pad_token_id
|
||||||
|
|
||||||
texts = [
|
texts = [
|
||||||
@@ -62,10 +76,13 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
]
|
]
|
||||||
images = [example["images"] for example in examples]
|
images = [example["images"] for example in examples]
|
||||||
|
|
||||||
|
if max_images > 0:
|
||||||
|
images = [img_batch[:max_images] for img_batch in images]
|
||||||
|
|
||||||
batch = processor(text=texts, images=images, padding=False)
|
batch = processor(text=texts, images=images, padding=False)
|
||||||
|
|
||||||
n_sequence = len(examples)
|
n_sequence = len(examples)
|
||||||
n = 0
|
n_seq_in_batch = 0
|
||||||
pack_len = 0
|
pack_len = 0
|
||||||
features_pack = {}
|
features_pack = {}
|
||||||
packed = {}
|
packed = {}
|
||||||
@@ -75,25 +92,29 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
packed[feature] = []
|
packed[feature] = []
|
||||||
features.remove("input_ids")
|
features.remove("input_ids")
|
||||||
|
|
||||||
for ii in range(n_sequence):
|
for seq_in_batch_id in range(n_sequence):
|
||||||
next_seq_len = len(batch["input_ids"][ii])
|
next_seq_len = len(batch["input_ids"][seq_in_batch_id])
|
||||||
if not pack_len + next_seq_len + 1 < sequence_length:
|
if not pack_len + next_seq_len + 1 < sequence_length:
|
||||||
n += 1
|
n_seq_in_batch += 1
|
||||||
pack_len += next_seq_len + 1
|
pack_len += next_seq_len + 1
|
||||||
features_pack["input_ids"] += batch["input_ids"][ii] + [sep_token_id]
|
features_pack["input_ids"] += batch["input_ids"][seq_in_batch_id] + [
|
||||||
|
sep_token_id
|
||||||
|
]
|
||||||
|
|
||||||
'''
|
"""
|
||||||
Do something with attention mask and cross-attention
|
Do something with attention mask and cross-attention
|
||||||
'''
|
"""
|
||||||
|
|
||||||
for feature in features:
|
for feature in features:
|
||||||
features_pack[feature] += batch[feature][ii]
|
features_pack[feature] += batch[feature][seq_in_batch_id]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for _ in range(sequence_length - pack_len):
|
for _ in range(sequence_length - pack_len):
|
||||||
features_pack["input_ids"] += [pad_token_id]
|
features_pack["input_ids"] += [pad_token_id]
|
||||||
|
|
||||||
packed["input_ids"].append(torch.tensor(features_pack["input_ids"].copy()))
|
packed["input_ids"].append(
|
||||||
|
torch.tensor(features_pack["input_ids"].copy())
|
||||||
|
)
|
||||||
|
|
||||||
for feature in features:
|
for feature in features:
|
||||||
packed[feature].append(torch.tensor(features_pack[feature].copy()))
|
packed[feature].append(torch.tensor(features_pack[feature].copy()))
|
||||||
@@ -105,11 +126,11 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
processor.image_token
|
processor.image_token
|
||||||
)
|
)
|
||||||
labels = [pack.clone() for pack in packed["input_ids"]]
|
labels = [pack.clone() for pack in packed["input_ids"]]
|
||||||
for ii , label in enumerate(labels):
|
for label_id, label in enumerate(labels):
|
||||||
labels[ii][label == processor.tokenizer.pad_token_id] = -100 #
|
labels[label_id][label == processor.tokenizer.pad_token_id] = -100 #
|
||||||
# Ignore the image token index in the loss computation (model specific)
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
|
||||||
labels[ii][label == image_token_id] = -100
|
labels[label_id][label == image_token_id] = -100
|
||||||
packed["labels"] = labels
|
packed["labels"] = labels
|
||||||
|
|
||||||
if length_only:
|
if length_only:
|
||||||
|
|||||||
Reference in New Issue
Block a user