diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d5184def6..8617ca33a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,6 +24,7 @@ jobs: - name: Install dependencies run: | + pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu118 pip install -e . pip install -r requirements-tests.txt diff --git a/docker/Dockerfile b/docker/Dockerfile index 8608e2348..683ca75ff 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -15,9 +15,9 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git # If AXOLOTL_EXTRAS is set, append it in brackets RUN cd axolotl && \ if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ + pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \ else \ - pip install -e .[flash-attn]; \ + pip install -e .[flash-attn,gptq]; \ fi # fix so that git fetch/pull from remote works diff --git a/requirements.txt b/requirements.txt index 6e91c1428..7bd5cea9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ ---extra-index-url https://huggingface.github.io/autogptq-index/whl/cu117/ +--extra-index-url https://download.pytorch.org/whl/cu118 +--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ +torch==2.0.1 auto-gptq packaging peft @ git+https://github.com/huggingface/peft.git diff --git a/setup.py b/setup.py index 100de39b7..973d656cd 100644 --- a/setup.py +++ b/setup.py @@ -2,15 +2,27 @@ from setuptools import find_packages, setup -install_requires = [] -with open("./requirements.txt", encoding="utf-8") as requirements_file: - # don't include peft yet until we check the int4 - # need to manually install peft for now... - reqs = [r.strip() for r in requirements_file.readlines()] - reqs = [r for r in reqs if "flash-attn" not in r] - reqs = [r for r in reqs if r and r[0] != "#"] - for r in reqs: - install_requires.append(r) + +def parse_requirements(): + _install_requires = [] + _dependency_links = [] + with open("./requirements.txt", encoding="utf-8") as requirements_file: + lines = [ + r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r + ] + for line in lines: + if line.startswith("--extra-index-url"): + # Handle custom index URLs + _, url = line.split() + _dependency_links.append(url) + elif "flash-attn" not in line and line and line[0] != "#": + # Handle standard packages + _install_requires.append(line) + return _install_requires, _dependency_links + + +install_requires, dependency_links = parse_requirements() + setup( name="axolotl", @@ -19,7 +31,11 @@ setup( package_dir={"": "src"}, packages=find_packages(), install_requires=install_requires, + dependency_links=dependency_links, extras_require={ + "gptq": [ + "auto-gptq", + ], "flash-attn": [ "flash-attn==2.0.8", ],