fix: update attention class import logic for gemma3n model
This commit is contained in:
@@ -156,7 +156,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
model_cls_prefix = "".join(
|
||||
[part.capitalize() for part in model_type.split("_")]
|
||||
)
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
||||
if model_type == "gemma3n":
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
|
||||
else:
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
||||
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
||||
|
||||
return attention_cls
|
||||
|
||||
Reference in New Issue
Block a user