fix: update attention class import logic for gemma3n model
This commit is contained in:
@@ -156,7 +156,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
model_cls_prefix = "".join(
|
model_cls_prefix = "".join(
|
||||||
[part.capitalize() for part in model_type.split("_")]
|
[part.capitalize() for part in model_type.split("_")]
|
||||||
)
|
)
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
if model_type == "gemma3n":
|
||||||
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
|
||||||
|
else:
|
||||||
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
||||||
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
||||||
|
|
||||||
return attention_cls
|
return attention_cls
|
||||||
|
|||||||
Reference in New Issue
Block a user