fix: warn user to install mamba_ssm package (#1019)

This commit is contained in:
NanoCode012
2024-01-10 16:50:56 +09:00
committed by GitHub
parent 9e3f0cb5a7
commit d69ba2b0b7
4 changed files with 24 additions and 10 deletions

View File

@@ -2,8 +2,20 @@
Modeling module for Mamba models
"""
import importlib
def check_mamba_ssm_installed():
mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
if mamba_ssm_spec is None:
raise ImportError(
"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`"
)
def fix_mamba_attn_for_loss():
check_mamba_ssm_installed()
from mamba_ssm.models import mixer_seq_simple
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed