Various fixes for VLMs (#3063)
* fix to not use batch feature indexing * more vlm fixes * use AutoModelForImageTextToText * add example yaml and need num2words for chat template * improve handling of adding image tokens to conversation * add lfm2-vl support * update the lfm readme * fix markdown and add rtol for loss checks * feat: add smolvlm2 processing strat * fix: check for causal-conv1d in lfm models * feat: add docs for lfm2 * feat: add new models and tips to docs * feat: add smolvlm2 docs and remove extra dep * chore: update docs * feat: add video instructions * chore: cleanup * chore: comments * fix: typo * feat: add usage stats * chore: refactor --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -1,26 +1,13 @@
|
||||
"""Shared constants for axolotl.loaders module"""
|
||||
|
||||
from transformers import (
|
||||
Gemma3ForConditionalGeneration,
|
||||
Gemma3nForConditionalGeneration,
|
||||
Llama4ForConditionalGeneration,
|
||||
LlavaForConditionalGeneration,
|
||||
Mistral3ForConditionalGeneration,
|
||||
MllamaForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
from transformers import AutoModelForImageTextToText
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||
"mllama": MllamaForConditionalGeneration,
|
||||
"llama4": Llama4ForConditionalGeneration,
|
||||
"llava": LlavaForConditionalGeneration,
|
||||
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
||||
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
||||
"mistral3": Mistral3ForConditionalGeneration,
|
||||
"gemma3": Gemma3ForConditionalGeneration,
|
||||
"gemma3n": Gemma3nForConditionalGeneration,
|
||||
}
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText
|
||||
|
||||
try:
|
||||
from transformers import VoxtralForConditionalGeneration
|
||||
|
||||
@@ -25,6 +25,7 @@ from peft import (
|
||||
from torch.distributed import DeviceMesh
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForVision2Seq,
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
@@ -212,6 +213,7 @@ class ModelLoader:
|
||||
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
||||
self._set_quantization_config()
|
||||
self._set_attention_config()
|
||||
self._check_model_requirements()
|
||||
|
||||
def _apply_post_model_load_setup(self):
|
||||
"""Configure the model after it has been loaded."""
|
||||
@@ -432,6 +434,8 @@ class ModelLoader:
|
||||
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
||||
self.model_config.model_type, AutoModelForVision2Seq
|
||||
)
|
||||
if isinstance(self.auto_model_loader, str):
|
||||
self.auto_model_loader = AutoModelForImageTextToText
|
||||
|
||||
def _set_device_map_config(self):
|
||||
"""Setup `device_map` according to config"""
|
||||
@@ -628,6 +632,16 @@ class ModelLoader:
|
||||
if self.cfg.low_cpu_mem_usage:
|
||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
def _check_model_requirements(self):
|
||||
if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]:
|
||||
from transformers.utils.import_utils import is_causal_conv1d_available
|
||||
|
||||
if is_causal_conv1d_available():
|
||||
raise ImportError(
|
||||
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
|
||||
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
|
||||
)
|
||||
|
||||
def _configure_zero3_memory_efficient_loading(
|
||||
self,
|
||||
) -> HfTrainerDeepSpeedConfig | None:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.Image import Resampling
|
||||
from torch import Tensor, zeros_like
|
||||
from transformers import ProcessorMixin, VoxtralProcessor
|
||||
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
|
||||
from transformers.image_utils import load_image
|
||||
|
||||
from axolotl.utils.dict import remove_none_values
|
||||
@@ -138,7 +138,7 @@ class ProcessingStrategy:
|
||||
image_key = key
|
||||
break
|
||||
|
||||
# if the image key exists, add the image to the first message
|
||||
# if the image key exists, add the image to the first user message
|
||||
if image_key is not None and processed_example[image_key] is not None:
|
||||
# TODO: check if it's normal to be single image only for common datasets
|
||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||
@@ -179,26 +179,34 @@ class ProcessingStrategy:
|
||||
|
||||
# Look for any image type in the first message
|
||||
# some dataset have an {type: "image"} in the first message
|
||||
msg_ind_to_add = None
|
||||
ind_to_add = None
|
||||
first_user_idx = None
|
||||
|
||||
for i, content in enumerate(
|
||||
processed_example["messages"][0]["content"]
|
||||
):
|
||||
# Usually datasets created with image columns, don't have it in the messages itself
|
||||
if content["type"] == "image" and all(
|
||||
k not in content for k in ["image", "url", "path", "base64"]
|
||||
for msg_idx, msg_content in enumerate(processed_example["messages"]):
|
||||
if first_user_idx is None and msg_content["role"] == "user":
|
||||
first_user_idx = msg_idx
|
||||
for i, content in enumerate(
|
||||
processed_example["messages"][msg_idx]["content"]
|
||||
):
|
||||
ind_to_add = i
|
||||
break
|
||||
# Usually datasets created with image columns, don't have it in the messages itself
|
||||
if content["type"] == "image" and all(
|
||||
k not in content for k in ["image", "url", "path", "base64"]
|
||||
):
|
||||
msg_ind_to_add = msg_idx
|
||||
ind_to_add = i
|
||||
break
|
||||
|
||||
# If an image type is found, add the image to that index
|
||||
if ind_to_add is not None:
|
||||
processed_example["messages"][0]["content"][ind_to_add][
|
||||
"image"
|
||||
] = image_value
|
||||
if ind_to_add is not None and msg_ind_to_add is not None:
|
||||
processed_example["messages"][msg_ind_to_add]["content"][
|
||||
ind_to_add
|
||||
]["image"] = image_value
|
||||
else:
|
||||
# if no image type is found, add it to end of the first message
|
||||
processed_example["messages"][0]["content"].append(
|
||||
# if no image type is found, add it to end of the first user message
|
||||
if first_user_idx is None:
|
||||
first_user_idx = 0
|
||||
processed_example["messages"][first_user_idx]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": image_value,
|
||||
@@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy):
|
||||
return labels
|
||||
|
||||
|
||||
class SmolVLM2ProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for SmolVLM2"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: ProcessorMixin,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
self.image_token = "<image>" # nosec
|
||||
|
||||
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
|
||||
processor.tokenizer.additional_special_tokens.index(self.image_token)
|
||||
]
|
||||
|
||||
|
||||
def get_processing_strategy(
|
||||
processor: ProcessorMixin,
|
||||
chat_template,
|
||||
@@ -402,32 +428,43 @@ def get_processing_strategy(
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
processing_kwargs = {
|
||||
"processor": processor,
|
||||
"chat_template": chat_template,
|
||||
"image_size": image_size,
|
||||
"image_resize_algorithm": image_resize_algorithm,
|
||||
}
|
||||
|
||||
if chat_template_type in [None, "tokenizer_default"] and hasattr(
|
||||
processor.tokenizer, "chat_template"
|
||||
):
|
||||
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
|
||||
|
||||
if chat_template_type == "qwen2_vl":
|
||||
return Qwen2VLProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
if chat_template_type == "gemma3":
|
||||
return Gemma3ProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
if chat_template_type == "gemma3n":
|
||||
return Gemma3nProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
)
|
||||
if chat_template_type in [
|
||||
"llama3_2_vision",
|
||||
"llama4",
|
||||
"llava",
|
||||
"mistral_v7_tekken",
|
||||
"pixtral",
|
||||
]:
|
||||
return ProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(processor, VoxtralProcessor):
|
||||
return VoxtralProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
|
||||
if isinstance(processor, SmolVLMProcessor):
|
||||
return SmolVLM2ProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
# llama3_2_vision, llama4, llava
|
||||
# mistral_v7_tekken, pixtral, lfm2vl
|
||||
return ProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
@@ -129,13 +129,21 @@ class ChatTemplatePrompter(Prompter):
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if hasattr(batch, "to_dict"):
|
||||
batch = batch.to_dict()
|
||||
else:
|
||||
batch = dict(batch)
|
||||
|
||||
# workaround since processor works in batches instead of single examples
|
||||
out = {}
|
||||
for k, val in batch.items():
|
||||
if k in ["pixel_values"]:
|
||||
batch[k] = val.tolist()
|
||||
if hasattr(val, "tolist"):
|
||||
out[k] = (
|
||||
val.tolist() if k == "pixel_values" else val.squeeze(0).tolist()
|
||||
)
|
||||
else:
|
||||
batch[k] = val.squeeze().tolist()
|
||||
return batch
|
||||
out[k] = val
|
||||
return out
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
@@ -433,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
|
||||
else:
|
||||
input_ids = tokenized_res["input_ids"]
|
||||
tokenized_prompt = tokenized_res
|
||||
tokenized_prompt = dict(tokenized_res)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
if isinstance(prompt_ids, dict):
|
||||
user_prompt_len = len(prompt_ids["input_ids"])
|
||||
else:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
||||
else:
|
||||
labels = input_ids
|
||||
|
||||
@@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.data.data_collator import DataCollatorMixin
|
||||
@@ -42,62 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
examples = self.processing_strategy(examples)
|
||||
|
||||
# Initialize batch
|
||||
batch: dict[str, Any] = {}
|
||||
messages = [ex["messages"] for ex in examples]
|
||||
|
||||
# Process each example
|
||||
for example in examples:
|
||||
# Apply chat template to process the example
|
||||
# This method requires transformers>=4.49.0
|
||||
result = self.processing_strategy.processor.apply_chat_template(
|
||||
example["messages"],
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_dict=True,
|
||||
chat_template=self.processing_strategy.chat_template,
|
||||
)
|
||||
|
||||
# TODO: Check if need handling for len(input_ids) > sequence_len
|
||||
|
||||
# Add the processed tensors to our batch
|
||||
for key in result.keys():
|
||||
if key not in batch:
|
||||
batch[key] = []
|
||||
|
||||
batch[key].append(result[key].squeeze(0))
|
||||
|
||||
# Pad sequences to the same length
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
batch["input_ids"],
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
batch = self.processing_strategy.processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_dict=True,
|
||||
chat_template=self.processing_strategy.chat_template,
|
||||
)
|
||||
|
||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
batch["attention_mask"], batch_first=True, padding_value=0
|
||||
)
|
||||
|
||||
# Create the final batch
|
||||
final_batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
for key, val in batch.items():
|
||||
if key in ["input_ids", "attention_mask"]:
|
||||
continue
|
||||
|
||||
if key in ["token_type_ids", "cross_attention_mask"]:
|
||||
final_batch[key] = torch.nn.utils.rnn.pad_sequence(
|
||||
val, batch_first=True, padding_value=0
|
||||
)
|
||||
else:
|
||||
final_batch[key] = torch.stack(val)
|
||||
|
||||
# Process the labels
|
||||
final_batch["labels"] = self.processing_strategy.process_labels(
|
||||
final_batch["input_ids"]
|
||||
)
|
||||
batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
|
||||
|
||||
return final_batch
|
||||
return batch
|
||||
|
||||
Reference in New Issue
Block a user