fix: correct attention class retrieval for gemma3n model in lora_kernels.py

This commit is contained in:
mhenrhcsen
2025-06-27 19:30:09 +02:00
parent a9c0f43202
commit 8eba033dc4

View File

@@ -158,9 +158,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
)
if model_type == "gemma3n":
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
attention_cls = getattr(module, f"{model_cls_prefix}TextAttention")
else:
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls
except (ImportError, AttributeError) as e: