Various fixes for VLMs (#3063)

* fix to not use batch feature indexing

* more vlm fixes

* use AutoModelForImageTextToText

* add example yaml and need num2words for chat template

* improve handling of adding image tokens to conversation

* add lfm2-vl support

* update the lfm readme

* fix markdown and add rtol for loss checks

* feat: add smolvlm2 processing strat

* fix: check for causal-conv1d in lfm models

* feat: add docs for lfm2

* feat: add new models and tips to docs

* feat: add smolvlm2 docs and remove extra dep

* chore: update docs

* feat: add video instructions

* chore: cleanup

* chore: comments

* fix: typo

* feat: add usage stats

* chore: refactor

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2025-08-15 10:52:57 -04:00
committed by GitHub
parent d1de6f5f3d
commit 130ef7c51a
13 changed files with 391 additions and 121 deletions

View File

@@ -13,10 +13,13 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
## Usage
@@ -31,7 +34,7 @@ skip_prepare_dataset: true
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
sample_packing: false # not yet supported with multimodal
chat_template: # see in next section
chat_template: # see in next section if specified
# example dataset
datasets:
@@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
chat_template: mistral_v7_tekken
```
### Voxtral {#sec-voxtral}
::: {.callout-tip}
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
:::
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
```
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}
@@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
:::
```yaml
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
```
### LFM2-VL {#sec-lfm2-vl}
::: {.callout-warning}
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
:::
```yaml
base_model: LiquidAI/LFM2-VL-450M
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
@@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
:::
### Video
::: {.callout-warning}
This is not well tested at the moment. We welcome contributors!
:::
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
- `"path": "/path/to/video.mp4"`
- `"url": "https://example.com/video.mp4"`
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
### Example
Here is an example of a multi-modal dataset:

View File

@@ -0,0 +1,58 @@
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.
**LFM2**
```bash
# FFT SFT (1x48GB @ 25GiB)
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
```
**LFM2-VL**
```bash
# LoRA SFT (1x48GB @ 2.7GiB)
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
```
### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
pip uninstall -y causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
- **Dataset Formats**:
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M
chunked_cross_entropy: true
chat_template: tokenizer_default
eot_tokens:
- "<|im_end|>"
datasets:

View File

@@ -0,0 +1,58 @@
base_model: LiquidAI/LFM2-VL-450M
trust_remote_code: true
model_type: AutoModelForImageTextToText
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,7 +0,0 @@
# Liquid Foundation Models 2
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
```bash
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
```

View File

@@ -0,0 +1,49 @@
# Finetune SmolVLM2 with Axolotl
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:
```bash
pip3 install num2words==0.5.14
```
3. Run the finetuning example:
```bash
# LoRA SFT (1x48GB @ 6.8GiB)
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
```
## TIPS
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,56 @@
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
trust_remote_code: true
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,26 +1,13 @@
"""Shared constants for axolotl.loaders module"""
from transformers import (
Gemma3ForConditionalGeneration,
Gemma3nForConditionalGeneration,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
from transformers import AutoModelForImageTextToText
from transformers.models.auto.modeling_auto import (
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
)
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
"llama4": Llama4ForConditionalGeneration,
"llava": LlavaForConditionalGeneration,
"qwen2_vl": Qwen2VLForConditionalGeneration,
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
"gemma3n": Gemma3nForConditionalGeneration,
}
MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText
try:
from transformers import VoxtralForConditionalGeneration

View File

@@ -25,6 +25,7 @@ from peft import (
from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
@@ -212,6 +213,7 @@ class ModelLoader:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
self._check_model_requirements()
def _apply_post_model_load_setup(self):
"""Configure the model after it has been loaded."""
@@ -432,6 +434,8 @@ class ModelLoader:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
)
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
def _set_device_map_config(self):
"""Setup `device_map` according to config"""
@@ -628,6 +632,16 @@ class ModelLoader:
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
def _check_model_requirements(self):
if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]:
from transformers.utils.import_utils import is_causal_conv1d_available
if is_causal_conv1d_available():
raise ImportError(
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
)
def _configure_zero3_memory_efficient_loading(
self,
) -> HfTrainerDeepSpeedConfig | None:

View File

@@ -6,7 +6,7 @@ from typing import Optional
from PIL import Image, ImageOps
from PIL.Image import Resampling
from torch import Tensor, zeros_like
from transformers import ProcessorMixin, VoxtralProcessor
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
from transformers.image_utils import load_image
from axolotl.utils.dict import remove_none_values
@@ -138,7 +138,7 @@ class ProcessingStrategy:
image_key = key
break
# if the image key exists, add the image to the first message
# if the image key exists, add the image to the first user message
if image_key is not None and processed_example[image_key] is not None:
# TODO: check if it's normal to be single image only for common datasets
# From observation, it's usually a list of single image but some datasets may have several columns for images
@@ -179,26 +179,34 @@ class ProcessingStrategy:
# Look for any image type in the first message
# some dataset have an {type: "image"} in the first message
msg_ind_to_add = None
ind_to_add = None
first_user_idx = None
for i, content in enumerate(
processed_example["messages"][0]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
for msg_idx, msg_content in enumerate(processed_example["messages"]):
if first_user_idx is None and msg_content["role"] == "user":
first_user_idx = msg_idx
for i, content in enumerate(
processed_example["messages"][msg_idx]["content"]
):
ind_to_add = i
break
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
msg_ind_to_add = msg_idx
ind_to_add = i
break
# If an image type is found, add the image to that index
if ind_to_add is not None:
processed_example["messages"][0]["content"][ind_to_add][
"image"
] = image_value
if ind_to_add is not None and msg_ind_to_add is not None:
processed_example["messages"][msg_ind_to_add]["content"][
ind_to_add
]["image"] = image_value
else:
# if no image type is found, add it to end of the first message
processed_example["messages"][0]["content"].append(
# if no image type is found, add it to end of the first user message
if first_user_idx is None:
first_user_idx = 0
processed_example["messages"][first_user_idx]["content"].append(
{
"type": "image",
"image": image_value,
@@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy):
return labels
class SmolVLM2ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for SmolVLM2"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<image>" # nosec
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index(self.image_token)
]
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -402,32 +428,43 @@ def get_processing_strategy(
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
processing_kwargs = {
"processor": processor,
"chat_template": chat_template,
"image_size": image_size,
"image_resize_algorithm": image_resize_algorithm,
}
if chat_template_type in [None, "tokenizer_default"] and hasattr(
processor.tokenizer, "chat_template"
):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
if chat_template_type == "gemma3":
return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
if chat_template_type == "gemma3n":
return Gemma3nProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [
"llama3_2_vision",
"llama4",
"llava",
"mistral_v7_tekken",
"pixtral",
]:
return ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
if isinstance(processor, VoxtralProcessor):
return VoxtralProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
if isinstance(processor, SmolVLMProcessor):
return SmolVLM2ProcessingStrategy(
**processing_kwargs,
)
# llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy(
**processing_kwargs,
)

View File

@@ -129,13 +129,21 @@ class ChatTemplatePrompter(Prompter):
images=images,
return_tensors="pt",
)
if hasattr(batch, "to_dict"):
batch = batch.to_dict()
else:
batch = dict(batch)
# workaround since processor works in batches instead of single examples
out = {}
for k, val in batch.items():
if k in ["pixel_values"]:
batch[k] = val.tolist()
if hasattr(val, "tolist"):
out[k] = (
val.tolist() if k == "pixel_values" else val.squeeze(0).tolist()
)
else:
batch[k] = val.squeeze().tolist()
return batch
out[k] = val
return out
return self.tokenizer.apply_chat_template(
conversation,
@@ -433,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
tokenized_prompt = dict(tokenized_res)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
if isinstance(prompt_ids, dict):
user_prompt_len = len(prompt_ids["input_ids"])
else:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids

View File

@@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
from torch import Tensor
from transformers import PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
@@ -42,62 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin):
examples = self.processing_strategy(examples)
# Initialize batch
batch: dict[str, Any] = {}
messages = [ex["messages"] for ex in examples]
# Process each example
for example in examples:
# Apply chat template to process the example
# This method requires transformers>=4.49.0
result = self.processing_strategy.processor.apply_chat_template(
example["messages"],
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
# TODO: Check if need handling for len(input_ids) > sequence_len
# Add the processed tensors to our batch
for key in result.keys():
if key not in batch:
batch[key] = []
batch[key].append(result[key].squeeze(0))
# Pad sequences to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(
batch["input_ids"],
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
batch = self.processing_strategy.processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
batch["attention_mask"], batch_first=True, padding_value=0
)
# Create the final batch
final_batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
for key, val in batch.items():
if key in ["input_ids", "attention_mask"]:
continue
if key in ["token_type_ids", "cross_attention_mask"]:
final_batch[key] = torch.nn.utils.rnn.pad_sequence(
val, batch_first=True, padding_value=0
)
else:
final_batch[key] = torch.stack(val)
# Process the labels
final_batch["labels"] = self.processing_strategy.process_labels(
final_batch["input_ids"]
)
batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
return final_batch
return batch

View File

@@ -147,7 +147,11 @@ def require_hopper(test_case):
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
temp_run_dir: str,
tag: str,
lt_val: float,
assertion_err: str,
rtol: float = 0.02,
) -> None:
"""
helper function to parse and check tensorboard logs
@@ -157,6 +161,7 @@ def check_tensorboard(
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
lt_val = (1 + rtol) * lt_val
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else: