Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60c98a4353 | ||
|
|
c760d2b815 | ||
|
|
2014f58181 | ||
|
|
b5f9dd44f2 | ||
|
|
b17b1aada7 | ||
|
|
85381b6b15 | ||
|
|
acde081321 | ||
|
|
e4c68a0cbc | ||
|
|
3855f5c3d3 | ||
|
|
5dd566dc63 | ||
|
|
42389c1f78 | ||
|
|
d009ead101 |
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging
|
pip3 install wheel packaging
|
||||||
pip3 install -e .
|
pip3 install --no-build-isolation -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Extract tag name
|
- name: Extract tag name
|
||||||
|
|||||||
6
.github/workflows/tests-nightly.yml
vendored
6
.github/workflows/tests-nightly.yml
vendored
@@ -60,11 +60,15 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging
|
pip3 install --upgrade packaging
|
||||||
pip3 install -U -e .
|
pip3 install --no-build-isolation -U -e .
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
|
run: |
|
||||||
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
- name: Ensure axolotl CLI was installed
|
||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|||||||
16
.github/workflows/tests.yml
vendored
16
.github/workflows/tests.yml
vendored
@@ -78,11 +78,15 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
pip3 show torch
|
||||||
pip3 install -U -e .
|
pip3 install --no-build-isolation -U -e .
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
|
run: |
|
||||||
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
- name: Ensure axolotl CLI was installed
|
||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
@@ -120,7 +124,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging setuptools wheel
|
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -129,12 +133,16 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
pip3 show torch
|
||||||
python3 setup.py sdist
|
python -m build --no-isolation --sdist
|
||||||
pip3 install dist/axolotl*.tar.gz
|
pip3 install --no-build-isolation dist/axolotl*.tar.gz
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
|
run: |
|
||||||
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
- name: Ensure axolotl CLI was installed
|
||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
include requirements.txt
|
include requirements.txt
|
||||||
include README.md
|
include README.md
|
||||||
include LICENSE
|
include LICENSE
|
||||||
|
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||||
recursive-include axolotl *.py
|
recursive-include axolotl *.py
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
|||||||
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
|
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
|
|
||||||
# download examples and optionally deepspeed configs to the local path
|
# download examples and optionally deepspeed configs to the local path
|
||||||
axolotl fetch examples
|
axolotl fetch examples
|
||||||
@@ -131,7 +131,7 @@ from source.
|
|||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip3 install packaging ninja
|
pip3 install packaging ninja
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Axolotl CLI Usage
|
### Axolotl CLI Usage
|
||||||
@@ -320,7 +320,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
|||||||
3. Install Axolotl along with python dependencies
|
3. Install Axolotl along with python dependencies
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
4. (Optional) Login to Huggingface to use gated models/datasets.
|
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||||
```bash
|
```bash
|
||||||
@@ -399,7 +399,7 @@ Please use WSL or Docker!
|
|||||||
|
|
||||||
Use the below instead of the install method in QuickStart.
|
Use the below instead of the install method in QuickStart.
|
||||||
```
|
```
|
||||||
pip3 install -e '.'
|
pip3 install --no-build-isolation -e '.'
|
||||||
```
|
```
|
||||||
More info: [mac.md](/docs/mac.qmd)
|
More info: [mac.md](/docs/mac.qmd)
|
||||||
|
|
||||||
|
|||||||
@@ -31,9 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
|
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ export GPU_ARCHS="gfx90a"
|
|||||||
cd flash-attention
|
cd flash-attention
|
||||||
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
|
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
|
||||||
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
|
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
|
||||||
pip install .
|
pip install --no-build-isolation .
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6. Install Axolotl
|
### 6. Install Axolotl
|
||||||
@@ -63,7 +63,7 @@ Clone and install Axolotl:
|
|||||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip install packaging ninja
|
pip install packaging ninja
|
||||||
pip install -e .
|
pip install --no-build-isolation -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
### 7. Apply xformers Workaround
|
### 7. Apply xformers Workaround
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Remote Hosts
|
#### Remote Hosts
|
||||||
@@ -212,7 +212,7 @@ You will now be in the container. Next, perform an editable install of Axolotl:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Attach To Container
|
### Attach To Container
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install axolotl[deepspeed]"
|
"!pip install --no-build-isolation axolotl[deepspeed]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
58
examples/llama-3/fft-8b-tp.yml
Normal file
58
examples/llama-3/fft-8b-tp.yml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
tensor_parallel: 'auto'
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
evals_per_epoch: 2
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
73
examples/llama-3/lora-8b-tp.yml
Normal file
73
examples/llama-3/lora-8b-tp.yml
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
base_model: NousResearch/Meta-Llama-3.1-8B
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
tensor_parallel: 'auto'
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
@@ -17,3 +17,10 @@ Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
|
|||||||
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||||
|
|
||||||
[tool.setuptools_scm]
|
[tool.setuptools_scm]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||||
|
include-package-data = true
|
||||||
|
|
||||||
|
[tool.setuptools.cmdclass]
|
||||||
|
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||||
|
|||||||
@@ -13,5 +13,5 @@ cd /workspace
|
|||||||
rm -rf /workspace/axolotl
|
rm -rf /workspace/axolotl
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip install --no-deps -e .
|
pip install --no-build-isolation --no-deps -e .
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1319,6 +1319,10 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if hasattr(model, "add_model_tags"):
|
if hasattr(model, "add_model_tags"):
|
||||||
model.add_model_tags(["axolotl"])
|
model.add_model_tags(["axolotl"])
|
||||||
|
|
||||||
|
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
||||||
|
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||||
|
# self.model =
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_ref(self):
|
def model_ref(self):
|
||||||
return self._model_ref
|
return self._model_ref
|
||||||
|
|||||||
@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
|
|||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
)
|
)
|
||||||
trust_remote_code: Optional[bool] = None
|
trust_remote_code: Optional[bool] = None
|
||||||
|
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None
|
model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
@field_validator("trust_remote_code")
|
||||||
|
|||||||
@@ -1187,9 +1187,15 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self.post_loading_set_env()
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return self.model, lora_config
|
return self.model, lora_config
|
||||||
|
|
||||||
|
def post_loading_set_env(self):
|
||||||
|
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
||||||
|
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
104
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
104
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
dynamic requirements for axolotl
|
||||||
|
"""
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
|
from setuptools.command.build_py import build_py as _build_py
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def parse_requirements():
|
||||||
|
_install_requires = []
|
||||||
|
_dependency_links = []
|
||||||
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
|
for line in lines:
|
||||||
|
is_extras = (
|
||||||
|
"flash-attn" in line
|
||||||
|
or "flash-attention" in line
|
||||||
|
or "deepspeed" in line
|
||||||
|
or "mamba-ssm" in line
|
||||||
|
or "lion-pytorch" in line
|
||||||
|
)
|
||||||
|
if line.startswith("--extra-index-url"):
|
||||||
|
# Handle custom index URLs
|
||||||
|
_, url = line.split()
|
||||||
|
_dependency_links.append(url)
|
||||||
|
elif not is_extras and line and line[0] != "#":
|
||||||
|
# Handle standard packages
|
||||||
|
_install_requires.append(line)
|
||||||
|
|
||||||
|
try:
|
||||||
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
|
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||||
|
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||||
|
|
||||||
|
if "Darwin" in platform.system():
|
||||||
|
# don't install xformers on MacOS
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
else:
|
||||||
|
# detect the version of torch already installed
|
||||||
|
# and set it so dependencies don't clobber the torch version
|
||||||
|
try:
|
||||||
|
torch_version = version("torch")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
torch_version = "2.5.1"
|
||||||
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
|
if version_match:
|
||||||
|
major, minor, patch = version_match.groups()
|
||||||
|
major, minor = int(major), int(minor)
|
||||||
|
patch = (
|
||||||
|
int(patch) if patch is not None else 0
|
||||||
|
) # Default patch to 0 if not present
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
|
if (major, minor) >= (2, 5):
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.append("xformers==0.0.28.post2")
|
||||||
|
else:
|
||||||
|
_install_requires.append("xformers==0.0.28.post3")
|
||||||
|
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||||
|
elif (major, minor) >= (2, 4):
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.27")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers==0.0.28.post1")
|
||||||
|
elif (major, minor) >= (2, 3):
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.26.post1")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.27")
|
||||||
|
elif (major, minor) >= (2, 2):
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.23.post1")
|
||||||
|
|
||||||
|
except PackageNotFoundError:
|
||||||
|
pass
|
||||||
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
|
|
||||||
|
class BuildPyCommand(_build_py):
|
||||||
|
"""
|
||||||
|
custom build_py command to parse dynamic requirements
|
||||||
|
"""
|
||||||
|
|
||||||
|
def finalize_options(self):
|
||||||
|
super().finalize_options()
|
||||||
|
install_requires, _ = parse_requirements()
|
||||||
|
self.distribution.install_requires = install_requires
|
||||||
Reference in New Issue
Block a user