diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 299902dc5..0eae26488 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -158,9 +158,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: ) if model_type == "gemma3n": module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"]) + attention_cls = getattr(module, f"{model_cls_prefix}TextAttention") else: module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) - attention_cls = getattr(module, f"{model_cls_prefix}Attention") + attention_cls = getattr(module, f"{model_cls_prefix}Attention") return attention_cls except (ImportError, AttributeError) as e: