diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index b62abb54d..eb07fa08d 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -35,7 +35,11 @@ class MultiModalChatDataCollator(DataCollatorMixin): # Handle dict or lists with proper padding and conversion to tensor. if self.packing: return self.__class__.process_rows_packing( - examples, self.processor, self.chat_template, self.max_images, self.sequence_length + examples, + self.processor, + self.chat_template, + self.max_images, + self.sequence_length, ) return self.__class__.process_rows( @@ -43,17 +47,27 @@ class MultiModalChatDataCollator(DataCollatorMixin): ) @staticmethod - def process_rows_packing(examples, processor, chat_template, max_images, sequence_length, length_only=False): + def process_rows_packing( + examples, + processor, + chat_template, + max_images, + sequence_length, + length_only=False, + ): import torch + # Perform sample packing within a batch if processor.tokenizer.sep_token is None: - sep_token = '[SEP]' + sep_token = "[SEP]" processor.tokenizer.add_tokens([sep_token]) processor.tokenizer.sep_token = sep_token - sep_token_id = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.sep_token) + sep_token_id = processor.tokenizer.convert_tokens_to_ids( + processor.tokenizer.sep_token + ) pad_token_id = processor.tokenizer.pad_token_id - + texts = [ processor.apply_chat_template( example["messages"], chat_template=chat_template, tokenize=False @@ -62,10 +76,13 @@ class MultiModalChatDataCollator(DataCollatorMixin): ] images = [example["images"] for example in examples] + if max_images > 0: + images = [img_batch[:max_images] for img_batch in images] + batch = processor(text=texts, images=images, padding=False) - + n_sequence = len(examples) - n = 0 + n_seq_in_batch = 0 pack_len = 0 features_pack = {} packed = {} @@ -75,25 +92,29 @@ class MultiModalChatDataCollator(DataCollatorMixin): packed[feature] = [] features.remove("input_ids") - for ii in range(n_sequence): - next_seq_len = len(batch["input_ids"][ii]) + for seq_in_batch_id in range(n_sequence): + next_seq_len = len(batch["input_ids"][seq_in_batch_id]) if not pack_len + next_seq_len + 1 < sequence_length: - n += 1 + n_seq_in_batch += 1 pack_len += next_seq_len + 1 - features_pack["input_ids"] += batch["input_ids"][ii] + [sep_token_id] + features_pack["input_ids"] += batch["input_ids"][seq_in_batch_id] + [ + sep_token_id + ] - ''' + """ Do something with attention mask and cross-attention - ''' + """ for feature in features: - features_pack[feature] += batch[feature][ii] + features_pack[feature] += batch[feature][seq_in_batch_id] else: for _ in range(sequence_length - pack_len): features_pack["input_ids"] += [pad_token_id] - packed["input_ids"].append(torch.tensor(features_pack["input_ids"].copy())) + packed["input_ids"].append( + torch.tensor(features_pack["input_ids"].copy()) + ) for feature in features: packed[feature].append(torch.tensor(features_pack[feature].copy())) @@ -105,11 +126,11 @@ class MultiModalChatDataCollator(DataCollatorMixin): processor.image_token ) labels = [pack.clone() for pack in packed["input_ids"]] - for ii , label in enumerate(labels): - labels[ii][label == processor.tokenizer.pad_token_id] = -100 # - # Ignore the image token index in the loss computation (model specific) - - labels[ii][label == image_token_id] = -100 + for label_id, label in enumerate(labels): + labels[label_id][label == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + + labels[label_id][label == image_token_id] = -100 packed["labels"] = labels if length_only: @@ -117,7 +138,7 @@ class MultiModalChatDataCollator(DataCollatorMixin): "length": [len(sample["input_ids"]) for sample in batch["input_ids"]] } return packed - + @staticmethod def process_rows(examples, processor, chat_template, max_images, length_only=False): # HINT: use `_torch_collate_batch` to stack and pad tensors