Various fixes for VLMs (#3063)
* fix to not use batch feature indexing * more vlm fixes * use AutoModelForImageTextToText * add example yaml and need num2words for chat template * improve handling of adding image tokens to conversation * add lfm2-vl support * update the lfm readme * fix markdown and add rtol for loss checks * feat: add smolvlm2 processing strat * fix: check for causal-conv1d in lfm models * feat: add docs for lfm2 * feat: add new models and tips to docs * feat: add smolvlm2 docs and remove extra dep * chore: update docs * feat: add video instructions * chore: cleanup * chore: comments * fix: typo * feat: add usage stats * chore: refactor --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -13,10 +13,13 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -31,7 +34,7 @@ skip_prepare_dataset: true
|
||||
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
||||
sample_packing: false # not yet supported with multimodal
|
||||
|
||||
chat_template: # see in next section
|
||||
chat_template: # see in next section if specified
|
||||
|
||||
# example dataset
|
||||
datasets:
|
||||
@@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
chat_template: mistral_v7_tekken
|
||||
```
|
||||
|
||||
### Voxtral {#sec-voxtral}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Voxtral-Mini-3B-2507
|
||||
```
|
||||
|
||||
### Gemma-3 {#sec-gemma-3}
|
||||
|
||||
::: {.callout-tip}
|
||||
@@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### SmolVLM2 {#sec-smolvlm2}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
|
||||
```
|
||||
|
||||
### LFM2-VL {#sec-lfm2-vl}
|
||||
|
||||
::: {.callout-warning}
|
||||
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: LiquidAI/LFM2-VL-450M
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||
@@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
|
||||
|
||||
:::
|
||||
|
||||
### Video
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
This is not well tested at the moment. We welcome contributors!
|
||||
|
||||
:::
|
||||
|
||||
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
|
||||
|
||||
- `"path": "/path/to/video.mp4"`
|
||||
- `"url": "https://example.com/video.mp4"`
|
||||
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
|
||||
|
||||
### Example
|
||||
|
||||
Here is an example of a multi-modal dataset:
|
||||
|
||||
58
examples/LiquidAI/README.md
Normal file
58
examples/LiquidAI/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
|
||||
|
||||
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
|
||||
|
||||
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
|
||||
|
||||
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run one of the finetuning examples below.
|
||||
|
||||
**LFM2**
|
||||
```bash
|
||||
# FFT SFT (1x48GB @ 25GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
|
||||
```
|
||||
|
||||
**LFM2-VL**
|
||||
```bash
|
||||
# LoRA SFT (1x48GB @ 2.7GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
```bash
|
||||
pip uninstall -y causal-conv1d
|
||||
```
|
||||
|
||||
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- **Dataset Formats**:
|
||||
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
|
||||
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M
|
||||
|
||||
chunked_cross_entropy: true
|
||||
|
||||
chat_template: tokenizer_default
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
datasets:
|
||||
58
examples/LiquidAI/lfm2-vl-lora.yaml
Normal file
58
examples/LiquidAI/lfm2-vl-lora.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: LiquidAI/LFM2-VL-450M
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForImageTextToText
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,7 +0,0 @@
|
||||
# Liquid Foundation Models 2
|
||||
|
||||
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
|
||||
|
||||
```bash
|
||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
49
examples/smolvlm2/README.md
Normal file
49
examples/smolvlm2/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Finetune SmolVLM2 with Axolotl
|
||||
|
||||
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
|
||||
|
||||
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
|
||||
|
||||
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install an extra dependency:
|
||||
|
||||
```bash
|
||||
pip3 install num2words==0.5.14
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# LoRA SFT (1x48GB @ 6.8GiB)
|
||||
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
|
||||
```
|
||||
|
||||
## TIPS
|
||||
|
||||
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
56
examples/smolvlm2/smolvlm2-2B-lora.yaml
Normal file
56
examples/smolvlm2/smolvlm2-2B-lora.yaml
Normal file
@@ -0,0 +1,56 @@
|
||||
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||
trust_remote_code: true
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,26 +1,13 @@
|
||||
"""Shared constants for axolotl.loaders module"""
|
||||
|
||||
from transformers import (
|
||||
Gemma3ForConditionalGeneration,
|
||||
Gemma3nForConditionalGeneration,
|
||||
Llama4ForConditionalGeneration,
|
||||
LlavaForConditionalGeneration,
|
||||
Mistral3ForConditionalGeneration,
|
||||
MllamaForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
from transformers import AutoModelForImageTextToText
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||
"mllama": MllamaForConditionalGeneration,
|
||||
"llama4": Llama4ForConditionalGeneration,
|
||||
"llava": LlavaForConditionalGeneration,
|
||||
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
||||
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
||||
"mistral3": Mistral3ForConditionalGeneration,
|
||||
"gemma3": Gemma3ForConditionalGeneration,
|
||||
"gemma3n": Gemma3nForConditionalGeneration,
|
||||
}
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText
|
||||
|
||||
try:
|
||||
from transformers import VoxtralForConditionalGeneration
|
||||
|
||||
@@ -25,6 +25,7 @@ from peft import (
|
||||
from torch.distributed import DeviceMesh
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForVision2Seq,
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
@@ -212,6 +213,7 @@ class ModelLoader:
|
||||
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
||||
self._set_quantization_config()
|
||||
self._set_attention_config()
|
||||
self._check_model_requirements()
|
||||
|
||||
def _apply_post_model_load_setup(self):
|
||||
"""Configure the model after it has been loaded."""
|
||||
@@ -432,6 +434,8 @@ class ModelLoader:
|
||||
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
||||
self.model_config.model_type, AutoModelForVision2Seq
|
||||
)
|
||||
if isinstance(self.auto_model_loader, str):
|
||||
self.auto_model_loader = AutoModelForImageTextToText
|
||||
|
||||
def _set_device_map_config(self):
|
||||
"""Setup `device_map` according to config"""
|
||||
@@ -628,6 +632,16 @@ class ModelLoader:
|
||||
if self.cfg.low_cpu_mem_usage:
|
||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
def _check_model_requirements(self):
|
||||
if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]:
|
||||
from transformers.utils.import_utils import is_causal_conv1d_available
|
||||
|
||||
if is_causal_conv1d_available():
|
||||
raise ImportError(
|
||||
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
|
||||
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
|
||||
)
|
||||
|
||||
def _configure_zero3_memory_efficient_loading(
|
||||
self,
|
||||
) -> HfTrainerDeepSpeedConfig | None:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.Image import Resampling
|
||||
from torch import Tensor, zeros_like
|
||||
from transformers import ProcessorMixin, VoxtralProcessor
|
||||
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
|
||||
from transformers.image_utils import load_image
|
||||
|
||||
from axolotl.utils.dict import remove_none_values
|
||||
@@ -138,7 +138,7 @@ class ProcessingStrategy:
|
||||
image_key = key
|
||||
break
|
||||
|
||||
# if the image key exists, add the image to the first message
|
||||
# if the image key exists, add the image to the first user message
|
||||
if image_key is not None and processed_example[image_key] is not None:
|
||||
# TODO: check if it's normal to be single image only for common datasets
|
||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||
@@ -179,26 +179,34 @@ class ProcessingStrategy:
|
||||
|
||||
# Look for any image type in the first message
|
||||
# some dataset have an {type: "image"} in the first message
|
||||
msg_ind_to_add = None
|
||||
ind_to_add = None
|
||||
first_user_idx = None
|
||||
|
||||
for i, content in enumerate(
|
||||
processed_example["messages"][0]["content"]
|
||||
):
|
||||
# Usually datasets created with image columns, don't have it in the messages itself
|
||||
if content["type"] == "image" and all(
|
||||
k not in content for k in ["image", "url", "path", "base64"]
|
||||
for msg_idx, msg_content in enumerate(processed_example["messages"]):
|
||||
if first_user_idx is None and msg_content["role"] == "user":
|
||||
first_user_idx = msg_idx
|
||||
for i, content in enumerate(
|
||||
processed_example["messages"][msg_idx]["content"]
|
||||
):
|
||||
ind_to_add = i
|
||||
break
|
||||
# Usually datasets created with image columns, don't have it in the messages itself
|
||||
if content["type"] == "image" and all(
|
||||
k not in content for k in ["image", "url", "path", "base64"]
|
||||
):
|
||||
msg_ind_to_add = msg_idx
|
||||
ind_to_add = i
|
||||
break
|
||||
|
||||
# If an image type is found, add the image to that index
|
||||
if ind_to_add is not None:
|
||||
processed_example["messages"][0]["content"][ind_to_add][
|
||||
"image"
|
||||
] = image_value
|
||||
if ind_to_add is not None and msg_ind_to_add is not None:
|
||||
processed_example["messages"][msg_ind_to_add]["content"][
|
||||
ind_to_add
|
||||
]["image"] = image_value
|
||||
else:
|
||||
# if no image type is found, add it to end of the first message
|
||||
processed_example["messages"][0]["content"].append(
|
||||
# if no image type is found, add it to end of the first user message
|
||||
if first_user_idx is None:
|
||||
first_user_idx = 0
|
||||
processed_example["messages"][first_user_idx]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": image_value,
|
||||
@@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy):
|
||||
return labels
|
||||
|
||||
|
||||
class SmolVLM2ProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for SmolVLM2"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: ProcessorMixin,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
self.image_token = "<image>" # nosec
|
||||
|
||||
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
|
||||
processor.tokenizer.additional_special_tokens.index(self.image_token)
|
||||
]
|
||||
|
||||
|
||||
def get_processing_strategy(
|
||||
processor: ProcessorMixin,
|
||||
chat_template,
|
||||
@@ -402,32 +428,43 @@ def get_processing_strategy(
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
processing_kwargs = {
|
||||
"processor": processor,
|
||||
"chat_template": chat_template,
|
||||
"image_size": image_size,
|
||||
"image_resize_algorithm": image_resize_algorithm,
|
||||
}
|
||||
|
||||
if chat_template_type in [None, "tokenizer_default"] and hasattr(
|
||||
processor.tokenizer, "chat_template"
|
||||
):
|
||||
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
|
||||
|
||||
if chat_template_type == "qwen2_vl":
|
||||
return Qwen2VLProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
if chat_template_type == "gemma3":
|
||||
return Gemma3ProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
if chat_template_type == "gemma3n":
|
||||
return Gemma3nProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
)
|
||||
if chat_template_type in [
|
||||
"llama3_2_vision",
|
||||
"llama4",
|
||||
"llava",
|
||||
"mistral_v7_tekken",
|
||||
"pixtral",
|
||||
]:
|
||||
return ProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(processor, VoxtralProcessor):
|
||||
return VoxtralProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
|
||||
if isinstance(processor, SmolVLMProcessor):
|
||||
return SmolVLM2ProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
# llama3_2_vision, llama4, llava
|
||||
# mistral_v7_tekken, pixtral, lfm2vl
|
||||
return ProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
@@ -129,13 +129,21 @@ class ChatTemplatePrompter(Prompter):
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if hasattr(batch, "to_dict"):
|
||||
batch = batch.to_dict()
|
||||
else:
|
||||
batch = dict(batch)
|
||||
|
||||
# workaround since processor works in batches instead of single examples
|
||||
out = {}
|
||||
for k, val in batch.items():
|
||||
if k in ["pixel_values"]:
|
||||
batch[k] = val.tolist()
|
||||
if hasattr(val, "tolist"):
|
||||
out[k] = (
|
||||
val.tolist() if k == "pixel_values" else val.squeeze(0).tolist()
|
||||
)
|
||||
else:
|
||||
batch[k] = val.squeeze().tolist()
|
||||
return batch
|
||||
out[k] = val
|
||||
return out
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
@@ -433,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
|
||||
else:
|
||||
input_ids = tokenized_res["input_ids"]
|
||||
tokenized_prompt = tokenized_res
|
||||
tokenized_prompt = dict(tokenized_res)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
if isinstance(prompt_ids, dict):
|
||||
user_prompt_len = len(prompt_ids["input_ids"])
|
||||
else:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
||||
else:
|
||||
labels = input_ids
|
||||
|
||||
@@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.data.data_collator import DataCollatorMixin
|
||||
@@ -42,62 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
examples = self.processing_strategy(examples)
|
||||
|
||||
# Initialize batch
|
||||
batch: dict[str, Any] = {}
|
||||
messages = [ex["messages"] for ex in examples]
|
||||
|
||||
# Process each example
|
||||
for example in examples:
|
||||
# Apply chat template to process the example
|
||||
# This method requires transformers>=4.49.0
|
||||
result = self.processing_strategy.processor.apply_chat_template(
|
||||
example["messages"],
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_dict=True,
|
||||
chat_template=self.processing_strategy.chat_template,
|
||||
)
|
||||
|
||||
# TODO: Check if need handling for len(input_ids) > sequence_len
|
||||
|
||||
# Add the processed tensors to our batch
|
||||
for key in result.keys():
|
||||
if key not in batch:
|
||||
batch[key] = []
|
||||
|
||||
batch[key].append(result[key].squeeze(0))
|
||||
|
||||
# Pad sequences to the same length
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
batch["input_ids"],
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
batch = self.processing_strategy.processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_dict=True,
|
||||
chat_template=self.processing_strategy.chat_template,
|
||||
)
|
||||
|
||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
batch["attention_mask"], batch_first=True, padding_value=0
|
||||
)
|
||||
|
||||
# Create the final batch
|
||||
final_batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
for key, val in batch.items():
|
||||
if key in ["input_ids", "attention_mask"]:
|
||||
continue
|
||||
|
||||
if key in ["token_type_ids", "cross_attention_mask"]:
|
||||
final_batch[key] = torch.nn.utils.rnn.pad_sequence(
|
||||
val, batch_first=True, padding_value=0
|
||||
)
|
||||
else:
|
||||
final_batch[key] = torch.stack(val)
|
||||
|
||||
# Process the labels
|
||||
final_batch["labels"] = self.processing_strategy.process_labels(
|
||||
final_batch["input_ids"]
|
||||
)
|
||||
batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
|
||||
|
||||
return final_batch
|
||||
return batch
|
||||
|
||||
@@ -147,7 +147,11 @@ def require_hopper(test_case):
|
||||
|
||||
|
||||
def check_tensorboard(
|
||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||
temp_run_dir: str,
|
||||
tag: str,
|
||||
lt_val: float,
|
||||
assertion_err: str,
|
||||
rtol: float = 0.02,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
@@ -157,6 +161,7 @@ def check_tensorboard(
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
lt_val = (1 + rtol) * lt_val
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user