Various fixes for VLMs (#3063)

* fix to not use batch feature indexing

* more vlm fixes

* use AutoModelForImageTextToText

* add example yaml and need num2words for chat template

* improve handling of adding image tokens to conversation

* add lfm2-vl support

* update the lfm readme

* fix markdown and add rtol for loss checks

* feat: add smolvlm2 processing strat

* fix: check for causal-conv1d in lfm models

* feat: add docs for lfm2

* feat: add new models and tips to docs

* feat: add smolvlm2 docs and remove extra dep

* chore: update docs

* feat: add video instructions

* chore: cleanup

* chore: comments

* fix: typo

* feat: add usage stats

* chore: refactor

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2025-08-15 10:52:57 -04:00
committed by GitHub
parent d1de6f5f3d
commit 130ef7c51a
13 changed files with 391 additions and 121 deletions

View File

@@ -13,10 +13,13 @@ format:
- [Pixtral](#sec-pixtral) - [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15) - [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31) - [Mistral-Small-3.1](#sec-mistral-small-31)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3) - [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n) - [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl) - [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl) - [Qwen2.5-VL](#sec-qwen25-vl)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
## Usage ## Usage
@@ -31,7 +34,7 @@ skip_prepare_dataset: true
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
sample_packing: false # not yet supported with multimodal sample_packing: false # not yet supported with multimodal
chat_template: # see in next section chat_template: # see in next section if specified
# example dataset # example dataset
datasets: datasets:
@@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
chat_template: mistral_v7_tekken chat_template: mistral_v7_tekken
``` ```
### Voxtral {#sec-voxtral}
::: {.callout-tip}
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
:::
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
```
### Gemma-3 {#sec-gemma-3} ### Gemma-3 {#sec-gemma-3}
::: {.callout-tip} ::: {.callout-tip}
@@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl chat_template: qwen2_vl # same as qwen2-vl
``` ```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
:::
```yaml
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
```
### LFM2-VL {#sec-lfm2-vl}
::: {.callout-warning}
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
:::
```yaml
base_model: LiquidAI/LFM2-VL-450M
```
## Dataset Format ## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format. For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
@@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
::: :::
### Video
::: {.callout-warning}
This is not well tested at the moment. We welcome contributors!
:::
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
- `"path": "/path/to/video.mp4"`
- `"url": "https://example.com/video.mp4"`
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
### Example ### Example
Here is an example of a multi-modal dataset: Here is an example of a multi-modal dataset:

View File

@@ -0,0 +1,58 @@
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.
**LFM2**
```bash
# FFT SFT (1x48GB @ 25GiB)
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
```
**LFM2-VL**
```bash
# LoRA SFT (1x48GB @ 2.7GiB)
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
```
### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
pip uninstall -y causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
- **Dataset Formats**:
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M
chunked_cross_entropy: true chunked_cross_entropy: true
chat_template: tokenizer_default
eot_tokens: eot_tokens:
- "<|im_end|>" - "<|im_end|>"
datasets: datasets:

View File

@@ -0,0 +1,58 @@
base_model: LiquidAI/LFM2-VL-450M
trust_remote_code: true
model_type: AutoModelForImageTextToText
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,7 +0,0 @@
# Liquid Foundation Models 2
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
```bash
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
```

View File

@@ -0,0 +1,49 @@
# Finetune SmolVLM2 with Axolotl
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:
```bash
pip3 install num2words==0.5.14
```
3. Run the finetuning example:
```bash
# LoRA SFT (1x48GB @ 6.8GiB)
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
```
## TIPS
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,56 @@
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
trust_remote_code: true
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,26 +1,13 @@
"""Shared constants for axolotl.loaders module""" """Shared constants for axolotl.loaders module"""
from transformers import ( from transformers import AutoModelForImageTextToText
Gemma3ForConditionalGeneration, from transformers.models.auto.modeling_auto import (
Gemma3nForConditionalGeneration, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
) )
MULTIMODAL_AUTO_MODEL_MAPPING = { MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
"mllama": MllamaForConditionalGeneration,
"llama4": Llama4ForConditionalGeneration, MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText
"llava": LlavaForConditionalGeneration,
"qwen2_vl": Qwen2VLForConditionalGeneration,
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
"gemma3n": Gemma3nForConditionalGeneration,
}
try: try:
from transformers import VoxtralForConditionalGeneration from transformers import VoxtralForConditionalGeneration

View File

@@ -25,6 +25,7 @@ from peft import (
from torch.distributed import DeviceMesh from torch.distributed import DeviceMesh
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AwqConfig, AwqConfig,
BitsAndBytesConfig, BitsAndBytesConfig,
@@ -212,6 +213,7 @@ class ModelLoader:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels self.model_kwargs["use_kernels"] = self.cfg.use_kernels
self._set_quantization_config() self._set_quantization_config()
self._set_attention_config() self._set_attention_config()
self._check_model_requirements()
def _apply_post_model_load_setup(self): def _apply_post_model_load_setup(self):
"""Configure the model after it has been loaded.""" """Configure the model after it has been loaded."""
@@ -432,6 +434,8 @@ class ModelLoader:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get( self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq self.model_config.model_type, AutoModelForVision2Seq
) )
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
def _set_device_map_config(self): def _set_device_map_config(self):
"""Setup `device_map` according to config""" """Setup `device_map` according to config"""
@@ -628,6 +632,16 @@ class ModelLoader:
if self.cfg.low_cpu_mem_usage: if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True self.model_kwargs["low_cpu_mem_usage"] = True
def _check_model_requirements(self):
if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]:
from transformers.utils.import_utils import is_causal_conv1d_available
if is_causal_conv1d_available():
raise ImportError(
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
)
def _configure_zero3_memory_efficient_loading( def _configure_zero3_memory_efficient_loading(
self, self,
) -> HfTrainerDeepSpeedConfig | None: ) -> HfTrainerDeepSpeedConfig | None:

View File

@@ -6,7 +6,7 @@ from typing import Optional
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.Image import Resampling from PIL.Image import Resampling
from torch import Tensor, zeros_like from torch import Tensor, zeros_like
from transformers import ProcessorMixin, VoxtralProcessor from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
from transformers.image_utils import load_image from transformers.image_utils import load_image
from axolotl.utils.dict import remove_none_values from axolotl.utils.dict import remove_none_values
@@ -138,7 +138,7 @@ class ProcessingStrategy:
image_key = key image_key = key
break break
# if the image key exists, add the image to the first message # if the image key exists, add the image to the first user message
if image_key is not None and processed_example[image_key] is not None: if image_key is not None and processed_example[image_key] is not None:
# TODO: check if it's normal to be single image only for common datasets # TODO: check if it's normal to be single image only for common datasets
# From observation, it's usually a list of single image but some datasets may have several columns for images # From observation, it's usually a list of single image but some datasets may have several columns for images
@@ -179,26 +179,34 @@ class ProcessingStrategy:
# Look for any image type in the first message # Look for any image type in the first message
# some dataset have an {type: "image"} in the first message # some dataset have an {type: "image"} in the first message
msg_ind_to_add = None
ind_to_add = None ind_to_add = None
first_user_idx = None
for i, content in enumerate( for msg_idx, msg_content in enumerate(processed_example["messages"]):
processed_example["messages"][0]["content"] if first_user_idx is None and msg_content["role"] == "user":
): first_user_idx = msg_idx
# Usually datasets created with image columns, don't have it in the messages itself for i, content in enumerate(
if content["type"] == "image" and all( processed_example["messages"][msg_idx]["content"]
k not in content for k in ["image", "url", "path", "base64"]
): ):
ind_to_add = i # Usually datasets created with image columns, don't have it in the messages itself
break if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
msg_ind_to_add = msg_idx
ind_to_add = i
break
# If an image type is found, add the image to that index # If an image type is found, add the image to that index
if ind_to_add is not None: if ind_to_add is not None and msg_ind_to_add is not None:
processed_example["messages"][0]["content"][ind_to_add][ processed_example["messages"][msg_ind_to_add]["content"][
"image" ind_to_add
] = image_value ]["image"] = image_value
else: else:
# if no image type is found, add it to end of the first message # if no image type is found, add it to end of the first user message
processed_example["messages"][0]["content"].append( if first_user_idx is None:
first_user_idx = 0
processed_example["messages"][first_user_idx]["content"].append(
{ {
"type": "image", "type": "image",
"image": image_value, "image": image_value,
@@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy):
return labels return labels
class SmolVLM2ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for SmolVLM2"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<image>" # nosec
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index(self.image_token)
]
def get_processing_strategy( def get_processing_strategy(
processor: ProcessorMixin, processor: ProcessorMixin,
chat_template, chat_template,
@@ -402,32 +428,43 @@ def get_processing_strategy(
image_size: int | tuple[int, int] | None = None, image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None, image_resize_algorithm: Resampling | None = None,
): ):
processing_kwargs = {
"processor": processor,
"chat_template": chat_template,
"image_size": image_size,
"image_resize_algorithm": image_resize_algorithm,
}
if chat_template_type in [None, "tokenizer_default"] and hasattr(
processor.tokenizer, "chat_template"
):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
if chat_template_type == "qwen2_vl": if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy( return Qwen2VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm **processing_kwargs,
) )
if chat_template_type == "gemma3": if chat_template_type == "gemma3":
return Gemma3ProcessingStrategy( return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm **processing_kwargs,
) )
if chat_template_type == "gemma3n": if chat_template_type == "gemma3n":
return Gemma3nProcessingStrategy( return Gemma3nProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm **processing_kwargs,
)
if chat_template_type in [
"llama3_2_vision",
"llama4",
"llava",
"mistral_v7_tekken",
"pixtral",
]:
return ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
) )
if isinstance(processor, VoxtralProcessor): if isinstance(processor, VoxtralProcessor):
return VoxtralProcessingStrategy( return VoxtralProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm **processing_kwargs,
) )
raise ValueError(f"Unsupported chat template type: {chat_template_type}") if isinstance(processor, SmolVLMProcessor):
return SmolVLM2ProcessingStrategy(
**processing_kwargs,
)
# llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy(
**processing_kwargs,
)

View File

@@ -129,13 +129,21 @@ class ChatTemplatePrompter(Prompter):
images=images, images=images,
return_tensors="pt", return_tensors="pt",
) )
if hasattr(batch, "to_dict"):
batch = batch.to_dict()
else:
batch = dict(batch)
# workaround since processor works in batches instead of single examples # workaround since processor works in batches instead of single examples
out = {}
for k, val in batch.items(): for k, val in batch.items():
if k in ["pixel_values"]: if hasattr(val, "tolist"):
batch[k] = val.tolist() out[k] = (
val.tolist() if k == "pixel_values" else val.squeeze(0).tolist()
)
else: else:
batch[k] = val.squeeze().tolist() out[k] = val
return batch return out
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
conversation, conversation,
@@ -433,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
tokenized_prompt["attention_mask"] = [1] * len(input_ids) tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else: else:
input_ids = tokenized_res["input_ids"] input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res tokenized_prompt = dict(tokenized_res)
if not self.train_on_inputs: if not self.train_on_inputs:
user_prompt_len = len(prompt_ids) if isinstance(prompt_ids, dict):
user_prompt_len = len(prompt_ids["input_ids"])
else:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else: else:
labels = input_ids labels = input_ids

View File

@@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch
from torch import Tensor from torch import Tensor
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin from transformers.data.data_collator import DataCollatorMixin
@@ -42,62 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin):
examples = self.processing_strategy(examples) examples = self.processing_strategy(examples)
# Initialize batch # Initialize batch
batch: dict[str, Any] = {} messages = [ex["messages"] for ex in examples]
# Process each example batch = self.processing_strategy.processor.apply_chat_template(
for example in examples: messages,
# Apply chat template to process the example add_generation_prompt=False,
# This method requires transformers>=4.49.0 tokenize=True,
result = self.processing_strategy.processor.apply_chat_template( return_tensors="pt",
example["messages"], padding=True,
add_generation_prompt=False, return_dict=True,
tokenize=True, chat_template=self.processing_strategy.chat_template,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
# TODO: Check if need handling for len(input_ids) > sequence_len
# Add the processed tensors to our batch
for key in result.keys():
if key not in batch:
batch[key] = []
batch[key].append(result[key].squeeze(0))
# Pad sequences to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(
batch["input_ids"],
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) )
attention_mask = torch.nn.utils.rnn.pad_sequence(
batch["attention_mask"], batch_first=True, padding_value=0
)
# Create the final batch
final_batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
for key, val in batch.items():
if key in ["input_ids", "attention_mask"]:
continue
if key in ["token_type_ids", "cross_attention_mask"]:
final_batch[key] = torch.nn.utils.rnn.pad_sequence(
val, batch_first=True, padding_value=0
)
else:
final_batch[key] = torch.stack(val)
# Process the labels # Process the labels
final_batch["labels"] = self.processing_strategy.process_labels( batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
final_batch["input_ids"]
)
return final_batch return batch

View File

@@ -147,7 +147,11 @@ def require_hopper(test_case):
def check_tensorboard( def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str temp_run_dir: str,
tag: str,
lt_val: float,
assertion_err: str,
rtol: float = 0.02,
) -> None: ) -> None:
""" """
helper function to parse and check tensorboard logs helper function to parse and check tensorboard logs
@@ -157,6 +161,7 @@ def check_tensorboard(
reader = SummaryReader(event_file) reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name df = df[(df.tag == tag)] # pylint: disable=invalid-name
lt_val = (1 + rtol) * lt_val
if "%s" in assertion_err: if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1] assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else: else: