fix: warn user to install mamba_ssm package (#1019)
This commit is contained in:
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn]; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging
|
packaging==23.2
|
||||||
peft==0.7.0
|
peft==0.7.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
@@ -34,6 +34,8 @@ fschat==0.2.34
|
|||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
||||||
|
mamba-ssm==1.1.1
|
||||||
|
|
||||||
# remote filesystems
|
# remote filesystems
|
||||||
s3fs
|
s3fs
|
||||||
gcsfs
|
gcsfs
|
||||||
|
|||||||
14
setup.py
14
setup.py
@@ -11,17 +11,17 @@ def parse_requirements():
|
|||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
is_extras = (
|
||||||
|
"flash-attn" in line
|
||||||
|
or "flash-attention" in line
|
||||||
|
or "deepspeed" in line
|
||||||
|
or "mamba-ssm" in line
|
||||||
|
)
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
_, url = line.split()
|
_, url = line.split()
|
||||||
_dependency_links.append(url)
|
_dependency_links.append(url)
|
||||||
elif (
|
elif not is_extras and line and line[0] != "#":
|
||||||
"flash-attn" not in line
|
|
||||||
and "flash-attention" not in line
|
|
||||||
and "deepspeed" not in line
|
|
||||||
and line
|
|
||||||
and line[0] != "#"
|
|
||||||
):
|
|
||||||
# Handle standard packages
|
# Handle standard packages
|
||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,20 @@
|
|||||||
Modeling module for Mamba models
|
Modeling module for Mamba models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
def check_mamba_ssm_installed():
|
||||||
|
mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
|
||||||
|
if mamba_ssm_spec is None:
|
||||||
|
raise ImportError(
|
||||||
|
"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fix_mamba_attn_for_loss():
|
def fix_mamba_attn_for_loss():
|
||||||
|
check_mamba_ssm_installed()
|
||||||
|
|
||||||
from mamba_ssm.models import mixer_seq_simple
|
from mamba_ssm.models import mixer_seq_simple
|
||||||
|
|
||||||
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
||||||
|
|||||||
Reference in New Issue
Block a user