Feat: add gemma3n support (#2852)

* feat: add gemma3n cce

* feat: add sample config

* feat: add gemma3n multimodal mode

* feat: add audio example

* feat: support audio and return pixel values in collator

* feat: support unmask only assistant region (gemma3n for now)

* feat(doc): add notes for audio loading

* feat: add audio support for gemma3n

* feat: update examples

* feat: add gemma3n to the docs

* fix: add link at top

* feat(doc): clarify additional requirements

* fix: mllama missing aspect ratio

* fix: mllama need attention fixes for fa2

* Partially Revert "fix: mllama need attention fixes for fa2"

This reverts commit a0bfdd1777.

* fix: disable FA2 for mllama in vision mode

* feat: update configs to use proper attention

* fix: support other vision features

* feat(doc): clarify requirements for gemma3n
This commit is contained in:
NanoCode012
2025-07-22 16:52:15 +07:00
committed by GitHub
parent d32058e149
commit dfba881e99
15 changed files with 473 additions and 18 deletions

View File

@@ -14,6 +14,7 @@ format:
- [Llava-1.5](#sec-llava-15) - [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31) - [Mistral-Small-3.1](#sec-mistral-small-31)
- [Gemma-3](#sec-gemma-3) - [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl) - [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl) - [Qwen2.5-VL](#sec-qwen25-vl)
@@ -110,6 +111,22 @@ base_model: google/gemma-3-4b-it
chat_template: gemma3 chat_template: gemma3
``` ```
### Gemma-3n {#sec-gemma-3n}
::: {.callout-warning}
The model's initial loss and grad norm will be very high. We suspect this to be due to the Conv in the vision layers.
:::
::: {.callout-tip}
Please make sure to install `timm` via `pip3 install timm==1.0.17`
:::
```yaml
base_model: google/gemma-3n-E2B-it
chat_template: gemma3n
```
### Qwen2-VL {#sec-qwen2-vl} ### Qwen2-VL {#sec-qwen2-vl}
```yaml ```yaml
@@ -132,7 +149,9 @@ For multi-modal datasets, we adopt an extended `chat_template` format similar to
- A message is a list of `role` and `content`. - A message is a list of `role` and `content`.
- `role` can be `system`, `user`, `assistant`, etc. - `role` can be `system`, `user`, `assistant`, etc.
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`). - `content` is a list of `type` and (`text`, `image`, `path`, `url`, `base64`, or `audio`).
### Image
::: {.callout-note} ::: {.callout-note}
For backwards compatibility: For backwards compatibility:
@@ -141,15 +160,29 @@ For backwards compatibility:
- If `content` is a string, it will be converted to a list with `type` as `text`. - If `content` is a string, it will be converted to a list with `type` as `text`.
::: :::
::: {.callout-tip}
For image loading, you can use the following keys within `content` alongside `"type": "image"`: For image loading, you can use the following keys within `content` alongside `"type": "image"`:
- `"path": "/path/to/image.jpg"` - `"path": "/path/to/image.jpg"`
- `"url": "https://example.com/image.jpg"` - `"url": "https://example.com/image.jpg"`
- `"base64": "..."` - `"base64": "..."`
- `"image": PIL.Image` - `"image": PIL.Image`
### Audio
For audio loading, you can use the following keys within `content` alongside `"type": "audio"`:
- `"path": "/path/to/audio.mp3"`
- `"url": "https://example.com/audio.mp3"`
- `"audio": np.ndarray`
::: {.callout-tip}
You may need to install `librosa` via `pip3 install librosa==0.11.0`.
::: :::
### Example
Here is an example of a multi-modal dataset: Here is an example of a multi-modal dataset:
```json ```json
[ [
@@ -178,3 +211,9 @@ Here is an example of a multi-modal dataset:
} }
] ]
``` ```
## FAQ
1. `PIL.UnidentifiedImageError: cannot identify image file ...`
`PIL` could not retrieve the file at `url` using `requests`. Please check for typo. One alternative reason is that the request is blocked by the server.

View File

@@ -0,0 +1,19 @@
# Gemma-3n
## Requirements
In addition to Axolotl's requirements, Gemma-3n requires
```
pip3 install timm
```
If you will load audio datasets, please also install
```
pip3 install librosa
```
## Usage
See example configs and the [multimodal doc](https://docs.axolotl.ai/docs/multimodal.html).

View File

@@ -0,0 +1,74 @@
base_model: google/gemma-3n-E2B-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
load_in_8bit: false
load_in_4bit: true
# for use with fft to only train on language model layers
# unfrozen_parameters:
# - model.language_model.*
# - lm_head
# - embed_tokens
chat_template: gemma3n
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
split: train[:1%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# lora_target_linear: # Does not work with gemma3n currently
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: muon
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
# flash_attention: true # Any attention impl does not work with gemma3n now
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,80 @@
base_model: google/gemma-3n-E2B-it
processor_type: AutoProcessor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
# for use with fft to only train on language model layers
# unfrozen_parameters:
# - model.language_model.*
# - lm_head
# - embed_tokens
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3n
eot_tokens:
- <end_of_turn>
# sample dataset below requires downloading audio/image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
datasets:
- path: Nanobit/text-vision-audio-2k-test
type: chat_template
data_files:
- dataset.jsonl
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: muon
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
# flash_attention: true # Any attention impl does not work with gemma3n now
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,75 @@
base_model: google/gemma-3n-E2B-it
processor_type: AutoProcessor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
# for use with fft to only train on language model layers
# unfrozen_parameters:
# - model.language_model.*
# - lm_head
# - embed_tokens
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3n
eot_tokens:
- <end_of_turn>
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: muon
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
# flash_attention: true # Any attention impl does not work with gemma3n now
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -15,8 +15,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft - path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template type: chat_template
split: train[:1%] split: train[:1%]
field_messages: messages dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./outputs/out
@@ -40,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1
optimizer: adamw_bnb_8bit optimizer: muon
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -50,8 +49,8 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
logging_steps: 1 logging_steps: 1
flash_attention: true # flash_attention: true # use for text-only mode
eager_attention: sdp_attention: true
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -11,8 +11,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft - path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template type: chat_template
split: train[:1%] split: train[:1%]
field_messages: messages dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./outputs/out
@@ -36,7 +35,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1
optimizer: adamw_bnb_8bit optimizer: muon
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002

View File

@@ -48,8 +48,8 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
logging_steps: 1 logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet. # flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
eager_attention: sdp_attention: true
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -11,8 +11,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft - path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template type: chat_template
split: train[:1%] split: train[:1%]
field_messages: messages dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./outputs/out
@@ -36,7 +35,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1
optimizer: adamw_bnb_8bit optimizer: muon
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -46,8 +45,8 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
logging_steps: 1 logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet # flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet
eager_attention: sdp_attention: true
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -37,6 +37,8 @@ plugins:
- gemma2 - gemma2
- gemma3 - gemma3
- gemma3_text - gemma3_text
- gemma3n
- gemma3n_text
- glm - glm
- glm4 - glm4
- llama - llama

View File

@@ -2,6 +2,7 @@
from transformers import ( from transformers import (
Gemma3ForConditionalGeneration, Gemma3ForConditionalGeneration,
Gemma3nForConditionalGeneration,
Llama4ForConditionalGeneration, Llama4ForConditionalGeneration,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration, Mistral3ForConditionalGeneration,
@@ -18,4 +19,5 @@ MULTIMODAL_AUTO_MODEL_MAPPING = {
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration, "mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration, "gemma3": Gemma3ForConditionalGeneration,
"gemma3n": Gemma3nForConditionalGeneration,
} }

View File

@@ -5,7 +5,7 @@ from typing import Optional
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.Image import Resampling from PIL.Image import Resampling
from torch import Tensor from torch import Tensor, zeros_like
from transformers import ProcessorMixin from transformers import ProcessorMixin
from transformers.image_utils import load_image from transformers.image_utils import load_image
@@ -208,9 +208,18 @@ class ProcessingStrategy:
return processed_examples return processed_examples
def _mask_non_assistant(self, labels: Tensor) -> Tensor:
"""
Mask non assistant regions to -100.
To be implemented per subclass.
"""
return labels
def process_labels(self, input_ids: Tensor) -> Tensor: def process_labels(self, input_ids: Tensor) -> Tensor:
labels = input_ids.clone() labels = input_ids.clone()
labels = self._mask_non_assistant(labels)
# The labels are the input_ids, and we mask the padding tokens in the loss computation # The labels are the input_ids, and we mask the padding tokens in the loss computation
labels[labels == self.processor.tokenizer.pad_token_id] = -100 labels[labels == self.processor.tokenizer.pad_token_id] = -100
@@ -264,6 +273,99 @@ class Gemma3ProcessingStrategy(ProcessingStrategy):
return labels return labels
class Gemma3nProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Gemma3n"""
def _mask_non_assistant(self, labels: Tensor) -> Tensor:
def _find_token_sequence(label, start_pos, token_sequence):
"""Check if token_sequence appears at start_pos in label"""
if start_pos + len(token_sequence) > len(label):
return False
if label[start_pos] != token_sequence[0]:
return False
return (
label[start_pos : start_pos + len(token_sequence)].tolist()
== token_sequence
)
def _find_assistant_end(label, start_pos, assistant_end_tok, mask, i):
"""
Find the end of assistant response and update mask accordingly
Returns new position to continue from and whether the end seq is found
"""
k = start_pos
while k < len(label):
if not _find_token_sequence(label, k, assistant_end_tok):
mask[i][k] = 1
k += 1
continue
return k + len(assistant_end_tok), True
return k, False
mask = zeros_like(labels)
assistant_start_str = "<start_of_turn>model"
assistant_end_str = "<end_of_turn>"
include_assistant_start_tok = False
include_assistant_end_tok = True
# str to tokens
assistant_start_tok = self.processor.tokenizer.encode(
assistant_start_str, add_special_tokens=False
)
assistant_end_tok = self.processor.tokenizer.encode(
assistant_end_str, add_special_tokens=False
)
for i, label in enumerate(labels):
j = 0
# while loop through each tok index in labels[i]
while j < len(label):
# Check until match start seq
if not _find_token_sequence(label, j, assistant_start_tok):
j += 1
continue
if include_assistant_start_tok:
mask[i][j : j + len(assistant_start_tok)] = 1
# Find where the assistant response ends
start_of_content = j + len(assistant_start_tok)
end_pos, found_end_seq = _find_assistant_end(
label, start_of_content, assistant_end_tok, mask, i
)
# Include end token if requested
if include_assistant_end_tok and found_end_seq:
mask[i][end_pos - len(assistant_end_tok) : end_pos] = 1
j = end_pos
labels[i][mask[i] == 0] = -100
return labels
def process_labels(self, input_ids):
labels = input_ids.clone()
labels = self._mask_non_assistant(labels)
# Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb
labels[labels == self.processor.tokenizer.pad_token_id] = -100
if hasattr(self.processor.tokenizer, "image_token_id"):
labels[labels == self.processor.tokenizer.image_token_id] = -100
if hasattr(self.processor.tokenizer, "audio_token_id"):
labels[labels == self.processor.tokenizer.audio_token_id] = -100
if hasattr(self.processor.tokenizer, "boi_token_id"):
labels[labels == self.processor.tokenizer.boi_token_id] = -100
if hasattr(self.processor.tokenizer, "eoi_token_id"):
labels[labels == self.processor.tokenizer.eoi_token_id] = -100
return labels
def get_processing_strategy( def get_processing_strategy(
processor: ProcessorMixin, processor: ProcessorMixin,
chat_template, chat_template,
@@ -279,6 +381,10 @@ def get_processing_strategy(
return Gemma3ProcessingStrategy( return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm processor, chat_template, image_size, image_resize_algorithm
) )
if chat_template_type == "gemma3n":
return Gemma3nProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [ if chat_template_type in [
"llama3_2_vision", "llama3_2_vision",
"llama4", "llama4",

View File

@@ -0,0 +1,49 @@
{{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
{%- if messages[0]['content'] is string -%}
{%- set first_user_prefix = messages[0]['content'] + '
' -%}
{%- else -%}
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '
' -%}
{%- endif -%}
{%- set loop_messages = messages[1:] -%}
{%- else -%}
{%- set first_user_prefix = "" -%}
{%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif -%}
{%- if (message['role'] == 'assistant') -%}
{%- set role = "model" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<start_of_turn>' + role + '
' + (first_user_prefix if loop.first else "") }}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'audio' -%}
{{ '<audio_soft_token>' }}
{%- elif item['type'] == 'image' -%}
{{ '<image_soft_token>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{{ '<end_of_turn>
' }}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{'<start_of_turn>model
'}}
{%- endif -%}

View File

@@ -84,6 +84,17 @@ class MultiModalChatDataCollator(DataCollatorMixin):
"attention_mask": attention_mask, "attention_mask": attention_mask,
} }
for key, val in batch.items():
if key in ["input_ids", "attention_mask"]:
continue
if key in ["token_type_ids", "cross_attention_mask"]:
final_batch[key] = torch.nn.utils.rnn.pad_sequence(
val, batch_first=True, padding_value=0
)
else:
final_batch[key] = torch.stack(val)
# Process the labels # Process the labels
final_batch["labels"] = self.processing_strategy.process_labels( final_batch["labels"] = self.processing_strategy.process_labels(
final_batch["input_ids"] final_batch["input_ids"]

View File

@@ -62,6 +62,7 @@ class ChatTemplate(str, Enum):
llava = "llava" llava = "llava"
qwen2_vl = "qwen2_vl" qwen2_vl = "qwen2_vl"
gemma3 = "gemma3" gemma3 = "gemma3"
gemma3n = "gemma3n"
command_a = "command_a" command_a = "command_a"
command_a_tool_use = "command_a_tool_use" command_a_tool_use = "command_a_tool_use"
command_a_rag = "command_a_rag" command_a_rag = "command_a_rag"