fix(vlm): handle legacy conversation data format and check image in data (#2018) [skip ci]

* fix: handle legacy conversation data format and check image in data

* feat: add test for llama vision

* feat: add max_steps to test

* fix: incorrect indent and return preprocess

* feat: use smaller model and dataset

* chore: add extra config for sharegpt dataset
This commit is contained in:
NanoCode012
2024-12-03 12:01:31 +07:00
committed by GitHub
parent d5f58b6509
commit 822c904092
5 changed files with 239 additions and 11 deletions

View File

@@ -1,8 +1,10 @@
"""
Collators for multi-modal chat messages and packing
"""
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
@@ -30,8 +32,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
raise ValueError("Packing is currently not supported.")
def torch_call(
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]:
self, examples: list[Union[list[int], Any, dict[str, Any]]]
) -> dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
@@ -46,6 +48,120 @@ class MultiModalChatDataCollator(DataCollatorMixin):
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
def _preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
return {"messages": messages, **result}
processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)
# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))
else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)
return processed_examples
def _process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""
def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images
images = [get_image(example) for example in examples]
# Count None and non-None images
none_count = sum(1 for img in images if img is None)
# All images are None
if none_count == len(images):
return None
# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)
# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]
return images
# Preprocess the examples
examples = _preprocess(examples)
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
@@ -53,15 +169,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
)
for example in examples
]
images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
images = _process_images(examples, max_images=max_images)
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

View File

@@ -0,0 +1,116 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestLlamaVision(unittest.TestCase):
"""
Test case for Llama Vision models
"""
@with_temp_dir
def test_lora_llama_vision_text_only_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
"processor_type": "AutoProcessor",
"skip_prepare_dataset": True,
"remove_unused_columns": False,
"sample_packing": False,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@with_temp_dir
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
"processor_type": "AutoProcessor",
"skip_prepare_dataset": True,
"remove_unused_columns": False,
"sample_packing": False,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
{
"path": "axolotl-ai-co/llava-instruct-mix-vsft-small",
"type": "chat_template",
"split": "train",
"field_messages": "messages",
},
],
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -57,6 +57,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
}
)
normalize_config(cfg)

View File

@@ -56,6 +56,7 @@ class TestCustomOptimizers(unittest.TestCase):
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "optimi_adamw",
"max_steps": 5,
"lr_scheduler": "cosine",
}
)

View File

@@ -58,6 +58,7 @@ class TestReLoraLlama(unittest.TestCase):
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"max_steps": 5,
"lr_scheduler": "cosine",
}
)