Feat: Rework multimodal support (mllama, llava, pixtral, qwen2, qwen25, gemma3, mistral3) (#2435)
This commit is contained in:
@@ -133,6 +133,7 @@ quartodoc:
|
|||||||
- utils.schemas.datasets
|
- utils.schemas.datasets
|
||||||
- utils.schemas.peft
|
- utils.schemas.peft
|
||||||
- utils.schemas.trl
|
- utils.schemas.trl
|
||||||
|
- utils.schemas.multimodal
|
||||||
- utils.schemas.integrations
|
- utils.schemas.integrations
|
||||||
- utils.schemas.enums
|
- utils.schemas.enums
|
||||||
- utils.schemas.utils
|
- utils.schemas.utils
|
||||||
|
|||||||
@@ -586,6 +586,14 @@ resume_from_checkpoint:
|
|||||||
# Be careful with this being turned on between different models.
|
# Be careful with this being turned on between different models.
|
||||||
auto_resume_from_checkpoints: false
|
auto_resume_from_checkpoints: false
|
||||||
|
|
||||||
|
## Multimodal section
|
||||||
|
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
||||||
|
# Will read from model/processor config if not set.
|
||||||
|
image_size:
|
||||||
|
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
||||||
|
image_resize_algorithm: 'bilinear'
|
||||||
|
## End of multimodal section
|
||||||
|
|
||||||
# Don't mess with this, it's here for accelerate and torchrun
|
# Don't mess with this, it's here for accelerate and torchrun
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +1,171 @@
|
|||||||
# MultiModal / Vision Language Models (BETA)
|
---
|
||||||
|
title: MultiModal / Vision Language Models (BETA)
|
||||||
|
format:
|
||||||
|
html:
|
||||||
|
toc: true
|
||||||
|
toc-depth: 3
|
||||||
|
---
|
||||||
|
|
||||||
### Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- Mllama, i.e. llama with vision models
|
- [Mllama](#sec-mllama)
|
||||||
|
- [Pixtral](#sec-pixtral)
|
||||||
|
- [Llava-1.5](#sec-llava-15)
|
||||||
|
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||||
|
- [Gemma-3](#sec-gemma-3)
|
||||||
|
- [Qwen2-VL](#sec-qwen2-vl)
|
||||||
|
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||||
|
|
||||||
### Usage
|
## Usage
|
||||||
|
|
||||||
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
|
Multimodal support is limited and doesn't have full feature parity.
|
||||||
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
|
|
||||||
|
Here are the hyperparams you'll need to use to finetune a multimodal model.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
skip_prepare_dataset: true
|
|
||||||
|
|
||||||
chat_template: llama3_2_vision
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
||||||
|
sample_packing: false # not yet supported with multimodal
|
||||||
|
|
||||||
|
chat_template: # see in next section
|
||||||
|
|
||||||
|
# example dataset
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
# only finetune the Language model, leave the vision model and vision tower frozen
|
# (optional) if doing lora, only finetune the Language model,
|
||||||
|
# leave the vision model and vision tower frozen
|
||||||
|
# load_in_8bit: true
|
||||||
|
adapter: lora
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# (optional) if you want to resize images to a set size
|
||||||
|
image_size: 512
|
||||||
|
image_resize_algorithm: bilinear
|
||||||
|
```
|
||||||
|
|
||||||
|
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||||
|
:::
|
||||||
|
|
||||||
|
### Mllama {#sec-mllama}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
|
||||||
|
chat_template: llama3_2_vision
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pixtral {#sec-pixtral}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: mistralai/Pixtral-12B-2409
|
||||||
|
|
||||||
|
chat_template: pixtral
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llava-1.5 {#sec-llava-15}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: llava-hf/llava-1.5-7b-hf
|
||||||
|
|
||||||
|
chat_template: llava
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||||
|
|
||||||
|
chat_template: mistral_v7_tekken
|
||||||
|
```
|
||||||
|
|
||||||
|
### Gemma-3 {#sec-gemma-3}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
The Gemma3-1B model is a text-only model, so please train as regular text model.
|
||||||
|
:::
|
||||||
|
|
||||||
|
For multi-modal 4B/12B/27B models, use the following config:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: google/gemma-3-4b-it
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
```
|
||||||
|
|
||||||
|
### Qwen2-VL {#sec-qwen2-vl}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
|
||||||
|
chat_template: qwen2_vl
|
||||||
|
```
|
||||||
|
|
||||||
|
### Qwen2.5-VL {#sec-qwen25-vl}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
|
||||||
|
chat_template: qwen2_vl # same as qwen2-vl
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataset Format
|
||||||
|
|
||||||
|
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||||
|
|
||||||
|
- A message is a list of `role` and `content`.
|
||||||
|
- `role` can be `system`, `user`, `assistant`, etc.
|
||||||
|
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
For backwards compatibility:
|
||||||
|
|
||||||
|
- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key.
|
||||||
|
- If `content` is a string, it will be converted to a list with `type` as `text`.
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
|
||||||
|
|
||||||
|
- `"path": "/path/to/image.jpg"`
|
||||||
|
- `"url": "https://example.com/image.jpg"`
|
||||||
|
- `"base64": "..."`
|
||||||
|
- `"image": PIL.Image`
|
||||||
|
:::
|
||||||
|
|
||||||
|
Here is an example of a multi-modal dataset:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are a helpful assistant."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
||||||
|
{"type": "text", "text": "Describe this image in detail."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The image is a bee."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
```
|
```
|
||||||
|
|||||||
63
examples/gemma3/gemma-3-4b-lora.yml
Normal file
63
examples/gemma3/gemma-3-4b-lora.yml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: google/gemma-3-4b-it
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
63
examples/llava/lora-7b.yaml
Normal file
63
examples/llava/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: llava-hf/llava-1.5-7b-hf
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: llava
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: mistral_v7_tekken
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
65
examples/pixtral/lora-12b.yml
Normal file
65
examples/pixtral/lora-12b.yml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: mistral-community/pixtral-12b
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: pixtral
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <pad>
|
||||||
63
examples/qwen2-vl/lora-7b.yaml
Normal file
63
examples/qwen2-vl/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: qwen2_vl
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -60,6 +60,7 @@ from axolotl.core.training_args import (
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
|
from axolotl.processing_strategies import get_processing_strategy
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
@@ -747,6 +748,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.accelerator_config
|
self.cfg.accelerator_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cfg.image_size:
|
||||||
|
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||||
|
if self.cfg.image_resize_algorithm:
|
||||||
|
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||||
|
self.cfg.image_resize_algorithm
|
||||||
|
)
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
if self.cfg.kd_alpha is not None:
|
if self.cfg.kd_alpha is not None:
|
||||||
@@ -890,8 +897,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
if self.cfg.processor_type and self.processor:
|
if self.cfg.processor_type and self.processor:
|
||||||
collator = MultiModalChatDataCollator
|
collator = MultiModalChatDataCollator
|
||||||
kwargs["processor"] = self.processor
|
kwargs["processing_strategy"] = get_processing_strategy(
|
||||||
kwargs["chat_template"] = training_args.chat_template
|
self.processor,
|
||||||
|
training_args.chat_template,
|
||||||
|
self.cfg.chat_template,
|
||||||
|
image_size=training_args.image_size,
|
||||||
|
image_resize_algorithm=training_args.image_resize_algorithm,
|
||||||
|
)
|
||||||
elif self.cfg.batch_flattening:
|
elif self.cfg.batch_flattening:
|
||||||
collator = DataCollatorWithFlattening
|
collator = DataCollatorWithFlattening
|
||||||
collator_args.pop(0)
|
collator_args.pop(0)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ extra axolotl specific training args
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL.Image import Resampling
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
@@ -212,6 +213,20 @@ class AxolotlTrainingMixins:
|
|||||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# multi-modal section
|
||||||
|
|
||||||
|
image_size: int | tuple[int, int] | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The size of the image to resize to"},
|
||||||
|
)
|
||||||
|
|
||||||
|
image_resize_algorithm: Resampling | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The algorithm to use for image resizing"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# end of multi-modal section
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
|||||||
278
src/axolotl/processing_strategies.py
Normal file
278
src/axolotl/processing_strategies.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from PIL.Image import Resampling
|
||||||
|
from torch import Tensor
|
||||||
|
from transformers import ProcessorMixin
|
||||||
|
from transformers.image_utils import load_image
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingStrategy:
|
||||||
|
"""Base Processing Strategy class"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
self.processor = processor
|
||||||
|
self.chat_template = chat_template
|
||||||
|
self.image_token = None
|
||||||
|
self.image_token_id = None
|
||||||
|
|
||||||
|
self.image_size = image_size
|
||||||
|
self.image_resize_algorithm = (
|
||||||
|
image_resize_algorithm or Image.Resampling.BILINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(processor, "image_token"):
|
||||||
|
self.image_token = processor.image_token
|
||||||
|
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
self.image_token
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, examples: list[dict]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Preprocess conversation examples to ensure consistent format.
|
||||||
|
Converts different conversation formats to OpenAI format with 'messages'.
|
||||||
|
Supports two formats:
|
||||||
|
1. OpenAI format with 'messages'
|
||||||
|
2. Legacy format with 'conversations'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: list of conversation dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of dicts in OpenAI format with 'messages' key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the conversation format is not supported
|
||||||
|
"""
|
||||||
|
role_mapping = {
|
||||||
|
"human": "user",
|
||||||
|
"gpt": "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
|
def normalize_role(role: str) -> str:
|
||||||
|
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||||
|
return role_mapping.get(role, role)
|
||||||
|
|
||||||
|
def convert_legacy_format(example: dict) -> dict:
|
||||||
|
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||||
|
messages = [
|
||||||
|
{"role": normalize_role(convo["from"]), "content": convo["value"]}
|
||||||
|
for convo in example["conversations"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create new dict without 'conversations' key
|
||||||
|
result = deepcopy(example)
|
||||||
|
result.pop("conversations")
|
||||||
|
result["messages"] = messages
|
||||||
|
return result
|
||||||
|
|
||||||
|
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
|
||||||
|
"""Convert regular messages format to Messages format with content type"""
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": message["role"],
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": message["content"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(message["content"], list):
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": message["role"],
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
processed_examples = []
|
||||||
|
for example in examples:
|
||||||
|
if not ("messages" in example or "conversations" in example):
|
||||||
|
raise ValueError(
|
||||||
|
"Only `messages` and `conversations` message keys are currently supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_example = None
|
||||||
|
if "messages" in example: # OpenAI format
|
||||||
|
processed_example = example
|
||||||
|
else: # Legacy format
|
||||||
|
processed_example = convert_legacy_format(example)
|
||||||
|
|
||||||
|
# convert regular messages format to Messages format with content type
|
||||||
|
# for compatibility with apply_chat_template
|
||||||
|
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
||||||
|
processed_example["messages"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# find the image key if it exists
|
||||||
|
possible_image_keys = ["images", "image"]
|
||||||
|
image_key = None
|
||||||
|
for key in possible_image_keys:
|
||||||
|
if key in processed_example:
|
||||||
|
image_key = key
|
||||||
|
break
|
||||||
|
|
||||||
|
# if the image key exists, add the image to the first message
|
||||||
|
if image_key is not None:
|
||||||
|
# TODO: check if it's normal to be single image only for common datasets
|
||||||
|
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||||
|
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
|
||||||
|
image_value = processed_example[image_key][0]
|
||||||
|
|
||||||
|
# Handle image loading (Image, url, path, base64)
|
||||||
|
image_value = load_image(image_value)
|
||||||
|
|
||||||
|
if self.image_size is not None:
|
||||||
|
assert hasattr(
|
||||||
|
image_value, "resize"
|
||||||
|
), "Image does not have a resize method"
|
||||||
|
|
||||||
|
if isinstance(self.image_size, tuple):
|
||||||
|
image_value = image_value.resize(
|
||||||
|
self.image_size, self.image_resize_algorithm
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Set the padding value; here we use black (0, 0, 0) for RGB images
|
||||||
|
padding_color = (0, 0, 0)
|
||||||
|
|
||||||
|
# When image_size is an int (square target), preserve aspect ratio then pad
|
||||||
|
# This is to prevent aspect ratio distortion when resizing to square
|
||||||
|
image_value = ImageOps.pad(
|
||||||
|
image_value,
|
||||||
|
(self.image_size, self.image_size),
|
||||||
|
method=self.image_resize_algorithm,
|
||||||
|
color=padding_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look for any image type in the first message
|
||||||
|
# some dataset have an {type: "image"} in the first message
|
||||||
|
ind_to_add = None
|
||||||
|
|
||||||
|
for i, content in enumerate(
|
||||||
|
processed_example["messages"][0]["content"]
|
||||||
|
):
|
||||||
|
# Usually datasets created with image columns, don't have it in the messages itself
|
||||||
|
if content["type"] == "image" and all(
|
||||||
|
k not in content for k in ["image", "url", "path", "base64"]
|
||||||
|
):
|
||||||
|
ind_to_add = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# If an image type is found, add the image to that index
|
||||||
|
if ind_to_add is not None:
|
||||||
|
processed_example["messages"][0]["content"][ind_to_add][
|
||||||
|
"image"
|
||||||
|
] = image_value
|
||||||
|
else:
|
||||||
|
# if no image type is found, add it to end of the first message
|
||||||
|
processed_example["messages"][0]["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": image_value,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_examples.append(processed_example)
|
||||||
|
|
||||||
|
return processed_examples
|
||||||
|
|
||||||
|
def process_labels(self, input_ids: Tensor) -> Tensor:
|
||||||
|
labels = input_ids.clone()
|
||||||
|
|
||||||
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||||
|
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||||
|
|
||||||
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
labels[labels == self.image_token_id] = -100
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLProcessingStrategy(ProcessingStrategy):
|
||||||
|
"""Processing Strategy class for Qwen2-VL"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||||
|
self.image_token = "<|image_pad|>" # nosec
|
||||||
|
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
self.image_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3ProcessingStrategy(ProcessingStrategy):
|
||||||
|
"""Processing Strategy class for Gemma3"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||||
|
self.image_token = processor.tokenizer.special_tokens_map["boi_token"]
|
||||||
|
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
self.image_token
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_labels(self, input_ids):
|
||||||
|
labels = input_ids.clone()
|
||||||
|
|
||||||
|
# Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora
|
||||||
|
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||||
|
labels[labels == self.image_token_id] = -100
|
||||||
|
labels[labels == 262144] = -100 # corresponds to <image_soft_token>
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def get_processing_strategy(
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template,
|
||||||
|
chat_template_type,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
if chat_template_type == "qwen2_vl":
|
||||||
|
return Qwen2VLProcessingStrategy(
|
||||||
|
processor, chat_template, image_size, image_resize_algorithm
|
||||||
|
)
|
||||||
|
if chat_template_type == "gemma3":
|
||||||
|
return Gemma3ProcessingStrategy(
|
||||||
|
processor, chat_template, image_size, image_resize_algorithm
|
||||||
|
)
|
||||||
|
if chat_template_type in [
|
||||||
|
"llama3_2_vision",
|
||||||
|
"llava",
|
||||||
|
"mistral_v7_tekken",
|
||||||
|
"pixtral",
|
||||||
|
]:
|
||||||
|
return ProcessingStrategy(
|
||||||
|
processor, chat_template, image_size, image_resize_algorithm
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
|
||||||
File diff suppressed because one or more lines are too long
@@ -2,15 +2,17 @@
|
|||||||
Collators for multi-modal chat messages and packing
|
Collators for multi-modal chat messages and packing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from PIL import Image
|
import torch
|
||||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
from torch import Tensor
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.data.data_collator import DataCollatorMixin
|
from transformers.data.data_collator import DataCollatorMixin
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
from axolotl.processing_strategies import ProcessingStrategy
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiModalChatDataCollator(DataCollatorMixin):
|
class MultiModalChatDataCollator(DataCollatorMixin):
|
||||||
@@ -19,11 +21,9 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
processor: ProcessorMixin
|
processing_strategy: ProcessingStrategy
|
||||||
return_tensors: str = "pt"
|
|
||||||
chat_template: Optional[str] = None
|
|
||||||
packing: bool = False
|
packing: bool = False
|
||||||
max_images: int = -1
|
return_tensors: str = "pt"
|
||||||
padding: Union[bool, str, PaddingStrategy] = True
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
@@ -31,162 +31,62 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
if self.packing:
|
if self.packing:
|
||||||
raise ValueError("Packing is currently not supported.")
|
raise ValueError("Packing is currently not supported.")
|
||||||
|
|
||||||
def torch_call(
|
def torch_call(self, examples: list[dict]) -> dict[str, Any]:
|
||||||
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
return self.process_rows(examples)
|
||||||
) -> dict[str, Any]:
|
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
|
||||||
|
|
||||||
return self.__class__.process_rows(
|
|
||||||
examples, self.processor, self.chat_template, self.max_images
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
|
||||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
|
||||||
# see also DataCollatorWithFlattening and DefaultDataCollator
|
|
||||||
|
|
||||||
# *** This is COPIED from the trl example sft_vlm.py code ***
|
|
||||||
# use this as a starting point
|
|
||||||
|
|
||||||
def _preprocess(examples: list[dict]) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Preprocess conversation examples to ensure consistent format.
|
|
||||||
|
|
||||||
Converts different conversation formats to OpenAI format with 'messages'.
|
|
||||||
Supports two formats:
|
|
||||||
1. OpenAI format with 'messages'
|
|
||||||
2. Legacy format with 'conversations'
|
|
||||||
|
|
||||||
Args:
|
|
||||||
examples: list of conversation dictionaries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict in OpenAI format with 'messages' key
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the conversation format is not supported
|
|
||||||
"""
|
|
||||||
role_mapping = {
|
|
||||||
"human": "user",
|
|
||||||
"gpt": "assistant",
|
|
||||||
}
|
|
||||||
|
|
||||||
def normalize_role(role: str) -> str:
|
|
||||||
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
|
||||||
return role_mapping.get(role, role)
|
|
||||||
|
|
||||||
def convert_legacy_format(example: dict) -> dict:
|
|
||||||
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": normalize_role(convo["from"]),
|
|
||||||
"content": convo["value"],
|
|
||||||
}
|
|
||||||
for convo in example["conversations"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create new dict without 'conversations' key
|
|
||||||
result = deepcopy(example)
|
|
||||||
result.pop("conversations")
|
|
||||||
return {"messages": messages, **result}
|
|
||||||
|
|
||||||
processed_examples = []
|
|
||||||
for example in examples:
|
|
||||||
# OpenAI format
|
|
||||||
if "messages" in example:
|
|
||||||
processed_examples.append(example)
|
|
||||||
|
|
||||||
# Legacy format
|
|
||||||
elif "conversations" in example:
|
|
||||||
processed_examples.append(convert_legacy_format(example))
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Only `messages` and `conversations` message keys are currently supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
return processed_examples
|
|
||||||
|
|
||||||
def _process_images(examples, max_images):
|
|
||||||
"""
|
|
||||||
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
examples: List of dictionaries that may contain 'images' key
|
|
||||||
max_images: Maximum number of images to keep per example (0 means no limit)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Either None (if no images) or List[Image objects] (if all examples have images)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If there's a mix of None and non-None images
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_image(example):
|
|
||||||
if "images" not in example:
|
|
||||||
return None
|
|
||||||
images = example["images"]
|
|
||||||
if isinstance(images, str):
|
|
||||||
return Image.open(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
images = [get_image(example) for example in examples]
|
|
||||||
|
|
||||||
# Count None and non-None images
|
|
||||||
none_count = sum(1 for img in images if img is None)
|
|
||||||
|
|
||||||
# All images are None
|
|
||||||
if none_count == len(images):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Mix of None and non-None images
|
|
||||||
if none_count > 0:
|
|
||||||
raise ValueError(
|
|
||||||
"All images should be either None or not None. "
|
|
||||||
"Please provide images for all examples or None."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply max_images limit if specified
|
|
||||||
if max_images > 0:
|
|
||||||
images = [
|
|
||||||
(
|
|
||||||
img_batch[:max_images]
|
|
||||||
if isinstance(img_batch, (list, tuple))
|
|
||||||
else img_batch
|
|
||||||
)
|
|
||||||
for img_batch in images
|
|
||||||
]
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
def process_rows(
|
||||||
|
self,
|
||||||
|
examples: list[dict],
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
# Preprocess the examples
|
# Preprocess the examples
|
||||||
examples = _preprocess(examples)
|
examples = self.processing_strategy(examples)
|
||||||
|
|
||||||
# Get the texts and images, and apply the chat template
|
# Initialize batch
|
||||||
texts = [
|
batch: dict[str, Any] = {}
|
||||||
processor.apply_chat_template(
|
|
||||||
example["messages"], chat_template=chat_template, tokenize=False
|
# Process each example
|
||||||
|
for example in examples:
|
||||||
|
# Apply chat template to process the example
|
||||||
|
# This method requires transformers>=4.49.0
|
||||||
|
result = self.processing_strategy.processor.apply_chat_template(
|
||||||
|
example["messages"],
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
return_dict=True,
|
||||||
|
chat_template=self.processing_strategy.chat_template,
|
||||||
)
|
)
|
||||||
for example in examples
|
|
||||||
]
|
|
||||||
|
|
||||||
images = _process_images(examples, max_images=max_images)
|
# TODO: Check if need handling for len(input_ids) > sequence_len
|
||||||
|
|
||||||
# Tokenize the texts and process the images
|
# Add the processed tensors to our batch
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
for key in result.keys():
|
||||||
|
if key not in batch:
|
||||||
|
batch[key] = []
|
||||||
|
|
||||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
batch[key].append(result[key].squeeze(0))
|
||||||
labels = batch["input_ids"].clone()
|
|
||||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
# Pad sequences to the same length
|
||||||
# Ignore the image token index in the loss computation (model specific)
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
batch["input_ids"],
|
||||||
processor.image_token
|
batch_first=True,
|
||||||
|
padding_value=self.tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
labels[labels == image_token_id] = -100
|
|
||||||
batch["labels"] = labels
|
|
||||||
|
|
||||||
if length_only:
|
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||||
return {
|
batch["attention_mask"], batch_first=True, padding_value=0
|
||||||
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
)
|
||||||
}
|
|
||||||
return batch
|
# Create the final batch
|
||||||
|
final_batch = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process the labels
|
||||||
|
final_batch["labels"] = self.processing_strategy.process_labels(
|
||||||
|
final_batch["input_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_batch
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from axolotl.integrations.base import PluginManager
|
|||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model_config
|
from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config
|
||||||
from axolotl.utils.schemas.config import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
@@ -158,7 +158,7 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
cfg.is_multimodal = (
|
cfg.is_multimodal = (
|
||||||
hasattr(model_config, "model_type")
|
hasattr(model_config, "model_type")
|
||||||
and model_config.model_type in ["llava", "mllama"]
|
and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING
|
||||||
or any(
|
or any(
|
||||||
multimodal_name in cfg.base_model.lower()
|
multimodal_name in cfg.base_model.lower()
|
||||||
for multimodal_name in [
|
for multimodal_name in [
|
||||||
@@ -171,7 +171,6 @@ def normalize_config(cfg):
|
|||||||
cfg.processor_config = (
|
cfg.processor_config = (
|
||||||
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
||||||
)
|
)
|
||||||
model_config = model_config.text_config
|
|
||||||
|
|
||||||
cfg.model_config_type = model_config.model_type
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
|
|||||||
@@ -34,12 +34,16 @@ from transformers import ( # noqa: F401
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AwqConfig,
|
AwqConfig,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
|
Gemma3ForConditionalGeneration,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
|
Mistral3ForConditionalGeneration,
|
||||||
MllamaForConditionalGeneration,
|
MllamaForConditionalGeneration,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
Qwen2VLForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import (
|
from transformers.integrations.deepspeed import (
|
||||||
HfTrainerDeepSpeedConfig,
|
HfTrainerDeepSpeedConfig,
|
||||||
@@ -69,9 +73,13 @@ from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_mod
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
MULTIMODEL_AUTO_MODEL_MAPPING = {
|
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||||
"llava": LlavaForConditionalGeneration,
|
|
||||||
"mllama": MllamaForConditionalGeneration,
|
"mllama": MllamaForConditionalGeneration,
|
||||||
|
"llava": LlavaForConditionalGeneration,
|
||||||
|
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
||||||
|
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
||||||
|
"mistral3": Mistral3ForConditionalGeneration,
|
||||||
|
"gemma3": Gemma3ForConditionalGeneration,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -101,7 +109,21 @@ def get_module_class_from_name(module, name):
|
|||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
model_config = model_config.text_config
|
if hasattr(model_config, "text_config"):
|
||||||
|
model_config = model_config.text_config
|
||||||
|
model_config.use_cache = False
|
||||||
|
elif hasattr(model_config, "get_text_config"):
|
||||||
|
model_config = model_config.get_text_config()
|
||||||
|
model_config.use_cache = False
|
||||||
|
|
||||||
|
# check if image_size is not set and load image size from model config if available
|
||||||
|
if (
|
||||||
|
cfg.image_size is None
|
||||||
|
and hasattr(model_config, "vision_config")
|
||||||
|
and hasattr(model_config.vision_config, "image_size")
|
||||||
|
):
|
||||||
|
cfg.image_size = model_config.vision_config.image_size
|
||||||
|
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
||||||
|
|
||||||
quant_config_exists = (
|
quant_config_exists = (
|
||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
@@ -440,6 +462,31 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Attempt to load image size from processor if available
|
||||||
|
if (
|
||||||
|
cfg.image_size is None
|
||||||
|
and hasattr(processor, "size")
|
||||||
|
and any(dim in processor.size for dim in ["width", "height"])
|
||||||
|
):
|
||||||
|
im_width = None
|
||||||
|
im_height = None
|
||||||
|
if "width" in processor.size:
|
||||||
|
im_width = processor.size["width"]
|
||||||
|
if "height" in processor.size:
|
||||||
|
im_height = processor.size["height"]
|
||||||
|
|
||||||
|
# If both width and height are set, use a tuple
|
||||||
|
if im_width is not None and im_height is not None:
|
||||||
|
cfg.image_size = (im_width, im_height)
|
||||||
|
# If only width is set, use as integer
|
||||||
|
elif im_width is not None:
|
||||||
|
cfg.image_size = im_width
|
||||||
|
# If only height is set, use as integer
|
||||||
|
elif im_height is not None:
|
||||||
|
cfg.image_size = im_height
|
||||||
|
|
||||||
|
LOG.debug(f"Loaded image size: {cfg.image_size} from processor")
|
||||||
|
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
@@ -477,7 +524,11 @@ class ModelLoader:
|
|||||||
# init model config
|
# init model config
|
||||||
self.model_config = load_model_config(cfg)
|
self.model_config = load_model_config(cfg)
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
self.text_model_config = self.model_config.text_config
|
if hasattr(self.model_config, "text_config"):
|
||||||
|
self.text_model_config = self.model_config.text_config
|
||||||
|
else:
|
||||||
|
# for qwen2_vl
|
||||||
|
self.text_model_config = self.model_config.get_text_config()
|
||||||
else:
|
else:
|
||||||
self.text_model_config = self.model_config
|
self.text_model_config = self.model_config
|
||||||
|
|
||||||
@@ -673,7 +724,7 @@ class ModelLoader:
|
|||||||
should be set according to the type of the model.
|
should be set according to the type of the model.
|
||||||
"""
|
"""
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get(
|
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
||||||
self.model_config.model_type, AutoModelForVision2Seq
|
self.model_config.model_type, AutoModelForVision2Seq
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1194,7 +1245,9 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None:
|
if self.cfg.mean_resizing_embeddings is not None and not (
|
||||||
|
self.model_config.model_type == "llava"
|
||||||
|
):
|
||||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from axolotl.utils.schemas.model import (
|
|||||||
ModelOutputConfig,
|
ModelOutputConfig,
|
||||||
SpecialTokensConfig,
|
SpecialTokensConfig,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
@@ -64,6 +65,7 @@ class AxolotlInputConfig(
|
|||||||
LISAConfig,
|
LISAConfig,
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
|
MultiModalConfig,
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
DeprecatedParameters,
|
DeprecatedParameters,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ class ChatTemplate(str, Enum):
|
|||||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||||
|
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
gemma3_text = "gemma3_text" # pylint: disable=invalid-name
|
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||||
@@ -37,6 +37,10 @@ class ChatTemplate(str, Enum):
|
|||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
exaone = "exaone" # pylint: disable=invalid-name
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
metharme = "metharme" # pylint: disable=invalid-name
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
pixtral = "pixtral" # pylint: disable=invalid-name
|
||||||
|
llava = "llava" # pylint: disable=invalid-name
|
||||||
|
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
||||||
|
gemma3 = "gemma3" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class CustomSupportedOptimizers(str, Enum):
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
|
|||||||
48
src/axolotl/utils/schemas/multimodal.py
Normal file
48
src/axolotl/utils/schemas/multimodal.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Pydantic models for multimodal-related configuration"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from PIL.Image import Resampling
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalConfig(BaseModel):
|
||||||
|
"""Multi-modal configuration subset"""
|
||||||
|
|
||||||
|
image_size: int | tuple[int, int] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"The size of the image to resize to. It can be an integer (resized into padded-square image) or a tuple (width, height)."
|
||||||
|
"If not provided, we will attempt to load from preprocessor.size, otherwise, images won't be resized."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
image_resize_algorithm: (
|
||||||
|
Literal["bilinear", "bicubic", "lanczos"] | Resampling | None
|
||||||
|
) = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("image_resize_algorithm", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_image_resize_algorithm(cls, image_resize_algorithm):
|
||||||
|
"""
|
||||||
|
Convert the image resize algorithm to a PIL.Image.Resampling enum.
|
||||||
|
"""
|
||||||
|
if isinstance(image_resize_algorithm, str):
|
||||||
|
image_resize_algorithm = image_resize_algorithm.lower()
|
||||||
|
if image_resize_algorithm == "bilinear":
|
||||||
|
image_resize_algorithm = Resampling.BILINEAR
|
||||||
|
elif image_resize_algorithm == "bicubic":
|
||||||
|
image_resize_algorithm = Resampling.BICUBIC
|
||||||
|
elif image_resize_algorithm == "lanczos":
|
||||||
|
image_resize_algorithm = Resampling.LANCZOS
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid image resize algorithm: {image_resize_algorithm}"
|
||||||
|
)
|
||||||
|
return image_resize_algorithm
|
||||||
@@ -58,7 +58,7 @@ class TestGemma3Text:
|
|||||||
"bos_token": "<bos>",
|
"bos_token": "<bos>",
|
||||||
"eos_token": "<eos>",
|
"eos_token": "<eos>",
|
||||||
},
|
},
|
||||||
"chat_template": "gemma3_text",
|
"chat_template": "gemma3",
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
@@ -105,7 +105,7 @@ class TestGemma3Text:
|
|||||||
"split": "train[:1%]",
|
"split": "train[:1%]",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"chat_template": "gemma3_text",
|
"chat_template": "gemma3",
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"bos_token": "<bos>",
|
"bos_token": "<bos>",
|
||||||
"eos_token": "<eos>",
|
"eos_token": "<eos>",
|
||||||
|
|||||||
Reference in New Issue
Block a user