wip on multimodal sample packing support
This commit is contained in:
@@ -20,6 +20,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
return_tensors: str = "pt"
|
||||
chat_template: Optional[str] = None
|
||||
packing: bool = False
|
||||
sequence_length: Optional[int] = None
|
||||
max_images: int = -1
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
@@ -32,11 +33,91 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
||||
) -> Dict[str, Any]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
if self.packing:
|
||||
return self.__class__.process_rows_packing(
|
||||
examples, self.processor, self.chat_template, self.max_images, self.sequence_length
|
||||
)
|
||||
|
||||
return self.__class__.process_rows(
|
||||
examples, self.processor, self.chat_template, self.max_images
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def process_rows_packing(examples, processor, chat_template, max_images, sequence_length, length_only=False):
|
||||
import torch
|
||||
# Perform sample packing within a batch
|
||||
|
||||
if processor.tokenizer.sep_token is None:
|
||||
sep_token = '[SEP]'
|
||||
processor.tokenizer.add_tokens([sep_token])
|
||||
processor.tokenizer.sep_token = sep_token
|
||||
sep_token_id = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.sep_token)
|
||||
pad_token_id = processor.tokenizer.pad_token_id
|
||||
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
example["messages"], chat_template=chat_template, tokenize=False
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
images = [example["images"] for example in examples]
|
||||
|
||||
batch = processor(text=texts, images=images, padding=False)
|
||||
|
||||
n_sequence = len(examples)
|
||||
n = 0
|
||||
pack_len = 0
|
||||
features_pack = {}
|
||||
packed = {}
|
||||
features = list[batch.keys()]
|
||||
for feature in features:
|
||||
features_pack[feature] = []
|
||||
packed[feature] = []
|
||||
features.remove("input_ids")
|
||||
|
||||
for ii in range(n_sequence):
|
||||
next_seq_len = len(batch["input_ids"][ii])
|
||||
if not pack_len + next_seq_len + 1 < sequence_length:
|
||||
n += 1
|
||||
pack_len += next_seq_len + 1
|
||||
features_pack["input_ids"] += batch["input_ids"][ii] + [sep_token_id]
|
||||
|
||||
'''
|
||||
Do something with attention mask and cross-attention
|
||||
'''
|
||||
|
||||
for feature in features:
|
||||
features_pack[feature] += batch[feature][ii]
|
||||
|
||||
else:
|
||||
for _ in range(sequence_length - pack_len):
|
||||
features_pack["input_ids"] += [pad_token_id]
|
||||
|
||||
packed["input_ids"].append(torch.tensor(features_pack["input_ids"].copy()))
|
||||
|
||||
for feature in features:
|
||||
packed[feature].append(torch.tensor(features_pack[feature].copy()))
|
||||
features_pack[feature] = []
|
||||
|
||||
pack_len = 0
|
||||
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.image_token
|
||||
)
|
||||
labels = [pack.clone() for pack in packed["input_ids"]]
|
||||
for ii , label in enumerate(labels):
|
||||
labels[ii][label == processor.tokenizer.pad_token_id] = -100 #
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
|
||||
labels[ii][label == image_token_id] = -100
|
||||
packed["labels"] = labels
|
||||
|
||||
if length_only:
|
||||
return {
|
||||
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||
}
|
||||
return packed
|
||||
|
||||
@staticmethod
|
||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||
|
||||
Reference in New Issue
Block a user