fix: update attention class import logic for gemma3n model

This commit is contained in:
mhenrhcsen
2025-06-27 19:27:36 +02:00
parent a1a740608d
commit a9c0f43202

View File

@@ -156,7 +156,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
if model_type == "gemma3n":
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
else:
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls