Compare commits

...

2 Commits

Author SHA1 Message Date
mhenrhcsen
8eba033dc4 fix: correct attention class retrieval for gemma3n model in lora_kernels.py 2025-06-27 19:30:09 +02:00
mhenrhcsen
a9c0f43202 fix: update attention class import logic for gemma3n model 2025-06-27 19:27:36 +02:00

View File

@@ -156,8 +156,12 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
if model_type == "gemma3n":
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
attention_cls = getattr(module, f"{model_cls_prefix}TextAttention")
else:
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls
except (ImportError, AttributeError) as e: