fix setup.py to use extra index url

install torch for tests
fix cuda version for autogptq index
set torch in requirements so that it installs properly
move gptq install around to work with github cicd
This commit is contained in:
Wing Lian
2023-08-23 22:08:11 -04:00
parent caa80e891d
commit 588cd65a64
4 changed files with 31 additions and 12 deletions

View File

@@ -24,6 +24,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e . pip install -e .
pip install -r requirements-tests.txt pip install -r requirements-tests.txt

View File

@@ -15,9 +15,9 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \ RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[flash-attn]; \ pip install -e .[flash-attn,gptq]; \
fi fi
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works

View File

@@ -1,4 +1,6 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu117/ --extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
torch==2.0.1
auto-gptq auto-gptq
packaging packaging
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git

View File

@@ -2,15 +2,27 @@
from setuptools import find_packages, setup from setuptools import find_packages, setup
install_requires = []
with open("./requirements.txt", encoding="utf-8") as requirements_file: def parse_requirements():
# don't include peft yet until we check the int4 _install_requires = []
# need to manually install peft for now... _dependency_links = []
reqs = [r.strip() for r in requirements_file.readlines()] with open("./requirements.txt", encoding="utf-8") as requirements_file:
reqs = [r for r in reqs if "flash-attn" not in r] lines = [
reqs = [r for r in reqs if r and r[0] != "#"] r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r
for r in reqs: ]
install_requires.append(r) for line in lines:
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif "flash-attn" not in line and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
return _install_requires, _dependency_links
install_requires, dependency_links = parse_requirements()
setup( setup(
name="axolotl", name="axolotl",
@@ -19,7 +31,11 @@ setup(
package_dir={"": "src"}, package_dir={"": "src"},
packages=find_packages(), packages=find_packages(),
install_requires=install_requires, install_requires=install_requires,
dependency_links=dependency_links,
extras_require={ extras_require={
"gptq": [
"auto-gptq",
],
"flash-attn": [ "flash-attn": [
"flash-attn==2.0.8", "flash-attn==2.0.8",
], ],