From d69ba2b0b76fad112acecd5a1fbb339e6244ff7b Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 10 Jan 2024 16:50:56 +0900 Subject: [PATCH] fix: warn user to install mamba_ssm package (#1019) --- docker/Dockerfile | 4 ++-- requirements.txt | 4 +++- setup.py | 14 +++++++------- src/axolotl/models/mamba/__init__.py | 12 ++++++++++++ 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index f8e052856..efc40ab06 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \ else \ - pip install -e .[deepspeed,flash-attn]; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm]; \ fi # So we can test the Docker image diff --git a/requirements.txt b/requirements.txt index 33ddd395d..b2595de50 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ -packaging +packaging==23.2 peft==0.7.0 transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0 tokenizers==0.15.0 @@ -34,6 +34,8 @@ fschat==0.2.34 gradio==3.50.2 tensorboard +mamba-ssm==1.1.1 + # remote filesystems s3fs gcsfs diff --git a/setup.py b/setup.py index 874f12608..235018dcc 100644 --- a/setup.py +++ b/setup.py @@ -11,17 +11,17 @@ def parse_requirements(): with open("./requirements.txt", encoding="utf-8") as requirements_file: lines = [r.strip() for r in requirements_file.readlines()] for line in lines: + is_extras = ( + "flash-attn" in line + or "flash-attention" in line + or "deepspeed" in line + or "mamba-ssm" in line + ) if line.startswith("--extra-index-url"): # Handle custom index URLs _, url = line.split() _dependency_links.append(url) - elif ( - "flash-attn" not in line - and "flash-attention" not in line - and "deepspeed" not in line - and line - and line[0] != "#" - ): + elif not is_extras and line and line[0] != "#": # Handle standard packages _install_requires.append(line) diff --git a/src/axolotl/models/mamba/__init__.py b/src/axolotl/models/mamba/__init__.py index 247c1d184..fee88e3a4 100644 --- a/src/axolotl/models/mamba/__init__.py +++ b/src/axolotl/models/mamba/__init__.py @@ -2,8 +2,20 @@ Modeling module for Mamba models """ +import importlib + + +def check_mamba_ssm_installed(): + mamba_ssm_spec = importlib.util.find_spec("mamba_ssm") + if mamba_ssm_spec is None: + raise ImportError( + "MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`" + ) + def fix_mamba_attn_for_loss(): + check_mamba_ssm_installed() + from mamba_ssm.models import mixer_seq_simple from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed