fix: correct attention class retrieval for gemma3n model in lora_kernels.py
This commit is contained in:
@@ -158,9 +158,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
)
|
)
|
||||||
if model_type == "gemma3n":
|
if model_type == "gemma3n":
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}TextAttention"])
|
||||||
|
attention_cls = getattr(module, f"{model_cls_prefix}TextAttention")
|
||||||
else:
|
else:
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
||||||
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
||||||
|
|
||||||
return attention_cls
|
return attention_cls
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user