From a9c0f4320228730d196a7fce24384fb35e25afa7 Mon Sep 17 00:00:00 2001 From: mhenrhcsen Date: Fri, 27 Jun 2025 19:27:36 +0200 Subject: [PATCH] fix: update attention class import logic for gemma3n model --- src/axolotl/monkeypatch/lora_kernels.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 586412dd7..299902dc5 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -156,7 +156,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: model_cls_prefix = "".join( [part.capitalize() for part in model_type.split("_")] ) - module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) + if model_type == "gemma3n": + module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"]) + else: + module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) attention_cls = getattr(module, f"{model_cls_prefix}Attention") return attention_cls