diff --git a/docs/lora_optims.qmd b/docs/lora_optims.qmd index a7555a0a3..56d56e9fc 100644 --- a/docs/lora_optims.qmd +++ b/docs/lora_optims.qmd @@ -17,6 +17,7 @@ We currently support several common model architectures, including (but not limi - `qwen2` - `gemma` - `gemma2` +- `gemma3`
diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index d59fd22c9..96cfb1b69 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -252,12 +252,38 @@ def apply_lora_kernel_patches( LOG.setLevel(logging.INFO) # Choose activation based on model type - activation = model.config.hidden_act + activation = None + text_config = ( + model.config.get_text_config() + if hasattr(model.config, "get_text_config") + else model.config + ) + if hasattr(text_config, "hidden_act"): + activation = text_config.hidden_act + elif hasattr(text_config, "hidden_activation"): + activation = text_config.hidden_activation + + # map activation to supported activation + if "gelu" in activation: + # gemma3 uses gelu_pytorch_tanh + activation = "gelu" + if activation not in SUPPORTED_ACTIVATIONS: raise NotImplementedError(f"Activation {activation} is not supported") + layers = [] + # check for multimodal models first + if hasattr(model, "language_model"): + layers = model.language_model.model.layers + elif hasattr(model, "model"): + layers = model.model.model.layers + else: + raise NotImplementedError( + f"Model type {model.config.model_type} is not supported yet. Please create an Issue." + ) + # Patch each layer - for layer in model.model.model.layers: + for layer in layers: # Add QKV, O fallback implementations to start # These will be overwritten later (if some conditions apply) layer.self_attn.apply_qkv = types.MethodType(