Compare commits
8 Commits
v0.6.0
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ac9cbebb9 | ||
|
|
15f2fa4c8e | ||
|
|
43a2f9a155 | ||
|
|
8b79f1cbf6 | ||
|
|
3872d5eaed | ||
|
|
02629c7cdf | ||
|
|
78a4aa86d6 | ||
|
|
d009ead101 |
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging
|
||||
pip3 install -e .
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Extract tag name
|
||||
|
||||
11
.github/workflows/tests-nightly.yml
vendored
11
.github/workflows/tests-nightly.yml
vendored
@@ -44,6 +44,11 @@ jobs:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
@@ -60,11 +65,15 @@ jobs:
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install -U -e .
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
16
.github/workflows/tests.yml
vendored
16
.github/workflows/tests.yml
vendored
@@ -78,11 +78,15 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install -U -e .
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
@@ -120,7 +124,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -129,12 +133,16 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
python3 setup.py sdist
|
||||
pip3 install dist/axolotl*.tar.gz
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-build-isolation dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||
recursive-include axolotl *.py
|
||||
|
||||
@@ -112,7 +112,7 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
||||
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
|
||||
|
||||
```bash
|
||||
pip3 install axolotl[flash-attn,deepspeed]
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# download examples and optionally deepspeed configs to the local path
|
||||
axolotl fetch examples
|
||||
@@ -131,7 +131,7 @@ from source.
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install packaging ninja
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Axolotl CLI Usage
|
||||
@@ -320,7 +320,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
3. Install Axolotl along with python dependencies
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||
```bash
|
||||
@@ -399,7 +399,7 @@ Please use WSL or Docker!
|
||||
|
||||
Use the below instead of the install method in QuickStart.
|
||||
```
|
||||
pip3 install -e '.'
|
||||
pip3 install --no-build-isolation -e '.'
|
||||
```
|
||||
More info: [mac.md](/docs/mac.qmd)
|
||||
|
||||
|
||||
@@ -31,9 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
|
||||
|
||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
@@ -52,7 +52,7 @@ export GPU_ARCHS="gfx90a"
|
||||
cd flash-attention
|
||||
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
|
||||
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
|
||||
pip install .
|
||||
pip install --no-build-isolation .
|
||||
```
|
||||
|
||||
### 6. Install Axolotl
|
||||
@@ -63,7 +63,7 @@ Clone and install Axolotl:
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||
cd axolotl
|
||||
pip install packaging ninja
|
||||
pip install -e .
|
||||
pip install --no-build-isolation -e .
|
||||
```
|
||||
|
||||
### 7. Apply xformers Workaround
|
||||
|
||||
@@ -71,7 +71,7 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
#### Remote Hosts
|
||||
@@ -212,7 +212,7 @@ You will now be in the container. Next, perform an editable install of Axolotl:
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Attach To Container
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install axolotl[deepspeed]"
|
||||
"!pip install --no-build-isolation axolotl[deepspeed]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -17,3 +17,10 @@ Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
|
||||
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.cmdclass]
|
||||
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||
|
||||
@@ -13,5 +13,5 @@ cd /workspace
|
||||
rm -rf /workspace/axolotl
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip install --no-deps -e .
|
||||
pip install --no-build-isolation --no-deps -e .
|
||||
```
|
||||
|
||||
@@ -996,6 +996,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
def _evaluate(self, *args, **kwargs):
|
||||
metrics = super()._evaluate(*args, **kwargs)
|
||||
|
||||
# cleanup memory after evals
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
|
||||
0
src/axolotl/monkeypatch/models/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/__init__.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import contextlib
|
||||
import inspect
|
||||
import types
|
||||
|
||||
from torchtune.training import OffloadActivations
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||
|
||||
HF_MODEL_OUTPUTS = """
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
""".lstrip()
|
||||
|
||||
PATCHED_HF_MODEL_OUTPUTS = """
|
||||
with self.act_offloading_ctx_manager:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
""".lstrip()
|
||||
|
||||
LCE_MODEL_OUTPUTS = """
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
""".lstrip()
|
||||
|
||||
PATCHED_LCE_OUTPUTS = """
|
||||
with self.act_offloading_ctx_manager:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
""".lstrip()
|
||||
|
||||
HF_GA_FORWARD_1 = """
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
""".lstrip()
|
||||
|
||||
PATCHED_HF_GA_FORWARD_1 = """
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
""".lstrip()
|
||||
|
||||
HF_GA_FORWARD_2 = """
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
""".lstrip()
|
||||
|
||||
PATCHED_HF_GA_FORWARD_2 = """
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||
""".lstrip()
|
||||
|
||||
|
||||
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||
act_offloading_ctx_manager = contextlib.nullcontext()
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@classmethod
|
||||
def set_forward(cls):
|
||||
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||
forward_source, _ = detab_code(forward_source)
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<forward>", "exec"), cls
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enable_act_offloading(cls):
|
||||
forward_source = inspect.getsource(cls.forward)
|
||||
forward_source = forward_source.replace(
|
||||
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
||||
)
|
||||
forward_source, _ = detab_code(forward_source)
|
||||
# replace forward method with patched version
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||
)
|
||||
cls.act_offloading_ctx_manager = OffloadActivations()
|
||||
|
||||
@classmethod
|
||||
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||
from liger_kernel.transformers.model.llama import (
|
||||
lce_forward as llama_lce_forward,
|
||||
)
|
||||
|
||||
if enable_act_offloading:
|
||||
lce_source = inspect.getsource(llama_lce_forward)
|
||||
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
||||
# replace forward method with patched version
|
||||
cls.forward = types.MethodType(
|
||||
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
||||
cls,
|
||||
)
|
||||
else:
|
||||
cls.forward = types.methodType(llama_lce_forward, cls)
|
||||
|
||||
@classmethod
|
||||
def patch_hf_ga(cls):
|
||||
# bugfix patch for gradient accumulation
|
||||
forward_source = inspect.getsource(cls.forward)
|
||||
forward_source = forward_source.replace(
|
||||
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
||||
)
|
||||
forward_source = forward_source.replace(
|
||||
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||
)
|
||||
forward_source, _ = detab_code(forward_source)
|
||||
# replace forward method with patched version
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||
)
|
||||
|
||||
|
||||
def replace_auto_model():
|
||||
from transformers import LlamaConfig
|
||||
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||
AxolotlLlamaForCausalLM.set_forward()
|
||||
|
||||
return AxolotlLlamaForCausalLM
|
||||
@@ -66,10 +66,7 @@ class EvalFirstStepCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and state.global_step == 1
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
|
||||
@@ -679,6 +679,7 @@ class AxolotlInputConfig(
|
||||
default=False
|
||||
)
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||
activation_offloading: Optional[bool] = None
|
||||
|
||||
unfrozen_parameters: Optional[List[str]] = None
|
||||
|
||||
|
||||
@@ -380,6 +380,15 @@ class ModelLoader:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.pre_model_load(self.cfg)
|
||||
|
||||
if self.cfg.model_config_type == "llama":
|
||||
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
|
||||
|
||||
AxolotlLlamaForCausalLM = replace_auto_model()
|
||||
|
||||
AxolotlLlamaForCausalLM.patch_hf_ga()
|
||||
if self.cfg.activation_offloading:
|
||||
AxolotlLlamaForCausalLM.enable_act_offloading()
|
||||
|
||||
if self.cfg.fsdp:
|
||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||
patch_training_loop_for_fsdp,
|
||||
@@ -1183,6 +1192,8 @@ class ModelLoader:
|
||||
|
||||
self.apply_lora_patch()
|
||||
|
||||
# self.apply_patches_to_model()
|
||||
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
104
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
104
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
dynamic requirements for axolotl
|
||||
"""
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from setuptools.command.build_py import build_py as _build_py
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def parse_requirements():
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
is_extras = (
|
||||
"flash-attn" in line
|
||||
or "flash-attention" in line
|
||||
or "deepspeed" in line
|
||||
or "mamba-ssm" in line
|
||||
or "lion-pytorch" in line
|
||||
)
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
_dependency_links.append(url)
|
||||
elif not is_extras and line and line[0] != "#":
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
|
||||
if "Darwin" in platform.system():
|
||||
# don't install xformers on MacOS
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.5.1"
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
if version_match:
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = (
|
||||
int(patch) if patch is not None else 0
|
||||
) # Default patch to 0 if not present
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
elif (major, minor) >= (2, 3):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.26.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
elif (major, minor) >= (2, 2):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.25.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.23.post1")
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
class BuildPyCommand(_build_py):
|
||||
"""
|
||||
custom build_py command to parse dynamic requirements
|
||||
"""
|
||||
|
||||
def finalize_options(self):
|
||||
super().finalize_options()
|
||||
install_requires, _ = parse_requirements()
|
||||
self.distribution.install_requires = install_requires
|
||||
Reference in New Issue
Block a user