Various fixes for VLMs (#3063)

* fix to not use batch feature indexing

* more vlm fixes

* use AutoModelForImageTextToText

* add example yaml and need num2words for chat template

* improve handling of adding image tokens to conversation

* add lfm2-vl support

* update the lfm readme

* fix markdown and add rtol for loss checks

* feat: add smolvlm2 processing strat

* fix: check for causal-conv1d in lfm models

* feat: add docs for lfm2

* feat: add new models and tips to docs

* feat: add smolvlm2 docs and remove extra dep

* chore: update docs

* feat: add video instructions

* chore: cleanup

* chore: comments

* fix: typo

* feat: add usage stats

* chore: refactor

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2025-08-15 10:52:57 -04:00
committed by GitHub
parent d1de6f5f3d
commit 130ef7c51a
13 changed files with 391 additions and 121 deletions

View File

@@ -1,26 +1,13 @@
"""Shared constants for axolotl.loaders module"""
from transformers import (
Gemma3ForConditionalGeneration,
Gemma3nForConditionalGeneration,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
from transformers import AutoModelForImageTextToText
from transformers.models.auto.modeling_auto import (
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
)
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
"llama4": Llama4ForConditionalGeneration,
"llava": LlavaForConditionalGeneration,
"qwen2_vl": Qwen2VLForConditionalGeneration,
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
"gemma3n": Gemma3nForConditionalGeneration,
}
MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText
try:
from transformers import VoxtralForConditionalGeneration

View File

@@ -25,6 +25,7 @@ from peft import (
from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
@@ -212,6 +213,7 @@ class ModelLoader:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
self._check_model_requirements()
def _apply_post_model_load_setup(self):
"""Configure the model after it has been loaded."""
@@ -432,6 +434,8 @@ class ModelLoader:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
)
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
def _set_device_map_config(self):
"""Setup `device_map` according to config"""
@@ -628,6 +632,16 @@ class ModelLoader:
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
def _check_model_requirements(self):
if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]:
from transformers.utils.import_utils import is_causal_conv1d_available
if is_causal_conv1d_available():
raise ImportError(
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
)
def _configure_zero3_memory_efficient_loading(
self,
) -> HfTrainerDeepSpeedConfig | None:

View File

@@ -6,7 +6,7 @@ from typing import Optional
from PIL import Image, ImageOps
from PIL.Image import Resampling
from torch import Tensor, zeros_like
from transformers import ProcessorMixin, VoxtralProcessor
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
from transformers.image_utils import load_image
from axolotl.utils.dict import remove_none_values
@@ -138,7 +138,7 @@ class ProcessingStrategy:
image_key = key
break
# if the image key exists, add the image to the first message
# if the image key exists, add the image to the first user message
if image_key is not None and processed_example[image_key] is not None:
# TODO: check if it's normal to be single image only for common datasets
# From observation, it's usually a list of single image but some datasets may have several columns for images
@@ -179,26 +179,34 @@ class ProcessingStrategy:
# Look for any image type in the first message
# some dataset have an {type: "image"} in the first message
msg_ind_to_add = None
ind_to_add = None
first_user_idx = None
for i, content in enumerate(
processed_example["messages"][0]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
for msg_idx, msg_content in enumerate(processed_example["messages"]):
if first_user_idx is None and msg_content["role"] == "user":
first_user_idx = msg_idx
for i, content in enumerate(
processed_example["messages"][msg_idx]["content"]
):
ind_to_add = i
break
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
msg_ind_to_add = msg_idx
ind_to_add = i
break
# If an image type is found, add the image to that index
if ind_to_add is not None:
processed_example["messages"][0]["content"][ind_to_add][
"image"
] = image_value
if ind_to_add is not None and msg_ind_to_add is not None:
processed_example["messages"][msg_ind_to_add]["content"][
ind_to_add
]["image"] = image_value
else:
# if no image type is found, add it to end of the first message
processed_example["messages"][0]["content"].append(
# if no image type is found, add it to end of the first user message
if first_user_idx is None:
first_user_idx = 0
processed_example["messages"][first_user_idx]["content"].append(
{
"type": "image",
"image": image_value,
@@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy):
return labels
class SmolVLM2ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for SmolVLM2"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<image>" # nosec
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index(self.image_token)
]
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -402,32 +428,43 @@ def get_processing_strategy(
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
processing_kwargs = {
"processor": processor,
"chat_template": chat_template,
"image_size": image_size,
"image_resize_algorithm": image_resize_algorithm,
}
if chat_template_type in [None, "tokenizer_default"] and hasattr(
processor.tokenizer, "chat_template"
):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
if chat_template_type == "gemma3":
return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
if chat_template_type == "gemma3n":
return Gemma3nProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [
"llama3_2_vision",
"llama4",
"llava",
"mistral_v7_tekken",
"pixtral",
]:
return ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
if isinstance(processor, VoxtralProcessor):
return VoxtralProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
if isinstance(processor, SmolVLMProcessor):
return SmolVLM2ProcessingStrategy(
**processing_kwargs,
)
# llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy(
**processing_kwargs,
)

View File

@@ -129,13 +129,21 @@ class ChatTemplatePrompter(Prompter):
images=images,
return_tensors="pt",
)
if hasattr(batch, "to_dict"):
batch = batch.to_dict()
else:
batch = dict(batch)
# workaround since processor works in batches instead of single examples
out = {}
for k, val in batch.items():
if k in ["pixel_values"]:
batch[k] = val.tolist()
if hasattr(val, "tolist"):
out[k] = (
val.tolist() if k == "pixel_values" else val.squeeze(0).tolist()
)
else:
batch[k] = val.squeeze().tolist()
return batch
out[k] = val
return out
return self.tokenizer.apply_chat_template(
conversation,
@@ -433,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
tokenized_prompt = dict(tokenized_res)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
if isinstance(prompt_ids, dict):
user_prompt_len = len(prompt_ids["input_ids"])
else:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids

View File

@@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
from torch import Tensor
from transformers import PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
@@ -42,62 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin):
examples = self.processing_strategy(examples)
# Initialize batch
batch: dict[str, Any] = {}
messages = [ex["messages"] for ex in examples]
# Process each example
for example in examples:
# Apply chat template to process the example
# This method requires transformers>=4.49.0
result = self.processing_strategy.processor.apply_chat_template(
example["messages"],
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
# TODO: Check if need handling for len(input_ids) > sequence_len
# Add the processed tensors to our batch
for key in result.keys():
if key not in batch:
batch[key] = []
batch[key].append(result[key].squeeze(0))
# Pad sequences to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(
batch["input_ids"],
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
batch = self.processing_strategy.processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
batch["attention_mask"], batch_first=True, padding_value=0
)
# Create the final batch
final_batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
for key, val in batch.items():
if key in ["input_ids", "attention_mask"]:
continue
if key in ["token_type_ids", "cross_attention_mask"]:
final_batch[key] = torch.nn.utils.rnn.pad_sequence(
val, batch_first=True, padding_value=0
)
else:
final_batch[key] = torch.stack(val)
# Process the labels
final_batch["labels"] = self.processing_strategy.process_labels(
final_batch["input_ids"]
)
batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
return final_batch
return batch