diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index ec51a8ec3..dbb365f73 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -14,6 +14,7 @@ format: - [Llava-1.5](#sec-llava-15) - [Mistral-Small-3.1](#sec-mistral-small-31) - [Gemma-3](#sec-gemma-3) +- [Gemma-3n](#sec-gemma-3n) - [Qwen2-VL](#sec-qwen2-vl) - [Qwen2.5-VL](#sec-qwen25-vl) @@ -110,6 +111,22 @@ base_model: google/gemma-3-4b-it chat_template: gemma3 ``` +### Gemma-3n {#sec-gemma-3n} + +::: {.callout-warning} +The model's initial loss and grad norm will be very high. We suspect this to be due to the Conv in the vision layers. +::: + +::: {.callout-tip} +Please make sure to install `timm` via `pip3 install timm==1.0.17` +::: + +```yaml +base_model: google/gemma-3n-E2B-it + +chat_template: gemma3n +``` + ### Qwen2-VL {#sec-qwen2-vl} ```yaml @@ -132,7 +149,9 @@ For multi-modal datasets, we adopt an extended `chat_template` format similar to - A message is a list of `role` and `content`. - `role` can be `system`, `user`, `assistant`, etc. -- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`). +- `content` is a list of `type` and (`text`, `image`, `path`, `url`, `base64`, or `audio`). + +### Image ::: {.callout-note} For backwards compatibility: @@ -141,15 +160,29 @@ For backwards compatibility: - If `content` is a string, it will be converted to a list with `type` as `text`. ::: -::: {.callout-tip} For image loading, you can use the following keys within `content` alongside `"type": "image"`: - `"path": "/path/to/image.jpg"` - `"url": "https://example.com/image.jpg"` - `"base64": "..."` - `"image": PIL.Image` + +### Audio + +For audio loading, you can use the following keys within `content` alongside `"type": "audio"`: + +- `"path": "/path/to/audio.mp3"` +- `"url": "https://example.com/audio.mp3"` +- `"audio": np.ndarray` + +::: {.callout-tip} + +You may need to install `librosa` via `pip3 install librosa==0.11.0`. + ::: +### Example + Here is an example of a multi-modal dataset: ```json [ @@ -178,3 +211,9 @@ Here is an example of a multi-modal dataset: } ] ``` + +## FAQ + +1. `PIL.UnidentifiedImageError: cannot identify image file ...` + +`PIL` could not retrieve the file at `url` using `requests`. Please check for typo. One alternative reason is that the request is blocked by the server. diff --git a/examples/gemma3n/README.md b/examples/gemma3n/README.md new file mode 100644 index 000000000..b3922d526 --- /dev/null +++ b/examples/gemma3n/README.md @@ -0,0 +1,19 @@ +# Gemma-3n + +## Requirements + +In addition to Axolotl's requirements, Gemma-3n requires + +``` +pip3 install timm +``` + +If you will load audio datasets, please also install + +``` +pip3 install librosa +``` + +## Usage + +See example configs and the [multimodal doc](https://docs.axolotl.ai/docs/multimodal.html). diff --git a/examples/gemma3n/gemma-3n-e2b-qlora.yml b/examples/gemma3n/gemma-3n-e2b-qlora.yml new file mode 100644 index 000000000..7868af59e --- /dev/null +++ b/examples/gemma3n/gemma-3n-e2b-qlora.yml @@ -0,0 +1,74 @@ +base_model: google/gemma-3n-E2B-it + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +cut_cross_entropy: true + +load_in_8bit: false +load_in_4bit: true + +# for use with fft to only train on language model layers +# unfrozen_parameters: + # - model.language_model.* + # - lm_head + # - embed_tokens + + +chat_template: gemma3n +eot_tokens: + - +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + split: train[:1%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +# lora_target_linear: # Does not work with gemma3n currently +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 4 +optimizer: muon +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +# flash_attention: true # Any attention impl does not work with gemma3n now + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml b/examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml new file mode 100644 index 000000000..6cdf5573e --- /dev/null +++ b/examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml @@ -0,0 +1,80 @@ +base_model: google/gemma-3n-E2B-it +processor_type: AutoProcessor + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +cut_cross_entropy: true + +# for use with fft to only train on language model layers +# unfrozen_parameters: + # - model.language_model.* + # - lm_head + # - embed_tokens + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + +chat_template: gemma3n +eot_tokens: + - + +# sample dataset below requires downloading audio/image in advance +# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg +# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga +datasets: + - path: Nanobit/text-vision-audio-2k-test + type: chat_template + data_files: + - dataset.jsonl +dataset_prepared_path: +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: muon +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +# flash_attention: true # Any attention impl does not work with gemma3n now + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/examples/gemma3n/gemma-3n-e2b-vision-qlora.yml b/examples/gemma3n/gemma-3n-e2b-vision-qlora.yml new file mode 100644 index 000000000..519edecc7 --- /dev/null +++ b/examples/gemma3n/gemma-3n-e2b-vision-qlora.yml @@ -0,0 +1,75 @@ +base_model: google/gemma-3n-E2B-it +processor_type: AutoProcessor + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +cut_cross_entropy: true + +# for use with fft to only train on language model layers +# unfrozen_parameters: + # - model.language_model.* + # - lm_head + # - embed_tokens + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + +chat_template: gemma3n +eot_tokens: + - +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] +dataset_prepared_path: +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: muon +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +# flash_attention: true # Any attention impl does not work with gemma3n now + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml index 64d749b5a..da28ace3b 100644 --- a/examples/llama-3-vision/lora-11b.yaml +++ b/examples/llama-3-vision/lora-11b.yaml @@ -15,8 +15,7 @@ datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages -dataset_prepared_path: last_run_prepared +dataset_prepared_path: val_set_size: 0.0 output_dir: ./outputs/out @@ -40,7 +39,7 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 1 -optimizer: adamw_bnb_8bit +optimizer: muon lr_scheduler: cosine learning_rate: 0.0002 @@ -50,8 +49,8 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +# flash_attention: true # use for text-only mode +sdp_attention: true warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml index a4bac8987..53ae97542 100644 --- a/examples/llava/lora-7b.yaml +++ b/examples/llava/lora-7b.yaml @@ -11,8 +11,7 @@ datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages -dataset_prepared_path: last_run_prepared +dataset_prepared_path: val_set_size: 0.0 output_dir: ./outputs/out @@ -36,7 +35,7 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 1 -optimizer: adamw_bnb_8bit +optimizer: muon lr_scheduler: cosine learning_rate: 0.0002 diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml index 4a492c595..3e477645e 100644 --- a/examples/mistral/mistral-small-3.1-24B-lora.yml +++ b/examples/mistral/mistral-small-3.1-24B-lora.yml @@ -48,8 +48,8 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet. -eager_attention: +# flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet. +sdp_attention: true warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml index ea769d202..fc4c0667c 100644 --- a/examples/pixtral/lora-12b.yml +++ b/examples/pixtral/lora-12b.yml @@ -11,8 +11,7 @@ datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages -dataset_prepared_path: last_run_prepared +dataset_prepared_path: val_set_size: 0.0 output_dir: ./outputs/out @@ -36,7 +35,7 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 1 -optimizer: adamw_bnb_8bit +optimizer: muon lr_scheduler: cosine learning_rate: 0.0002 @@ -46,8 +45,8 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet -eager_attention: +# flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet +sdp_attention: true warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 07c6de1f8..a97bac71c 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -37,6 +37,8 @@ plugins: - gemma2 - gemma3 - gemma3_text +- gemma3n +- gemma3n_text - glm - glm4 - llama diff --git a/src/axolotl/loaders/constants.py b/src/axolotl/loaders/constants.py index c08518dd6..c340c414c 100644 --- a/src/axolotl/loaders/constants.py +++ b/src/axolotl/loaders/constants.py @@ -2,6 +2,7 @@ from transformers import ( Gemma3ForConditionalGeneration, + Gemma3nForConditionalGeneration, Llama4ForConditionalGeneration, LlavaForConditionalGeneration, Mistral3ForConditionalGeneration, @@ -18,4 +19,5 @@ MULTIMODAL_AUTO_MODEL_MAPPING = { "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, "mistral3": Mistral3ForConditionalGeneration, "gemma3": Gemma3ForConditionalGeneration, + "gemma3n": Gemma3nForConditionalGeneration, } diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index 080697400..1cb297406 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -5,7 +5,7 @@ from typing import Optional from PIL import Image, ImageOps from PIL.Image import Resampling -from torch import Tensor +from torch import Tensor, zeros_like from transformers import ProcessorMixin from transformers.image_utils import load_image @@ -208,9 +208,18 @@ class ProcessingStrategy: return processed_examples + def _mask_non_assistant(self, labels: Tensor) -> Tensor: + """ + Mask non assistant regions to -100. + To be implemented per subclass. + """ + return labels + def process_labels(self, input_ids: Tensor) -> Tensor: labels = input_ids.clone() + labels = self._mask_non_assistant(labels) + # The labels are the input_ids, and we mask the padding tokens in the loss computation labels[labels == self.processor.tokenizer.pad_token_id] = -100 @@ -264,6 +273,99 @@ class Gemma3ProcessingStrategy(ProcessingStrategy): return labels +class Gemma3nProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Gemma3n""" + + def _mask_non_assistant(self, labels: Tensor) -> Tensor: + def _find_token_sequence(label, start_pos, token_sequence): + """Check if token_sequence appears at start_pos in label""" + if start_pos + len(token_sequence) > len(label): + return False + if label[start_pos] != token_sequence[0]: + return False + return ( + label[start_pos : start_pos + len(token_sequence)].tolist() + == token_sequence + ) + + def _find_assistant_end(label, start_pos, assistant_end_tok, mask, i): + """ + Find the end of assistant response and update mask accordingly + + Returns new position to continue from and whether the end seq is found + """ + k = start_pos + while k < len(label): + if not _find_token_sequence(label, k, assistant_end_tok): + mask[i][k] = 1 + k += 1 + continue + + return k + len(assistant_end_tok), True + + return k, False + + mask = zeros_like(labels) + + assistant_start_str = "model" + assistant_end_str = "" + include_assistant_start_tok = False + include_assistant_end_tok = True + + # str to tokens + assistant_start_tok = self.processor.tokenizer.encode( + assistant_start_str, add_special_tokens=False + ) + assistant_end_tok = self.processor.tokenizer.encode( + assistant_end_str, add_special_tokens=False + ) + + for i, label in enumerate(labels): + j = 0 + # while loop through each tok index in labels[i] + while j < len(label): + # Check until match start seq + if not _find_token_sequence(label, j, assistant_start_tok): + j += 1 + continue + + if include_assistant_start_tok: + mask[i][j : j + len(assistant_start_tok)] = 1 + + # Find where the assistant response ends + start_of_content = j + len(assistant_start_tok) + end_pos, found_end_seq = _find_assistant_end( + label, start_of_content, assistant_end_tok, mask, i + ) + + # Include end token if requested + if include_assistant_end_tok and found_end_seq: + mask[i][end_pos - len(assistant_end_tok) : end_pos] = 1 + + j = end_pos + + labels[i][mask[i] == 0] = -100 + + return labels + + def process_labels(self, input_ids): + labels = input_ids.clone() + labels = self._mask_non_assistant(labels) + + # Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + if hasattr(self.processor.tokenizer, "image_token_id"): + labels[labels == self.processor.tokenizer.image_token_id] = -100 + if hasattr(self.processor.tokenizer, "audio_token_id"): + labels[labels == self.processor.tokenizer.audio_token_id] = -100 + if hasattr(self.processor.tokenizer, "boi_token_id"): + labels[labels == self.processor.tokenizer.boi_token_id] = -100 + if hasattr(self.processor.tokenizer, "eoi_token_id"): + labels[labels == self.processor.tokenizer.eoi_token_id] = -100 + + return labels + + def get_processing_strategy( processor: ProcessorMixin, chat_template, @@ -279,6 +381,10 @@ def get_processing_strategy( return Gemma3ProcessingStrategy( processor, chat_template, image_size, image_resize_algorithm ) + if chat_template_type == "gemma3n": + return Gemma3nProcessingStrategy( + processor, chat_template, image_size, image_resize_algorithm + ) if chat_template_type in [ "llama3_2_vision", "llama4", diff --git a/src/axolotl/utils/chat_templates/templates/gemma3n.jinja b/src/axolotl/utils/chat_templates/templates/gemma3n.jinja new file mode 100644 index 000000000..a0405ea9c --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/gemma3n.jinja @@ -0,0 +1,49 @@ +{{ bos_token }} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + ' + +' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + ' + +' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {{ '' + role + ' +' + (first_user_prefix if loop.first else "") }} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'audio' -%} + {{ '' }} + {%- elif item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + {{ ' +' }} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{'model +'}} +{%- endif -%} diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 8b9d728d5..0075d4830 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -84,6 +84,17 @@ class MultiModalChatDataCollator(DataCollatorMixin): "attention_mask": attention_mask, } + for key, val in batch.items(): + if key in ["input_ids", "attention_mask"]: + continue + + if key in ["token_type_ids", "cross_attention_mask"]: + final_batch[key] = torch.nn.utils.rnn.pad_sequence( + val, batch_first=True, padding_value=0 + ) + else: + final_batch[key] = torch.stack(val) + # Process the labels final_batch["labels"] = self.processing_strategy.process_labels( final_batch["input_ids"] diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 67fc7a8a7..3c8828396 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -62,6 +62,7 @@ class ChatTemplate(str, Enum): llava = "llava" qwen2_vl = "qwen2_vl" gemma3 = "gemma3" + gemma3n = "gemma3n" command_a = "command_a" command_a_tool_use = "command_a_tool_use" command_a_rag = "command_a_rag"