diff --git a/examples/gemma3n/gemma-3n-e2b-qlora.yml b/examples/gemma3n/gemma-3n-e2b-qlora.yml index ffc1da736..09504e14c 100644 --- a/examples/gemma3n/gemma-3n-e2b-qlora.yml +++ b/examples/gemma3n/gemma-3n-e2b-qlora.yml @@ -1,7 +1,5 @@ base_model: google/gemma-3n-E2B-it -model_type: AutoModelForCausalLM -tokenizer_type: AutoTokenizer # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name @@ -18,8 +16,8 @@ load_in_4bit: true # - lm_head # - embed_tokens -# huggingface repo -# chat_template: gemma3 + +chat_template: gemma3n eot_tokens: - datasets: diff --git a/examples/gemma3n/gemma-3n-e2b-vision-qlora.yml b/examples/gemma3n/gemma-3n-e2b-vision-qlora.yml new file mode 100644 index 000000000..d915a60b6 --- /dev/null +++ b/examples/gemma3n/gemma-3n-e2b-vision-qlora.yml @@ -0,0 +1,76 @@ +base_model: google/gemma-3n-E2B-it +processor_type: AutoProcessor + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +cut_cross_entropy: true + +# for use with fft to only train on language model layers +# unfrozen_parameters: + # - model.language_model.* + # - lm_head + # - embed_tokens + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + +chat_template: gemma3n +eot_tokens: + - +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +# flash_attention: true # Any attention impl does not work with gemma3n now + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/src/axolotl/loaders/constants.py b/src/axolotl/loaders/constants.py index c08518dd6..c340c414c 100644 --- a/src/axolotl/loaders/constants.py +++ b/src/axolotl/loaders/constants.py @@ -2,6 +2,7 @@ from transformers import ( Gemma3ForConditionalGeneration, + Gemma3nForConditionalGeneration, Llama4ForConditionalGeneration, LlavaForConditionalGeneration, Mistral3ForConditionalGeneration, @@ -18,4 +19,5 @@ MULTIMODAL_AUTO_MODEL_MAPPING = { "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, "mistral3": Mistral3ForConditionalGeneration, "gemma3": Gemma3ForConditionalGeneration, + "gemma3n": Gemma3nForConditionalGeneration, } diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index 080697400..f22e601d9 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -264,6 +264,26 @@ class Gemma3ProcessingStrategy(ProcessingStrategy): return labels +class Gemma3nProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Gemma3n""" + + def process_labels(self, input_ids): + labels = input_ids.clone() + + # Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + if hasattr(self.processor.tokenizer, "image_token_id"): + labels[labels == self.processor.tokenizer.image_token_id] = -100 + if hasattr(self.processor.tokenizer, "audio_token_id"): + labels[labels == self.processor.tokenizer.audio_token_id] = -100 + if hasattr(self.processor.tokenizer, "boi_token_id"): + labels[labels == self.processor.tokenizer.boi_token_id] = -100 + if hasattr(self.processor.tokenizer, "eoi_token_id"): + labels[labels == self.processor.tokenizer.eoi_token_id] = -100 + + return labels + + def get_processing_strategy( processor: ProcessorMixin, chat_template, @@ -279,6 +299,10 @@ def get_processing_strategy( return Gemma3ProcessingStrategy( processor, chat_template, image_size, image_resize_algorithm ) + if chat_template_type == "gemma3n": + return Gemma3nProcessingStrategy( + processor, chat_template, image_size, image_resize_algorithm + ) if chat_template_type in [ "llama3_2_vision", "llama4", diff --git a/src/axolotl/utils/chat_templates/templates/gemma3n.jinja b/src/axolotl/utils/chat_templates/templates/gemma3n.jinja new file mode 100644 index 000000000..a0405ea9c --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/gemma3n.jinja @@ -0,0 +1,49 @@ +{{ bos_token }} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + ' + +' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + ' + +' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {{ '' + role + ' +' + (first_user_prefix if loop.first else "") }} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'audio' -%} + {{ '' }} + {%- elif item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + {{ ' +' }} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{'model +'}} +{%- endif -%} diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 67fc7a8a7..3c8828396 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -62,6 +62,7 @@ class ChatTemplate(str, Enum): llava = "llava" qwen2_vl = "qwen2_vl" gemma3 = "gemma3" + gemma3n = "gemma3n" command_a = "command_a" command_a_tool_use = "command_a_tool_use" command_a_rag = "command_a_rag"