diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index f49e97f37..b9b67f875 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -4,6 +4,7 @@ Collators for multi-modal chat messages and packing from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union +from PIL import Image from transformers import PreTrainedTokenizerBase, ProcessorMixin from transformers.data.data_collator import DataCollatorMixin from transformers.utils import PaddingStrategy @@ -52,7 +53,12 @@ class MultiModalChatDataCollator(DataCollatorMixin): ) for example in examples ] - images = [example["images"] for example in examples] + images = [ + Image.open(example["images"]) + if isinstance(example["images"], str) + else example["images"] + for example in examples + ] if max_images > 0: images = [img_batch[:max_images] for img_batch in images]