fix: correct attention class retrieval for gemma3n model in lora_kernels.py
This commit is contained in:
@@ -158,6 +158,7 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
)
|
)
|
||||||
if model_type == "gemma3n":
|
if model_type == "gemma3n":
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
|
||||||
|
attention_cls = getattr(module, f"{model_cls_prefix}TextAttention")
|
||||||
else:
|
else:
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
||||||
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
||||||
|
|||||||
Reference in New Issue
Block a user