Revert "feat: move to uv first" (#3544)
This reverts commit 1f1ebb8237.
This commit is contained in:
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -6,7 +6,7 @@ on:
|
|||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'pyproject.toml'
|
- 'requirements.txt'
|
||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
- "*.[q]md"
|
- "*.[q]md"
|
||||||
- "examples/**/*.y[a]?ml"
|
- "examples/**/*.y[a]?ml"
|
||||||
|
|||||||
3
.github/workflows/multi-gpu-e2e.yml
vendored
3
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -4,7 +4,8 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'tests/e2e/multigpu/**.py'
|
- 'tests/e2e/multigpu/**.py'
|
||||||
- 'pyproject.toml'
|
- 'requirements.txt'
|
||||||
|
- 'setup.py'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- '.github/workflows/multi-gpu-e2e.yml'
|
- '.github/workflows/multi-gpu-e2e.yml'
|
||||||
- 'scripts/cutcrossentropy_install.py'
|
- 'scripts/cutcrossentropy_install.py'
|
||||||
|
|||||||
17
.github/workflows/tests-nightly.yml
vendored
17
.github/workflows/tests-nightly.yml
vendored
@@ -72,6 +72,14 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||||
|
|
||||||
|
- name: Update requirements.txt
|
||||||
|
run: |
|
||||||
|
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
||||||
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||||
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||||
|
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||||
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
pip3 show torch
|
||||||
@@ -80,15 +88,6 @@ jobs:
|
|||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Override with nightly HF packages
|
|
||||||
run: |
|
|
||||||
pip3 install --no-deps \
|
|
||||||
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
|
||||||
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
|
||||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
|
||||||
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
|
||||||
"datasets @ git+https://github.com/huggingface/datasets.git@main"
|
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
run: |
|
run: |
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||||
|
|||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -7,7 +7,7 @@ on:
|
|||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'pyproject.toml'
|
- 'requirements.txt'
|
||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
- 'requirements-tests.txt'
|
- 'requirements-tests.txt'
|
||||||
- 'cicd/cicd.sh'
|
- 'cicd/cicd.sh'
|
||||||
@@ -16,7 +16,7 @@ on:
|
|||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'pyproject.toml'
|
- 'requirements.txt'
|
||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
- 'requirements-tests.txt'
|
- 'requirements-tests.txt'
|
||||||
- 'cicd/cicd.sh'
|
- 'cicd/cicd.sh'
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
include requirements.txt
|
||||||
include README.md
|
include README.md
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include VERSION
|
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||||
|
recursive-include axolotl *.py
|
||||||
|
|||||||
178
pyproject.toml
178
pyproject.toml
@@ -1,143 +1,15 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "axolotl"
|
name = "axolotl"
|
||||||
dynamic = ["version"]
|
dynamic = ["version", "dependencies", "optional-dependencies"]
|
||||||
description = "LLM Trainer"
|
description = "LLM Trainer"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
# license = "Apache-2.0"
|
# license = "Apache-2.0"
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
# Core ML stack
|
|
||||||
"torch>=2.6.0",
|
|
||||||
"packaging==26.0",
|
|
||||||
"huggingface_hub>=1.1.7",
|
|
||||||
"peft>=0.18.1",
|
|
||||||
"tokenizers>=0.22.1",
|
|
||||||
"transformers==5.3.0",
|
|
||||||
"accelerate==1.13.0",
|
|
||||||
"datasets==4.5.0",
|
|
||||||
"trl==0.29.0",
|
|
||||||
"hf_xet==1.3.2",
|
|
||||||
"kernels==0.12.2",
|
|
||||||
"trackio>=0.16.1",
|
|
||||||
"typing-extensions>=4.15.0",
|
|
||||||
"optimum==1.16.2",
|
|
||||||
"hf_transfer",
|
|
||||||
"sentencepiece",
|
|
||||||
"gradio>=6.2.0,<7.0",
|
|
||||||
"modal==1.3.0.post1",
|
|
||||||
"pydantic>=2.10.6",
|
|
||||||
"addict",
|
|
||||||
"fire",
|
|
||||||
"PyYAML>=6.0",
|
|
||||||
"requests",
|
|
||||||
"wandb",
|
|
||||||
"einops",
|
|
||||||
"colorama",
|
|
||||||
"numba>=0.61.2",
|
|
||||||
"numpy>=2.2.6",
|
|
||||||
|
|
||||||
# Evaluation & metrics
|
|
||||||
"evaluate==0.4.1",
|
|
||||||
"scipy",
|
|
||||||
"nvidia-ml-py==12.560.30",
|
|
||||||
"art",
|
|
||||||
"tensorboard",
|
|
||||||
"python-dotenv==1.0.1",
|
|
||||||
|
|
||||||
# Remote filesystems
|
|
||||||
"s3fs>=2024.5.0",
|
|
||||||
"gcsfs>=2025.3.0",
|
|
||||||
"adlfs>=2024.5.0",
|
|
||||||
"ocifs==1.3.2",
|
|
||||||
|
|
||||||
"zstandard==0.22.0",
|
|
||||||
"fastcore",
|
|
||||||
|
|
||||||
# lm eval harness
|
|
||||||
"lm_eval==0.4.7",
|
|
||||||
"langdetect==1.0.9",
|
|
||||||
"immutabledict==4.2.0",
|
|
||||||
"antlr4-python3-runtime==4.13.2",
|
|
||||||
|
|
||||||
"schedulefree==1.4.1",
|
|
||||||
"openenv-core==0.1.0",
|
|
||||||
|
|
||||||
# Axolotl contribs
|
|
||||||
"axolotl-contribs-lgpl==0.0.7",
|
|
||||||
"axolotl-contribs-mit==0.0.6",
|
|
||||||
|
|
||||||
# Telemetry
|
|
||||||
"posthog==6.7.11",
|
|
||||||
|
|
||||||
"mistral-common==1.10.0",
|
|
||||||
|
|
||||||
# Platform-specific (Linux only)
|
|
||||||
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
|
|
||||||
"triton>=3.4.0 ; sys_platform != 'darwin'",
|
|
||||||
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
|
|
||||||
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
|
|
||||||
"torchao==0.16.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
|
||||||
|
|
||||||
# Architecture-specific
|
|
||||||
"fla-core==0.4.1 ; platform_machine != 'aarch64'",
|
|
||||||
"flash-linear-attention==0.4.1 ; platform_machine != 'aarch64'",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
flash-attn = ["flash-attn==2.8.3"]
|
|
||||||
ring-flash-attn = [
|
|
||||||
"flash-attn==2.8.3",
|
|
||||||
"ring-flash-attn>=0.1.7",
|
|
||||||
]
|
|
||||||
deepspeed = [
|
|
||||||
"deepspeed>=0.18.6,<0.19.0",
|
|
||||||
"deepspeed-kernels",
|
|
||||||
]
|
|
||||||
mamba-ssm = [
|
|
||||||
"mamba-ssm==1.2.0.post1",
|
|
||||||
"causal_conv1d",
|
|
||||||
]
|
|
||||||
auto-gptq = [
|
|
||||||
"auto-gptq==0.5.1",
|
|
||||||
]
|
|
||||||
mlflow = [
|
|
||||||
"mlflow",
|
|
||||||
]
|
|
||||||
galore = [
|
|
||||||
"galore_torch",
|
|
||||||
]
|
|
||||||
apollo = [
|
|
||||||
"apollo-torch",
|
|
||||||
]
|
|
||||||
optimizers = [
|
|
||||||
"galore_torch",
|
|
||||||
"apollo-torch",
|
|
||||||
"lomo-optim==0.1.1",
|
|
||||||
"torch-optimi==0.2.1",
|
|
||||||
"came_pytorch==0.1.3",
|
|
||||||
]
|
|
||||||
ray = [
|
|
||||||
"ray[train]>=2.52.1",
|
|
||||||
]
|
|
||||||
vllm = [
|
|
||||||
"vllm>=0.10.0",
|
|
||||||
]
|
|
||||||
llmcompressor = [
|
|
||||||
"llmcompressor>=0.10.0",
|
|
||||||
]
|
|
||||||
fbgemm-gpu = ["fbgemm-gpu-genai>=1.3.0"]
|
|
||||||
opentelemetry = [
|
|
||||||
"opentelemetry-api",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"opentelemetry-exporter-prometheus",
|
|
||||||
"prometheus-client",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
axolotl = "axolotl.cli.main:main"
|
axolotl = "axolotl.cli.main:main"
|
||||||
|
|
||||||
@@ -146,15 +18,18 @@ Homepage = "https://axolotl.ai/"
|
|||||||
Documentation = "https://docs.axolotl.ai/"
|
Documentation = "https://docs.axolotl.ai/"
|
||||||
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools_scm]
|
||||||
include-package-data = true
|
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools]
|
||||||
where = ["src"]
|
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||||
|
include-package-data = true
|
||||||
|
|
||||||
[tool.setuptools.dynamic]
|
[tool.setuptools.dynamic]
|
||||||
version = { file = "VERSION" }
|
version = { file = "VERSION" }
|
||||||
|
|
||||||
|
[tool.setuptools.cmdclass]
|
||||||
|
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 88
|
line-length = 88
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
@@ -192,40 +67,5 @@ markers = [
|
|||||||
"slow: marks tests as slow",
|
"slow: marks tests as slow",
|
||||||
]
|
]
|
||||||
|
|
||||||
# UV specific configuration
|
|
||||||
[tool.uv]
|
|
||||||
prerelease = "allow"
|
|
||||||
conflicts = [
|
|
||||||
[
|
|
||||||
{ package = "axolotl" },
|
|
||||||
{ extra = "vllm" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ package = "axolotl" },
|
|
||||||
{ extra = "flash-attn" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ package = "axolotl" },
|
|
||||||
{ extra = "ring-flash-attn" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ package = "axolotl" },
|
|
||||||
{ extra = "mamba-ssm" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ package = "axolotl" },
|
|
||||||
{ extra = "auto-gptq" },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
{ package = "axolotl" },
|
|
||||||
{ extra = "fbgemm-gpu" },
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.extra-build-dependencies]
|
[tool.uv.extra-build-dependencies]
|
||||||
axolotl = ["huggingface_hub"]
|
axolotl = ["huggingface_hub"]
|
||||||
mamba-ssm = ["torch"]
|
|
||||||
causal-conv1d = ["torch"]
|
|
||||||
flash-attn = ["torch"]
|
|
||||||
deepspeed = ["torch"]
|
|
||||||
auto-gptq = ["torch"]
|
|
||||||
|
|||||||
78
requirements.txt
Normal file
78
requirements.txt
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
|
bitsandbytes==0.49.1
|
||||||
|
triton>=3.4.0
|
||||||
|
mamba-ssm==1.2.0.post1
|
||||||
|
xformers>=0.0.23.post1
|
||||||
|
liger-kernel==0.7.0
|
||||||
|
# END section
|
||||||
|
|
||||||
|
packaging==26.0
|
||||||
|
huggingface_hub>=1.1.7
|
||||||
|
peft>=0.18.1
|
||||||
|
tokenizers>=0.22.1
|
||||||
|
transformers==5.3.0
|
||||||
|
accelerate==1.13.0
|
||||||
|
datasets==4.5.0
|
||||||
|
deepspeed>=0.18.6,<0.19.0
|
||||||
|
trl==0.29.0
|
||||||
|
hf_xet==1.3.2
|
||||||
|
kernels==0.12.2
|
||||||
|
|
||||||
|
fla-core==0.4.1
|
||||||
|
flash-linear-attention==0.4.1
|
||||||
|
|
||||||
|
trackio>=0.16.1
|
||||||
|
typing-extensions>=4.15.0
|
||||||
|
|
||||||
|
optimum==1.16.2
|
||||||
|
hf_transfer
|
||||||
|
sentencepiece
|
||||||
|
gradio>=6.2.0,<7.0
|
||||||
|
|
||||||
|
modal==1.3.0.post1
|
||||||
|
pydantic>=2.10.6
|
||||||
|
addict
|
||||||
|
fire
|
||||||
|
PyYAML>=6.0
|
||||||
|
requests
|
||||||
|
wandb
|
||||||
|
einops
|
||||||
|
colorama
|
||||||
|
numba>=0.61.2
|
||||||
|
numpy>=2.2.6
|
||||||
|
|
||||||
|
# qlora things
|
||||||
|
evaluate==0.4.1
|
||||||
|
scipy
|
||||||
|
nvidia-ml-py==12.560.30
|
||||||
|
art
|
||||||
|
tensorboard
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
|
||||||
|
# remote filesystems
|
||||||
|
s3fs>=2024.5.0
|
||||||
|
gcsfs>=2025.3.0
|
||||||
|
adlfs>=2024.5.0
|
||||||
|
ocifs==1.3.2
|
||||||
|
|
||||||
|
zstandard==0.22.0
|
||||||
|
fastcore
|
||||||
|
|
||||||
|
# lm eval harness
|
||||||
|
lm_eval==0.4.7
|
||||||
|
langdetect==1.0.9
|
||||||
|
immutabledict==4.2.0
|
||||||
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
|
torchao==0.16.0
|
||||||
|
openenv-core==0.1.0
|
||||||
|
schedulefree==1.4.1
|
||||||
|
|
||||||
|
axolotl-contribs-lgpl==0.0.7
|
||||||
|
axolotl-contribs-mit==0.0.6
|
||||||
|
# telemetry
|
||||||
|
posthog==6.7.11
|
||||||
|
|
||||||
|
mistral-common==1.10.0
|
||||||
230
setup.py
Normal file
230
setup.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
def parse_requirements(extras_require_map):
|
||||||
|
_install_requires = []
|
||||||
|
_dependency_links = []
|
||||||
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
|
for line in lines:
|
||||||
|
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
||||||
|
if line.startswith("--extra-index-url"):
|
||||||
|
# Handle custom index URLs
|
||||||
|
_, url = line.split()
|
||||||
|
_dependency_links.append(url)
|
||||||
|
elif not is_extras and line and line[0] != "#":
|
||||||
|
# Handle standard packages
|
||||||
|
_install_requires.append(line)
|
||||||
|
try:
|
||||||
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
|
install_xformers = platform.machine() != "aarch64"
|
||||||
|
if platform.machine() == "aarch64":
|
||||||
|
# skip on ARM64
|
||||||
|
skip_packages = [
|
||||||
|
"torchao",
|
||||||
|
"fla-core",
|
||||||
|
"flash-linear-attention",
|
||||||
|
]
|
||||||
|
_install_requires = [
|
||||||
|
req
|
||||||
|
for req in _install_requires
|
||||||
|
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||||
|
]
|
||||||
|
if "Darwin" in platform.system():
|
||||||
|
# skip packages not compatible with OSX
|
||||||
|
skip_packages = [
|
||||||
|
"bitsandbytes",
|
||||||
|
"triton",
|
||||||
|
"mamba-ssm",
|
||||||
|
"xformers",
|
||||||
|
"liger-kernel",
|
||||||
|
]
|
||||||
|
_install_requires = [
|
||||||
|
req
|
||||||
|
for req in _install_requires
|
||||||
|
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||||
|
]
|
||||||
|
print(
|
||||||
|
_install_requires, [req in skip_packages for req in _install_requires]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# detect the version of torch already installed
|
||||||
|
# and set it so dependencies don't clobber the torch version
|
||||||
|
try:
|
||||||
|
torch_version = version("torch")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
torch_version = "2.8.0" # default to torch 2.8.0
|
||||||
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
|
if version_match:
|
||||||
|
major, minor, patch = version_match.groups()
|
||||||
|
major, minor = int(major), int(minor)
|
||||||
|
patch = (
|
||||||
|
int(patch) if patch is not None else 0
|
||||||
|
) # Default patch to 0 if not present
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
|
torch_parts = torch_version.split("+")
|
||||||
|
if len(torch_parts) == 2:
|
||||||
|
torch_cuda_version = torch_parts[1]
|
||||||
|
_dependency_links.append(
|
||||||
|
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (major, minor) >= (2, 10):
|
||||||
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
|
extras_require_map["fbgemm-gpu"] = [
|
||||||
|
"fbgemm-gpu==1.5.0",
|
||||||
|
"fbgemm-gpu-genai==1.5.0",
|
||||||
|
]
|
||||||
|
if not install_xformers:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
extras_require_map["vllm"] = ["vllm>=0.17.1"]
|
||||||
|
elif (major, minor) >= (2, 9):
|
||||||
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
|
extras_require_map["fbgemm-gpu"] = [
|
||||||
|
"fbgemm-gpu==1.4.0",
|
||||||
|
"fbgemm-gpu-genai==1.4.2",
|
||||||
|
]
|
||||||
|
if not install_xformers:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if patch == 0:
|
||||||
|
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||||
|
else:
|
||||||
|
extras_require_map["vllm"] = ["vllm==0.14.0"]
|
||||||
|
elif (major, minor) >= (2, 8):
|
||||||
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
|
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||||
|
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
||||||
|
if not install_xformers:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
elif (major, minor) >= (2, 7):
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if patch == 0:
|
||||||
|
if install_xformers:
|
||||||
|
_install_requires.append("xformers==0.0.30")
|
||||||
|
# vllm 0.9.x is incompatible with latest transformers
|
||||||
|
extras_require_map.pop("vllm")
|
||||||
|
else:
|
||||||
|
if install_xformers:
|
||||||
|
_install_requires.append("xformers==0.0.31")
|
||||||
|
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
||||||
|
elif (major, minor) >= (2, 6):
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if install_xformers:
|
||||||
|
_install_requires.append("xformers==0.0.29.post3")
|
||||||
|
# since we only support 2.6.0+cu126
|
||||||
|
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||||
|
extras_require_map.pop("vllm")
|
||||||
|
elif (major, minor) >= (2, 5):
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if install_xformers:
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.append("xformers==0.0.28.post2")
|
||||||
|
else:
|
||||||
|
_install_requires.append("xformers>=0.0.28.post3")
|
||||||
|
extras_require_map.pop("vllm")
|
||||||
|
elif (major, minor) >= (2, 4):
|
||||||
|
extras_require_map.pop("vllm")
|
||||||
|
if install_xformers:
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.27")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers==0.0.28.post1")
|
||||||
|
else:
|
||||||
|
raise ValueError("axolotl requires torch>=2.4")
|
||||||
|
|
||||||
|
except PackageNotFoundError:
|
||||||
|
pass
|
||||||
|
return _install_requires, _dependency_links, extras_require_map
|
||||||
|
|
||||||
|
|
||||||
|
def get_package_version():
|
||||||
|
with open(
|
||||||
|
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
|
||||||
|
"r",
|
||||||
|
encoding="utf-8",
|
||||||
|
) as fin:
|
||||||
|
version_ = fin.read().strip()
|
||||||
|
return version_
|
||||||
|
|
||||||
|
|
||||||
|
extras_require = {
|
||||||
|
"flash-attn": ["flash-attn==2.8.3"],
|
||||||
|
"ring-flash-attn": [
|
||||||
|
"flash-attn==2.8.3",
|
||||||
|
"ring-flash-attn>=0.1.7",
|
||||||
|
],
|
||||||
|
"deepspeed": [
|
||||||
|
"deepspeed==0.18.2",
|
||||||
|
"deepspeed-kernels",
|
||||||
|
],
|
||||||
|
"mamba-ssm": [
|
||||||
|
"mamba-ssm==1.2.0.post1",
|
||||||
|
"causal_conv1d",
|
||||||
|
],
|
||||||
|
"auto-gptq": [
|
||||||
|
"auto-gptq==0.5.1",
|
||||||
|
],
|
||||||
|
"mlflow": [
|
||||||
|
"mlflow",
|
||||||
|
],
|
||||||
|
"galore": [
|
||||||
|
"galore_torch",
|
||||||
|
],
|
||||||
|
"apollo": [
|
||||||
|
"apollo-torch",
|
||||||
|
],
|
||||||
|
"optimizers": [
|
||||||
|
"galore_torch",
|
||||||
|
"apollo-torch",
|
||||||
|
"lomo-optim==0.1.1",
|
||||||
|
"torch-optimi==0.2.1",
|
||||||
|
"came_pytorch==0.1.3",
|
||||||
|
],
|
||||||
|
"ray": [
|
||||||
|
"ray[train]>=2.52.1",
|
||||||
|
],
|
||||||
|
"vllm": [
|
||||||
|
"vllm==0.10.0",
|
||||||
|
],
|
||||||
|
"llmcompressor": [
|
||||||
|
"llmcompressor==0.5.1",
|
||||||
|
],
|
||||||
|
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
|
||||||
|
"opentelemetry": [
|
||||||
|
"opentelemetry-api",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"opentelemetry-exporter-prometheus",
|
||||||
|
"prometheus-client",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||||
|
extras_require
|
||||||
|
)
|
||||||
|
|
||||||
|
setup(
|
||||||
|
version=get_package_version(),
|
||||||
|
package_dir={"": "src"},
|
||||||
|
packages=find_packages("src"),
|
||||||
|
install_requires=install_requires,
|
||||||
|
dependency_links=dependency_links,
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
"axolotl=axolotl.cli.main:main",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
extras_require=extras_require_build,
|
||||||
|
)
|
||||||
102
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
102
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""
|
||||||
|
dynamic requirements for axolotl
|
||||||
|
"""
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
|
from setuptools.command.build_py import build_py as _build_py
|
||||||
|
|
||||||
|
|
||||||
|
def parse_requirements():
|
||||||
|
_install_requires = []
|
||||||
|
_dependency_links = []
|
||||||
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
|
for line in lines:
|
||||||
|
is_extras = (
|
||||||
|
"flash-attn" in line
|
||||||
|
or "flash-attention" in line
|
||||||
|
or "deepspeed" in line
|
||||||
|
or "mamba-ssm" in line
|
||||||
|
or "lion-pytorch" in line
|
||||||
|
)
|
||||||
|
if line.startswith("--extra-index-url"):
|
||||||
|
# Handle custom index URLs
|
||||||
|
_, url = line.split()
|
||||||
|
_dependency_links.append(url)
|
||||||
|
elif not is_extras and line and line[0] != "#":
|
||||||
|
# Handle standard packages
|
||||||
|
_install_requires.append(line)
|
||||||
|
|
||||||
|
try:
|
||||||
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
|
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||||
|
|
||||||
|
if "Darwin" in platform.system():
|
||||||
|
# don't install xformers on MacOS
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
else:
|
||||||
|
# detect the version of torch already installed
|
||||||
|
# and set it so dependencies don't clobber the torch version
|
||||||
|
try:
|
||||||
|
torch_version = version("torch")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
torch_version = "2.5.1"
|
||||||
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
|
if version_match:
|
||||||
|
major, minor, patch = version_match.groups()
|
||||||
|
major, minor = int(major), int(minor)
|
||||||
|
patch = (
|
||||||
|
int(patch) if patch is not None else 0
|
||||||
|
) # Default patch to 0 if not present
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
|
if (major, minor) >= (2, 5):
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.append("xformers==0.0.28.post2")
|
||||||
|
else:
|
||||||
|
_install_requires.append("xformers==0.0.28.post3")
|
||||||
|
elif (major, minor) >= (2, 4):
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.27")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers==0.0.28.post1")
|
||||||
|
elif (major, minor) >= (2, 3):
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.26.post1")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.27")
|
||||||
|
elif (major, minor) >= (2, 2):
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.23.post1")
|
||||||
|
|
||||||
|
except PackageNotFoundError:
|
||||||
|
pass
|
||||||
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
|
|
||||||
|
class BuildPyCommand(_build_py):
|
||||||
|
"""
|
||||||
|
custom build_py command to parse dynamic requirements
|
||||||
|
"""
|
||||||
|
|
||||||
|
def finalize_options(self):
|
||||||
|
super().finalize_options()
|
||||||
|
install_requires, _ = parse_requirements()
|
||||||
|
self.distribution.install_requires = install_requires
|
||||||
Reference in New Issue
Block a user