fix: warn user to install mamba_ssm package (#1019)

This commit is contained in:
NanoCode012
2024-01-10 16:50:56 +09:00
committed by GitHub
parent 9e3f0cb5a7
commit d69ba2b0b7
4 changed files with 24 additions and 10 deletions

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[deepspeed,flash-attn]; \ pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -1,5 +1,5 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging packaging==23.2
peft==0.7.0 peft==0.7.0
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0 transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
tokenizers==0.15.0 tokenizers==0.15.0
@@ -34,6 +34,8 @@ fschat==0.2.34
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
mamba-ssm==1.1.1
# remote filesystems # remote filesystems
s3fs s3fs
gcsfs gcsfs

View File

@@ -11,17 +11,17 @@ def parse_requirements():
with open("./requirements.txt", encoding="utf-8") as requirements_file: with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()] lines = [r.strip() for r in requirements_file.readlines()]
for line in lines: for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
)
if line.startswith("--extra-index-url"): if line.startswith("--extra-index-url"):
# Handle custom index URLs # Handle custom index URLs
_, url = line.split() _, url = line.split()
_dependency_links.append(url) _dependency_links.append(url)
elif ( elif not is_extras and line and line[0] != "#":
"flash-attn" not in line
and "flash-attention" not in line
and "deepspeed" not in line
and line
and line[0] != "#"
):
# Handle standard packages # Handle standard packages
_install_requires.append(line) _install_requires.append(line)

View File

@@ -2,8 +2,20 @@
Modeling module for Mamba models Modeling module for Mamba models
""" """
import importlib
def check_mamba_ssm_installed():
mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
if mamba_ssm_spec is None:
raise ImportError(
"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`"
)
def fix_mamba_attn_for_loss(): def fix_mamba_attn_for_loss():
check_mamba_ssm_installed()
from mamba_ssm.models import mixer_seq_simple from mamba_ssm.models import mixer_seq_simple
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed