wip on multimodal packing support

This commit is contained in:
sunny
2024-10-04 15:08:36 -04:00
parent 08143c7b0d
commit cdd8be7097

View File

@@ -35,7 +35,11 @@ class MultiModalChatDataCollator(DataCollatorMixin):
# Handle dict or lists with proper padding and conversion to tensor. # Handle dict or lists with proper padding and conversion to tensor.
if self.packing: if self.packing:
return self.__class__.process_rows_packing( return self.__class__.process_rows_packing(
examples, self.processor, self.chat_template, self.max_images, self.sequence_length examples,
self.processor,
self.chat_template,
self.max_images,
self.sequence_length,
) )
return self.__class__.process_rows( return self.__class__.process_rows(
@@ -43,15 +47,25 @@ class MultiModalChatDataCollator(DataCollatorMixin):
) )
@staticmethod @staticmethod
def process_rows_packing(examples, processor, chat_template, max_images, sequence_length, length_only=False): def process_rows_packing(
examples,
processor,
chat_template,
max_images,
sequence_length,
length_only=False,
):
import torch import torch
# Perform sample packing within a batch # Perform sample packing within a batch
if processor.tokenizer.sep_token is None: if processor.tokenizer.sep_token is None:
sep_token = '[SEP]' sep_token = "[SEP]"
processor.tokenizer.add_tokens([sep_token]) processor.tokenizer.add_tokens([sep_token])
processor.tokenizer.sep_token = sep_token processor.tokenizer.sep_token = sep_token
sep_token_id = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.sep_token) sep_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.tokenizer.sep_token
)
pad_token_id = processor.tokenizer.pad_token_id pad_token_id = processor.tokenizer.pad_token_id
texts = [ texts = [
@@ -62,10 +76,13 @@ class MultiModalChatDataCollator(DataCollatorMixin):
] ]
images = [example["images"] for example in examples] images = [example["images"] for example in examples]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
batch = processor(text=texts, images=images, padding=False) batch = processor(text=texts, images=images, padding=False)
n_sequence = len(examples) n_sequence = len(examples)
n = 0 n_seq_in_batch = 0
pack_len = 0 pack_len = 0
features_pack = {} features_pack = {}
packed = {} packed = {}
@@ -75,25 +92,29 @@ class MultiModalChatDataCollator(DataCollatorMixin):
packed[feature] = [] packed[feature] = []
features.remove("input_ids") features.remove("input_ids")
for ii in range(n_sequence): for seq_in_batch_id in range(n_sequence):
next_seq_len = len(batch["input_ids"][ii]) next_seq_len = len(batch["input_ids"][seq_in_batch_id])
if not pack_len + next_seq_len + 1 < sequence_length: if not pack_len + next_seq_len + 1 < sequence_length:
n += 1 n_seq_in_batch += 1
pack_len += next_seq_len + 1 pack_len += next_seq_len + 1
features_pack["input_ids"] += batch["input_ids"][ii] + [sep_token_id] features_pack["input_ids"] += batch["input_ids"][seq_in_batch_id] + [
sep_token_id
]
''' """
Do something with attention mask and cross-attention Do something with attention mask and cross-attention
''' """
for feature in features: for feature in features:
features_pack[feature] += batch[feature][ii] features_pack[feature] += batch[feature][seq_in_batch_id]
else: else:
for _ in range(sequence_length - pack_len): for _ in range(sequence_length - pack_len):
features_pack["input_ids"] += [pad_token_id] features_pack["input_ids"] += [pad_token_id]
packed["input_ids"].append(torch.tensor(features_pack["input_ids"].copy())) packed["input_ids"].append(
torch.tensor(features_pack["input_ids"].copy())
)
for feature in features: for feature in features:
packed[feature].append(torch.tensor(features_pack[feature].copy())) packed[feature].append(torch.tensor(features_pack[feature].copy()))
@@ -105,11 +126,11 @@ class MultiModalChatDataCollator(DataCollatorMixin):
processor.image_token processor.image_token
) )
labels = [pack.clone() for pack in packed["input_ids"]] labels = [pack.clone() for pack in packed["input_ids"]]
for ii , label in enumerate(labels): for label_id, label in enumerate(labels):
labels[ii][label == processor.tokenizer.pad_token_id] = -100 # labels[label_id][label == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific) # Ignore the image token index in the loss computation (model specific)
labels[ii][label == image_token_id] = -100 labels[label_id][label == image_token_id] = -100
packed["labels"] = labels packed["labels"] = labels
if length_only: if length_only: