diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 586412dd7..299902dc5 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -156,7 +156,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: model_cls_prefix = "".join( [part.capitalize() for part in model_type.split("_")] ) - module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) + if model_type == "gemma3n": + module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"]) + else: + module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) attention_cls = getattr(module, f"{model_cls_prefix}Attention") return attention_cls