fix: correct attention class retrieval for gemma3n model in lora_kernels.py
This commit is contained in:
@@ -158,9 +158,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
)
|
||||
if model_type == "gemma3n":
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
|
||||
attention_cls = getattr(module, f"{model_cls_prefix}TextAttention")
|
||||
else:
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
||||
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
||||
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
||||
|
||||
return attention_cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
|
||||
Reference in New Issue
Block a user