diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e679f4101..130ac6e7b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -49,14 +49,18 @@ jobs: python-version: ${{ matrix.python_version }} cache: 'pip' # caching pip dependencies + - name: upgrade pip + run: | + pip3 install --upgrade pip + pip3 install --upgrade packaging setuptools wheel + - name: Install PyTorch run: | - pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu + pip3 install torch==${{ matrix.pytorch_version }} - name: Install dependencies run: | - pip3 install --upgrade pip - pip3 install --upgrade packaging + pip3 show torch pip3 install -U -e . pip3 install -r requirements-tests.txt diff --git a/requirements.txt b/requirements.txt index 8f9f55262..067be05cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ flash-attn==2.6.3 sentencepiece wandb einops -xformers==0.0.28.post1 +xformers>=0.0.23.post1 optimum==1.16.2 hf_transfer colorama diff --git a/setup.py b/setup.py index 1153d6968..17347f063 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,8 @@ def parse_requirements(): try: xformers_version = [req for req in _install_requires if "xformers" in req][0] torchao_version = [req for req in _install_requires if "torchao" in req][0] + autoawq_version = [req for req in _install_requires if "autoawq" in req][0] + if "Darwin" in platform.system(): # don't install xformers on MacOS _install_requires.pop(_install_requires.index(xformers_version)) @@ -52,10 +54,14 @@ def parse_requirements(): if (major, minor) >= (2, 5): _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.pop(_install_requires.index(autoawq_version)) elif (major, minor) >= (2, 4): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") + else: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers==0.0.28.post1") elif (major, minor) >= (2, 3): _install_requires.pop(_install_requires.index(torchao_version)) if patch == 0: @@ -75,7 +81,6 @@ def parse_requirements(): except PackageNotFoundError: pass - return _install_requires, _dependency_links