diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 7c4e75796..3f8116b21 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -64,6 +64,7 @@ class PatchManager: self._patch_llama_derived_model() self._apply_mistral_cross_entropy_patch() self._apply_self_attention_lora_patch() + self._apply_gemma3_conditional_generation_forward_patch() def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" @@ -221,6 +222,15 @@ class PatchManager: has_remote_code=has_remote_code, ) + def _apply_gemma3_conditional_generation_forward_patch(self): + """Apply gemma3 conditional generation forward patch.""" + if self.model_config.model_type in ["gemma3", "gemma3_text"]: + from axolotl.monkeypatch.models.gemma3.modeling import ( + patch_gemma3_conditional_generation_forward, + ) + + patch_gemma3_conditional_generation_forward() + def _patch_attention(self): """Apply attention-specific patches based on model type.""" if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): diff --git a/src/axolotl/monkeypatch/models/gemma3/__init__.py b/src/axolotl/monkeypatch/models/gemma3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/gemma3/modeling.py b/src/axolotl/monkeypatch/models/gemma3/modeling.py new file mode 100644 index 000000000..3b608c347 --- /dev/null +++ b/src/axolotl/monkeypatch/models/gemma3/modeling.py @@ -0,0 +1,16 @@ +"""Monkeypatch for gemma3 conditional generation forward to fix high loss""" + + +def patch_gemma3_conditional_generation_forward(): + # Remove when https://github.com/huggingface/transformers/pull/37208 merged + + from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3ForConditionalGeneration, + ) + + setattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs", False) + + def unpatch(): + delattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs") + + return unpatch