Compare commits
2 Commits
feat/pref_
...
e2e-fsdp-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39ab9626f1 | ||
|
|
26bd81cec0 |
7
.github/workflows/pypi.yml
vendored
7
.github/workflows/pypi.yml
vendored
@@ -13,13 +13,10 @@ jobs:
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Create release
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: gh release create "$GITHUB_REF_NAME" --generate-notes
|
||||
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
@@ -41,7 +38,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Extract tag name
|
||||
|
||||
11
.github/workflows/tests-nightly.yml
vendored
11
.github/workflows/tests-nightly.yml
vendored
@@ -44,11 +44,6 @@ jobs:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
@@ -65,15 +60,11 @@ jobs:
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
pip3 install -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
24
.github/workflows/tests.yml
vendored
24
.github/workflows/tests.yml
vendored
@@ -78,23 +78,19 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
pip3 install -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest tests/patched/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -124,7 +120,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -133,24 +129,20 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-build-isolation dist/axolotl*.tar.gz
|
||||
python3 setup.py sdist
|
||||
pip3 install dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest tests/patched/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||
recursive-include axolotl *.py
|
||||
|
||||
104
README.md
104
README.md
@@ -10,13 +10,9 @@
|
||||
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
|
||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
|
||||
<br/>
|
||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
|
||||
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
|
||||
<br/>
|
||||
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
|
||||
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
|
||||
<br/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||
</p>
|
||||
@@ -46,8 +42,7 @@ Features:
|
||||
- [Axolotl](#axolotl)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [Quickstart ⚡](#quickstart-)
|
||||
- [Edge Builds](#edge-builds-)
|
||||
- [Axolotl CLI Usage](#axolotl-cli-usage)
|
||||
- [Usage](#usage)
|
||||
- [Badge ❤🏷️](#badge-️)
|
||||
- [Contributing 🤝](#contributing-)
|
||||
- [Sponsors 🤝❤](#sponsors-)
|
||||
@@ -112,49 +107,58 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
||||
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
|
||||
|
||||
```bash
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# download examples and optionally deepspeed configs to the local path
|
||||
axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
|
||||
# finetune using lora
|
||||
axolotl train examples/llama-3/lora-1b.yml
|
||||
```
|
||||
|
||||
### Edge Builds 🏎️
|
||||
|
||||
If you're looking for the latest features and updates between releases, you'll need to install
|
||||
from source.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Axolotl CLI Usage
|
||||
We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/).
|
||||
### Usage
|
||||
```bash
|
||||
# preprocess datasets - optional but recommended
|
||||
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
|
||||
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./outputs/lora-out"
|
||||
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./outputs/lora-out" --gradio
|
||||
|
||||
# remote yaml files - the yaml config can be hosted on a public URL
|
||||
# Note: the yaml config must directly link to the **raw** yaml
|
||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
|
||||
```
|
||||
|
||||
### Axolotl CLI
|
||||
|
||||
If you've installed this package using `pip` from source, we now support a new, more
|
||||
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
|
||||
the above commands:
|
||||
|
||||
```bash
|
||||
# preprocess datasets - optional but recommended
|
||||
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml
|
||||
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml
|
||||
|
||||
# finetune lora
|
||||
axolotl train examples/llama-3/lora-1b.yml
|
||||
axolotl train examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
axolotl inference examples/llama-3/lora-1b.yml \
|
||||
axolotl inference examples/openllama-3b/lora.yml \
|
||||
--lora-model-dir="./outputs/lora-out"
|
||||
|
||||
# gradio
|
||||
axolotl inference examples/llama-3/lora-1b.yml \
|
||||
axolotl inference examples/openllama-3b/lora.yml \
|
||||
--lora-model-dir="./outputs/lora-out" --gradio
|
||||
|
||||
# remote yaml files - the yaml config can be hosted on a public URL
|
||||
# Note: the yaml config must directly link to the **raw** yaml
|
||||
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
|
||||
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
|
||||
```
|
||||
|
||||
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
|
||||
@@ -171,36 +175,6 @@ axolotl fetch deepspeed_configs
|
||||
axolotl fetch examples --dest path/to/folder
|
||||
```
|
||||
|
||||
### Legacy Usage
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
|
||||
While the Axolotl CLI is the preferred method for interacting with axolotl, we
|
||||
still support the legacy `-m axolotl.cli.*` usage.
|
||||
|
||||
```bash
|
||||
# preprocess datasets - optional but recommended
|
||||
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
|
||||
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
|
||||
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
|
||||
--lora_model_dir="./outputs/lora-out"
|
||||
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
|
||||
--lora_model_dir="./outputs/lora-out" --gradio
|
||||
|
||||
# remote yaml files - the yaml config can be hosted on a public URL
|
||||
# Note: the yaml config must directly link to the **raw** yaml
|
||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Badge ❤🏷️
|
||||
|
||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||
@@ -320,7 +294,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
3. Install Axolotl along with python dependencies
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||
```bash
|
||||
@@ -399,7 +373,7 @@ Please use WSL or Docker!
|
||||
|
||||
Use the below instead of the install method in QuickStart.
|
||||
```
|
||||
pip3 install --no-build-isolation -e '.'
|
||||
pip3 install -e '.'
|
||||
```
|
||||
More info: [mac.md](/docs/mac.qmd)
|
||||
|
||||
|
||||
@@ -31,9 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
@@ -52,7 +52,7 @@ export GPU_ARCHS="gfx90a"
|
||||
cd flash-attention
|
||||
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
|
||||
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
|
||||
pip install --no-build-isolation .
|
||||
pip install .
|
||||
```
|
||||
|
||||
### 6. Install Axolotl
|
||||
@@ -63,7 +63,7 @@ Clone and install Axolotl:
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||
cd axolotl
|
||||
pip install packaging ninja
|
||||
pip install --no-build-isolation -e .
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 7. Apply xformers Workaround
|
||||
|
||||
@@ -71,7 +71,7 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
#### Remote Hosts
|
||||
@@ -212,7 +212,7 @@ You will now be in the container. Next, perform an editable install of Axolotl:
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Attach To Container
|
||||
|
||||
@@ -52,26 +52,6 @@ datasets:
|
||||
type: chat_template.argilla
|
||||
```
|
||||
|
||||
|
||||
#### KTO
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
rl_beta: 0.5
|
||||
kto_desirable_weight: 0.2
|
||||
|
||||
remove_unused_columns: false
|
||||
|
||||
datasets:
|
||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
||||
type: llama3.ultra
|
||||
split: train
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
```
|
||||
|
||||
#### Using local dataset files
|
||||
```yaml
|
||||
datasets:
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install --no-build-isolation axolotl[deepspeed]"
|
||||
"!pip install axolotl[deepspeed]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -1,75 +0,0 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
rl: kto
|
||||
rl_beta: 0.5
|
||||
kto_desirable_weight: 0.2
|
||||
|
||||
datasets:
|
||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
||||
type: llama3.ultra
|
||||
split: train
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/qlora-out
|
||||
|
||||
remove_unused_columns: false
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false # not supported with kto
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -1,4 +1,4 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
@@ -22,6 +22,7 @@ pad_to_sequence_len: true
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
|
||||
@@ -17,10 +17,3 @@ Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
|
||||
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.cmdclass]
|
||||
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||
|
||||
@@ -1,30 +1,22 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.0
|
||||
triton>=2.3.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.0.post2
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.4.2
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
peft==0.14.0
|
||||
transformers>=4.46.3
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.45.0
|
||||
accelerate==1.2.0
|
||||
datasets==3.1.0
|
||||
deepspeed==0.16.1
|
||||
deepspeed==0.15.4
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
flash-attn==2.7.0.post2
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers>=0.0.23.post1
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
colorama
|
||||
@@ -39,6 +31,11 @@ art
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq==0.2.7.post3
|
||||
triton>=2.3.0
|
||||
liger-kernel==0.4.2
|
||||
|
||||
mamba-ssm==1.2.0.post1
|
||||
|
||||
# remote filesystems
|
||||
s3fs>=2024.5.0
|
||||
|
||||
@@ -13,5 +13,5 @@ cd /workspace
|
||||
rm -rf /workspace/axolotl
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip install --no-build-isolation --no-deps -e .
|
||||
pip install --no-deps -e .
|
||||
```
|
||||
|
||||
20
setup.py
20
setup.py
@@ -1,10 +1,7 @@
|
||||
"""setup.py for axolotl"""
|
||||
import ast
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
@@ -93,24 +90,9 @@ def parse_requirements():
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
def get_package_version():
|
||||
with open(
|
||||
Path(os.path.dirname(os.path.abspath(__file__)))
|
||||
/ "src"
|
||||
/ "axolotl"
|
||||
/ "__init__.py",
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
|
||||
version_ = ast.literal_eval(version_match.group(1))
|
||||
return version_
|
||||
|
||||
|
||||
install_requires, dependency_links = parse_requirements()
|
||||
|
||||
setup(
|
||||
version=get_package_version(),
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
install_requires=install_requires,
|
||||
@@ -125,7 +107,7 @@ setup(
|
||||
"flash-attn==2.7.0.post2",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.1",
|
||||
"deepspeed==0.15.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
"""Axolotl - Train and fine-tune large language models"""
|
||||
|
||||
__version__ = "0.6.0"
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
__version__ = version("axolotl")
|
||||
except ImportError:
|
||||
__version__ = "unknown"
|
||||
|
||||
@@ -14,22 +14,17 @@ import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from liger_kernel.chunked_loss.fused_linear_preference import (
|
||||
LigerFusedLinearPreferenceBase,
|
||||
)
|
||||
from packaging import version
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import amp, nn
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
@@ -1082,15 +1077,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
|
||||
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
||||
|
||||
self.liger_loss = LigerFusedLinearDPOLoss(
|
||||
ignore_index=self.label_pad_token_id,
|
||||
beta=self.beta,
|
||||
compute_nll_loss=True, # not same as rpo_alpha hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None,
|
||||
use_ref_model=not self.reference_free,
|
||||
)
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
@@ -1194,309 +1180,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
# transformers<=4.46
|
||||
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the DPO loss and other metrics using Liger kernel."""
|
||||
# return super().get_batch_loss_metrics(model, batch, train_eval)
|
||||
if not self.liger_loss:
|
||||
raise ValueError("Liger loss not initialized")
|
||||
|
||||
metrics = {}
|
||||
|
||||
model_output = self.concatenated_forward(model, batch)
|
||||
|
||||
# Get the lm_head weights and bias
|
||||
lin_weight = model.lm_head.weight
|
||||
lin_bias = getattr(model.lm_head, "bias", None)
|
||||
|
||||
hidden_states = model_output["hidden_states"]
|
||||
labels = model_output["labels"]
|
||||
|
||||
if not self.reference_free:
|
||||
# Adapted from DPO's compute_ref_log_probs
|
||||
compte_ref_context_manager = (
|
||||
amp.autocast("cuda")
|
||||
if self._peft_has_been_casted_to_bf16
|
||||
else nullcontext()
|
||||
)
|
||||
with torch.no_grad(), compte_ref_context_manager: # type: ignore
|
||||
if self.ref_model is None:
|
||||
with self.null_ref_context():
|
||||
ref_model_output = self.concatenated_forward(self.model, batch)
|
||||
ref_weight = self.model.lm_head.weight
|
||||
ref_bias = getattr(self.model.lm_head, "bias", None)
|
||||
|
||||
ref_hidden_states = ref_model_output["hidden_states"]
|
||||
|
||||
else:
|
||||
ref_model_output = self.concatenated_forward(self.ref_model, batch)
|
||||
ref_weight = self.ref_model.lm_head.weight
|
||||
ref_bias = getattr(self.ref_model.lm_head, "bias", None)
|
||||
|
||||
ref_hidden_states = ref_model_output["hidden_states"]
|
||||
(
|
||||
ref_chosen_logps,
|
||||
ref_rejected_logps,
|
||||
_ref_chosen_logits,
|
||||
_ref_rejected_logits,
|
||||
_ref_chosen_nll_loss,
|
||||
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
||||
input_chunk=ref_hidden_states,
|
||||
weight=ref_weight,
|
||||
target_chunk=labels,
|
||||
bias=ref_bias,
|
||||
# ignore_index=ignore_index,
|
||||
compute_nll_loss=False,
|
||||
)
|
||||
|
||||
else:
|
||||
ref_hidden_states = None
|
||||
ref_weight = None
|
||||
ref_bias = None
|
||||
|
||||
# Compute loss using Liger kernel
|
||||
loss, return_vars = self.liger_loss(
|
||||
lin_weight=lin_weight,
|
||||
_input=hidden_states,
|
||||
target=labels,
|
||||
bias=lin_bias, # TODO: check whether to pass bias as FCLE doesn't
|
||||
ref_input=ref_hidden_states,
|
||||
ref_weight=ref_weight,
|
||||
ref_bias=ref_bias,
|
||||
)
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = return_vars
|
||||
|
||||
# Calculate rewards
|
||||
if not self.reference_free:
|
||||
chosen_rewards = (
|
||||
self.beta * (policy_chosen_logps - (ref_chosen_logps)).detach()
|
||||
)
|
||||
rejected_rewards = (
|
||||
self.beta * (policy_rejected_logps - (ref_rejected_logps)).detach()
|
||||
)
|
||||
|
||||
else:
|
||||
chosen_rewards = self.beta * policy_chosen_logps
|
||||
rejected_rewards = self.beta * policy_rejected_logps
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics.update(
|
||||
{
|
||||
f"{prefix}rewards/chosen": chosen_rewards.mean().cpu(),
|
||||
f"{prefix}rewards/rejected": rejected_rewards.mean().cpu(),
|
||||
f"{prefix}rewards/accuracies": reward_accuracies.mean().cpu(),
|
||||
f"{prefix}rewards/margins": (chosen_rewards - rejected_rewards)
|
||||
.mean()
|
||||
.cpu(),
|
||||
f"{prefix}logps/chosen": policy_chosen_logps.mean().cpu(),
|
||||
f"{prefix}logps/rejected": policy_rejected_logps.mean().cpu(),
|
||||
f"{prefix}logits/chosen": policy_chosen_logits_mean.cpu(),
|
||||
f"{prefix}logits/rejected": policy_rejected_logits_mean.cpu(),
|
||||
}
|
||||
)
|
||||
|
||||
if hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None:
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu()
|
||||
|
||||
# TODO: Handle use_weighting, aux_loss_enabled as in upstream
|
||||
|
||||
return loss, metrics
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
||||
):
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||
|
||||
Overridden base function to return the hidden states and labels for the loss calculation.
|
||||
"""
|
||||
num_examples = batch["prompt_input_ids"].shape[0] # type: ignore
|
||||
|
||||
concatenated_batch = self.concatenated_inputs(
|
||||
batch, padding_value=self.padding_value
|
||||
)
|
||||
|
||||
model_kwargs = {}
|
||||
if self.aux_loss_enabled:
|
||||
model_kwargs["output_router_logits"] = True
|
||||
|
||||
# Add to get the hidden states for the loss
|
||||
model_kwargs["output_hidden_states"] = True
|
||||
|
||||
# Add the pixel values and attention masks for vision models
|
||||
if "pixel_values" in concatenated_batch:
|
||||
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
|
||||
if "pixel_attention_mask" in concatenated_batch:
|
||||
model_kwargs["pixel_attention_mask"] = concatenated_batch[
|
||||
"pixel_attention_mask"
|
||||
]
|
||||
if "image_sizes" in concatenated_batch:
|
||||
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
|
||||
|
||||
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
||||
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
||||
completion_input_ids = concatenated_batch["completion_input_ids"]
|
||||
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
||||
if self.is_encoder_decoder:
|
||||
labels = completion_input_ids
|
||||
labels[completion_attention_mask == 0] = self.label_pad_token_id
|
||||
outputs = model(
|
||||
input_ids=prompt_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
labels=labels, # we need the labels for the logits to be returned
|
||||
**model_kwargs,
|
||||
)
|
||||
logits = outputs.logits
|
||||
hidden_states = outputs.decoder_hidden_states[-1]
|
||||
loss_mask = completion_attention_mask.bool()
|
||||
else:
|
||||
# Concatenate the prompt and completion inputs
|
||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||
attention_mask = torch.cat(
|
||||
(prompt_attention_mask, completion_attention_mask), dim=1
|
||||
)
|
||||
# Mask the prompt but not the completion for the loss
|
||||
loss_mask = torch.cat(
|
||||
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
for i in range(attention_mask.size(0)):
|
||||
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
||||
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) # type: ignore
|
||||
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) # type: ignore
|
||||
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) # type: ignore
|
||||
|
||||
# Get the first column idx that is all zeros and remove every column after that
|
||||
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
||||
first_empty_col = (
|
||||
torch.nonzero(empty_cols)[0].item()
|
||||
if empty_cols.any()
|
||||
else attention_mask.size(1)
|
||||
)
|
||||
input_ids = input_ids[:, :first_empty_col] # type: ignore
|
||||
attention_mask = attention_mask[:, :first_empty_col] # type: ignore
|
||||
loss_mask = loss_mask[:, :first_empty_col] # type: ignore
|
||||
|
||||
# Truncate right
|
||||
if self.args.max_length is not None:
|
||||
input_ids = input_ids[:, : self.args.max_length]
|
||||
attention_mask = attention_mask[:, : self.args.max_length]
|
||||
loss_mask = loss_mask[:, : self.args.max_length]
|
||||
|
||||
# if self.use_num_logits_to_keep:
|
||||
# # Compute num_logits_to_keep based on loss_mask pattern:
|
||||
# # [[0, 0, 0, x, x, x, x],
|
||||
# # [0, 0, 0, x, x, x, 0]]
|
||||
# # ^ start computing logits from here ([:, -(7-3+1):])
|
||||
# first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
|
||||
# num_logits_to_keep = loss_mask.shape[1] - first_compute_index
|
||||
# model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label
|
||||
|
||||
outputs = model(
|
||||
input_ids=input_ids, attention_mask=attention_mask, **model_kwargs
|
||||
)
|
||||
|
||||
# Offset the logits by one to align with the labels
|
||||
logits = outputs.logits[:, :-1, :]
|
||||
hidden_states = outputs.hidden_states[-1][:, :-1, :]
|
||||
labels = input_ids[:, 1:].clone()
|
||||
loss_mask = loss_mask[:, 1:].bool()
|
||||
|
||||
# if self.use_num_logits_to_keep:
|
||||
# # Align labels with logits
|
||||
# # logits: -, -, [x2, x3, x4, x5, x6]
|
||||
# # ^ --------- ^ after logits[:, :-1, :]
|
||||
# # labels: [y0, y1, y2, y3, y4, y5, y6]
|
||||
# # ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
|
||||
# # loss_mask: [0, 0, 0, 1, 1, 1, 1]
|
||||
# labels = labels[:, -num_logits_to_keep:]
|
||||
# loss_mask = loss_mask[:, -num_logits_to_keep:]
|
||||
# hidden_states = hidden_states[:, -num_logits_to_keep:, :]
|
||||
|
||||
if logits.shape[:2] != labels.shape[:2]:
|
||||
# for llava, the returned logits include the image tokens (placed before the text tokens)
|
||||
seq_len = labels.shape[1]
|
||||
logits = logits[:, -seq_len:]
|
||||
hidden_states = hidden_states[:, -seq_len:]
|
||||
|
||||
# Compute the log probabilities of the labels
|
||||
labels[
|
||||
~loss_mask
|
||||
] = 0 # dummy token; we'll ignore the losses on these tokens later
|
||||
per_token_logps = torch.gather(
|
||||
logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
per_token_logps[~loss_mask] = 0
|
||||
all_logps = per_token_logps.sum(-1)
|
||||
|
||||
output = {}
|
||||
|
||||
if self.use_weighting:
|
||||
with torch.no_grad():
|
||||
# Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
|
||||
logprobs = F.log_softmax(logits, dim=-1)
|
||||
weights_adjustment_factor = torch.logsumexp(
|
||||
2 * logprobs, dim=-1
|
||||
) # same as sum(probs**2) in log space
|
||||
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
|
||||
all_weights = (per_token_logps_adjusted * loss_mask).sum(
|
||||
-1
|
||||
) / loss_mask.sum(-1)
|
||||
chosen_weights = all_weights[:num_examples]
|
||||
rejected_weights = all_weights[num_examples:]
|
||||
output["policy_weights"] = torch.clamp(
|
||||
torch.exp(chosen_weights + rejected_weights), max=1
|
||||
)
|
||||
|
||||
if self.args.rpo_alpha is not None:
|
||||
# Only use the chosen logits for the RPO loss
|
||||
chosen_logits = logits[:num_examples]
|
||||
chosen_labels = labels[:num_examples]
|
||||
|
||||
# Compute the log probabilities of the labels
|
||||
output["nll_loss"] = F.cross_entropy(
|
||||
torch.flatten(chosen_logits, end_dim=1),
|
||||
torch.flatten(chosen_labels, end_dim=1),
|
||||
ignore_index=0,
|
||||
)
|
||||
|
||||
if self.loss_type == "ipo":
|
||||
all_logps = all_logps / loss_mask.sum(-1)
|
||||
|
||||
output["chosen_logps"] = all_logps[:num_examples]
|
||||
output["rejected_logps"] = all_logps[num_examples:]
|
||||
output["mean_chosen_logits"] = logits[:num_examples][
|
||||
loss_mask[:num_examples]
|
||||
].mean()
|
||||
output["mean_rejected_logits"] = logits[num_examples:][
|
||||
loss_mask[num_examples:]
|
||||
].mean()
|
||||
output["hidden_states"] = hidden_states
|
||||
output["labels"] = labels
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
output["aux_loss"] = outputs.aux_loss
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
@@ -1685,6 +1368,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from transformers.integrations.integration_utils import MLflowCallback
|
||||
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
@@ -1692,6 +1377,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.extend(
|
||||
[
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
||||
MLflowCallback,
|
||||
]
|
||||
)
|
||||
if self.cfg.use_comet and is_comet_available():
|
||||
@@ -2480,14 +2166,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
|
||||
report_to = []
|
||||
if self.cfg.use_wandb:
|
||||
report_to.append("wandb")
|
||||
if self.cfg.wandb_name:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
training_args_kwargs["report_to"] = report_to
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
|
||||
@@ -204,87 +204,3 @@ def patch_forward_for_ga():
|
||||
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
|
||||
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i != len(batch_samples) - 1
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
with context():
|
||||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
disable_deepspeed_no_sync = (
|
||||
self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
|
||||
)
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
with context():
|
||||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||
"""
|
||||
|
||||
|
||||
def get_training_loop_code() -> str:
|
||||
training_loop = inspect.getsource(
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
return training_loop
|
||||
|
||||
|
||||
def check_training_loop_is_patchable() -> bool:
|
||||
training_loop = get_training_loop_code()
|
||||
training_loop, _ = detab_code(training_loop)
|
||||
return ORIGINAL_TRAINER_CODE in training_loop
|
||||
|
||||
|
||||
def patch_training_loop_for_deepspeed_0_16_x():
|
||||
"""
|
||||
monkeypatch for fixing the training loop for deepspeed GA
|
||||
|
||||
see https://github.com/huggingface/transformers/pull/35157
|
||||
"""
|
||||
|
||||
try:
|
||||
training_loop = get_training_loop_code()
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||
training_loop
|
||||
)
|
||||
training_loop, _ = detab_code(training_loop)
|
||||
if ORIGINAL_TRAINER_CODE not in training_loop:
|
||||
return
|
||||
|
||||
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
|
||||
training_loop = training_loop.replace(
|
||||
"def _inner_training_loop(",
|
||||
"def _fixed_inner_training_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in training_loop:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
@@ -28,8 +28,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
:return:
|
||||
"""
|
||||
|
||||
max_length = self.prompter.max_length
|
||||
|
||||
self.messages = "chosen_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
@@ -41,16 +39,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
chosen_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||
)
|
||||
|
||||
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
||||
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
|
||||
:max_length
|
||||
]
|
||||
|
||||
self.messages = "rejected_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
@@ -64,18 +52,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
)
|
||||
rejected_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||
)
|
||||
|
||||
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
||||
:max_length
|
||||
]
|
||||
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
|
||||
:max_length
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids_chosen": chosen_tokenized["input_ids"],
|
||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||
@@ -104,9 +80,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": (
|
||||
cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len
|
||||
),
|
||||
"max_length": cfg.sequence_len + 1
|
||||
if not cfg.reward_model
|
||||
else cfg.sequence_len,
|
||||
}
|
||||
|
||||
strategy_params = {
|
||||
|
||||
@@ -42,7 +42,6 @@ class ChatTemplatePrompter(Prompter):
|
||||
"gpt": "assistant",
|
||||
"system": "system",
|
||||
}
|
||||
|
||||
self.message_field_role = message_field_role
|
||||
self.message_field_content = message_field_content
|
||||
self.message_field_training = message_field_training
|
||||
@@ -54,9 +53,21 @@ class ChatTemplatePrompter(Prompter):
|
||||
self.drop_system_message = drop_system_message
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||
turns = [
|
||||
{
|
||||
"role": self.roles[t[self.message_field_role]],
|
||||
"content": t[self.message_field_content],
|
||||
"training": t.get(self.message_field_training, None),
|
||||
}
|
||||
for t in conversation
|
||||
]
|
||||
|
||||
if self.drop_system_message and turns[0]["role"] == "system":
|
||||
turns = turns[1:]
|
||||
|
||||
if self.processor:
|
||||
text = self.processor.apply_chat_template(
|
||||
conversation,
|
||||
turns,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
@@ -65,6 +76,8 @@ class ChatTemplatePrompter(Prompter):
|
||||
text=text,
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
# workaround since processor works in batches instead of single examples
|
||||
for k, val in batch.items():
|
||||
@@ -75,7 +88,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
return batch
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
turns,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
@@ -200,14 +215,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
train_on_eos=None,
|
||||
):
|
||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||
|
||||
self.roles_to_train = []
|
||||
if roles_to_train:
|
||||
# map roles if exist in prompter.roles else use the role as is
|
||||
self.roles_to_train = [
|
||||
prompter.roles.get(role, role) for role in roles_to_train
|
||||
]
|
||||
|
||||
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
||||
self.train_on_eos = train_on_eos
|
||||
self.images = "images"
|
||||
|
||||
@@ -254,28 +262,30 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
turns = prompt[self.messages]
|
||||
input_ids = self.prompter.build_prompt(turns)
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
last_eos_idx = -1
|
||||
for index, turn in enumerate(turns):
|
||||
role = turn.get("role")
|
||||
content = turn.get("content")
|
||||
train_turn = turn.get("training")
|
||||
train_detail = turn.get("training_detail")
|
||||
role = turn.get(self.prompter.message_field_role)
|
||||
content = turn.get(self.prompter.message_field_content)
|
||||
train_turn = turn.get(self.prompter.message_field_training)
|
||||
train_detail = turn.get(self.prompter.message_field_training_detail)
|
||||
|
||||
LOG.debug(
|
||||
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
||||
)
|
||||
|
||||
should_train = None
|
||||
if train_turn is not None:
|
||||
should_train = train_turn
|
||||
elif train_detail is not None:
|
||||
should_train = bool(train_detail)
|
||||
else:
|
||||
should_train = self.train_on_inputs or role in self.roles_to_train
|
||||
should_train = (
|
||||
train_turn
|
||||
if train_turn is not None
|
||||
else (
|
||||
bool(train_detail is not None)
|
||||
if train_detail is not None
|
||||
else self.train_on_inputs or role in self.roles_to_train
|
||||
)
|
||||
)
|
||||
|
||||
LOG.debug(f"Should train: {should_train}")
|
||||
|
||||
@@ -283,9 +293,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
conversation_ids=input_ids, turn=index, turn_content=turn
|
||||
)
|
||||
|
||||
if turn_start_idx == -1 or turn_end_idx == -1:
|
||||
LOG.warning(f"Failed to find boundaries for turn {index}")
|
||||
|
||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||
|
||||
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||
@@ -306,9 +313,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
labels[turn_start_idx:turn_end_idx] = input_ids[
|
||||
turn_start_idx:turn_end_idx
|
||||
]
|
||||
LOG.debug(
|
||||
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
|
||||
)
|
||||
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
|
||||
|
||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||
|
||||
@@ -346,73 +351,52 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
|
||||
def find_turn(self, conversation_ids, turn, turn_content):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
|
||||
Args:
|
||||
conversation_ids (list[int]): Token IDs representing the conversation.
|
||||
turn (int): The turn number to locate (based on EOS tokens).
|
||||
turn_content (str): String containing the content of the turn.
|
||||
|
||||
Returns:
|
||||
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
|
||||
Returns (-1, -1) if the turn content is not found.
|
||||
"""
|
||||
content = turn_content.get("content")
|
||||
content = turn_content.get(self.prompter.message_field_content, "")
|
||||
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
||||
|
||||
LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
eos_count = 0
|
||||
start_search_idx = 0
|
||||
|
||||
if not content_ids:
|
||||
LOG.warning(f"Empty content for turn {turn}")
|
||||
return -1, -1
|
||||
# Locate the starting index after the specified number of EOS tokens
|
||||
for i, token_id in enumerate(conversation_ids):
|
||||
if token_id == eos_token_id:
|
||||
eos_count += 1
|
||||
if eos_count == turn:
|
||||
start_search_idx = (
|
||||
i + 1
|
||||
) # Start searching after the specified turn's EOS token
|
||||
break
|
||||
|
||||
# For first turn, start from beginning
|
||||
if turn == 0:
|
||||
start_search_idx = 0
|
||||
# Find the start index of the content within the conversation
|
||||
start_idx = -1
|
||||
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
|
||||
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
||||
start_idx = i
|
||||
break
|
||||
|
||||
if start_idx != -1:
|
||||
end_idx = start_idx + len(content_ids)
|
||||
else:
|
||||
# For subsequent turns, find the previous EOS token
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
eos_count = 0
|
||||
start_search_idx = 0
|
||||
end_idx = -1
|
||||
|
||||
for i, token_id in enumerate(conversation_ids):
|
||||
if token_id == eos_token_id:
|
||||
eos_count += 1
|
||||
if eos_count == turn: # Find the nth EOS token where n = turn
|
||||
start_search_idx = i + 1
|
||||
break
|
||||
|
||||
# we can optimize this to only search for a few tokens from start_search_idx
|
||||
# but it would risk missing the content if it's not found within the first few tokens or
|
||||
# if start_search_idx cannot be found above.
|
||||
last_index = len(conversation_ids) - len(content_ids) + 1
|
||||
|
||||
if last_index < start_search_idx:
|
||||
LOG.warning(
|
||||
f"last_index to search is less than start_search_idx for turn {turn}"
|
||||
)
|
||||
return -1, -1
|
||||
|
||||
# Search for content starting from start_search_idx
|
||||
first_elem = content_ids[0]
|
||||
for i in range(start_search_idx, last_index):
|
||||
# Quick check of first element before doing full comparison
|
||||
if conversation_ids[i] == first_elem:
|
||||
# Check if the rest of the content matches
|
||||
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
||||
LOG.debug(f"Found turn {turn} content at position {i}")
|
||||
return i, i + len(content_ids)
|
||||
|
||||
return -1, -1
|
||||
return start_idx, end_idx
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = [
|
||||
{
|
||||
"role": self.prompter.roles[t[self.prompter.message_field_role]],
|
||||
"content": t[self.prompter.message_field_content],
|
||||
"training": t.get(self.prompter.message_field_training),
|
||||
"training_detail": t.get(self.prompter.message_field_training_detail),
|
||||
}
|
||||
for t in prompt[self.messages]
|
||||
]
|
||||
|
||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||
turns = turns[1:]
|
||||
|
||||
return turns
|
||||
return prompt[self.messages]
|
||||
|
||||
def get_images(self, prompt):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
@@ -259,7 +259,14 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
try:
|
||||
# Check to make sure the base model is from HuggingFace not a local directory
|
||||
hf_api = HfApi()
|
||||
hf_api.model_info(cfg.base_model)
|
||||
|
||||
model_card_kwarg = {
|
||||
"model_name": cfg.output_dir.lstrip("./")
|
||||
.encode("utf-8")
|
||||
@@ -267,22 +274,16 @@ def train(
|
||||
}
|
||||
if cfg.datasets is not None:
|
||||
if cfg.rl is not None or cfg.reward_model:
|
||||
dataset_tags = [
|
||||
model_card_kwarg["dataset_name"] = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
if dataset_tags:
|
||||
# guard as create_model_card may fail if dataset_tags is empty list
|
||||
model_card_kwarg["dataset_name"] = dataset_tags
|
||||
else:
|
||||
dataset_tags = [
|
||||
model_card_kwarg["dataset_tags"] = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
if dataset_tags:
|
||||
# guard as create_model_card may fail if dataset_tags is empty list
|
||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||
|
||||
trainer.create_model_card(**model_card_kwarg)
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
|
||||
pass
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
|
||||
@@ -66,7 +66,10 @@ class EvalFirstStepCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and state.global_step == 1
|
||||
):
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
|
||||
@@ -1475,27 +1475,6 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_kto_config(cls, data):
|
||||
if data.get("rl") == "kto":
|
||||
if data.get("sample_packing") or data.get("eval_sample_packing"):
|
||||
raise ValueError("sample_packing is not supported with kto")
|
||||
|
||||
if data.get("remove_unused_columns") is not False:
|
||||
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
||||
|
||||
if data.get("gradient_checkpointing") and not (
|
||||
data.get("gradient_checkpointing_kwargs")
|
||||
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
|
||||
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
|
||||
):
|
||||
raise ValueError(
|
||||
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
@@ -386,12 +386,6 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
patch_training_loop_for_fsdp()
|
||||
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
|
||||
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||
patch_training_loop_for_deepspeed_0_16_x,
|
||||
)
|
||||
|
||||
patch_training_loop_for_deepspeed_0_16_x()
|
||||
|
||||
if self.cfg.gradient_checkpointing == "unsloth":
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"""
|
||||
dynamic requirements for axolotl
|
||||
"""
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from setuptools.command.build_py import build_py as _build_py
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def parse_requirements():
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
is_extras = (
|
||||
"flash-attn" in line
|
||||
or "flash-attention" in line
|
||||
or "deepspeed" in line
|
||||
or "mamba-ssm" in line
|
||||
or "lion-pytorch" in line
|
||||
)
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
_dependency_links.append(url)
|
||||
elif not is_extras and line and line[0] != "#":
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
|
||||
if "Darwin" in platform.system():
|
||||
# don't install xformers on MacOS
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.5.1"
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
if version_match:
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = (
|
||||
int(patch) if patch is not None else 0
|
||||
) # Default patch to 0 if not present
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
elif (major, minor) >= (2, 3):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.26.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
elif (major, minor) >= (2, 2):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.25.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.23.post1")
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
class BuildPyCommand(_build_py):
|
||||
"""
|
||||
custom build_py command to parse dynamic requirements
|
||||
"""
|
||||
|
||||
def finalize_options(self):
|
||||
super().finalize_options()
|
||||
install_requires, _ = parse_requirements()
|
||||
self.distribution.install_requires = install_requires
|
||||
@@ -1,10 +0,0 @@
|
||||
"""pytest tests for axolotl CLI --version"""
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
def test_print_version(cli_runner):
|
||||
"""Test that version is printed when --version is used."""
|
||||
|
||||
result = cli_runner.invoke(cli, ["--version"])
|
||||
assert result.exit_code == 0
|
||||
assert "axolotl, version " in result.output
|
||||
@@ -120,15 +120,9 @@ def temp_dir():
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
from transformers import Trainer
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
original_fa2_forward = LlamaFlashAttention2.forward
|
||||
original_llama_attn_forward = LlamaAttention.forward
|
||||
original_llama_forward = LlamaForCausalLM.forward
|
||||
original_trainer_inner_training_loop = (
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
@@ -137,8 +131,6 @@ def cleanup_monkeypatches():
|
||||
yield
|
||||
# Reset LlamaFlashAttention2 forward
|
||||
LlamaFlashAttention2.forward = original_fa2_forward
|
||||
LlamaAttention.forward = original_llama_attn_forward
|
||||
LlamaForCausalLM.forward = original_llama_forward
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
original_trainer_inner_training_loop
|
||||
)
|
||||
@@ -146,25 +138,16 @@ def cleanup_monkeypatches():
|
||||
|
||||
# Reset other known monkeypatches
|
||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||
("transformers.models.llama",),
|
||||
(
|
||||
"transformers.models.llama.modeling_llama",
|
||||
["LlamaFlashAttention2", "LlamaAttention"],
|
||||
),
|
||||
("transformers.trainer",),
|
||||
("transformers", ["Trainer"]),
|
||||
("transformers",),
|
||||
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
|
||||
("transformers.trainer", ["Trainer"]),
|
||||
("transformers.loss.loss_utils",),
|
||||
]
|
||||
for module_name_tuple in modules_to_reset:
|
||||
module_name = module_name_tuple[0]
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name, sys.modules[module_name].__file__
|
||||
)
|
||||
sys.modules[module_name] = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(sys.modules[module_name])
|
||||
|
||||
sys.modules[module_name] = importlib.reload(sys.modules[module_name])
|
||||
module = importlib.import_module(module_name)
|
||||
sys.modules[module_name] = module
|
||||
importlib.reload(sys.modules[module_name])
|
||||
if len(module_name_tuple) > 1:
|
||||
module_globals = module_name_tuple[1]
|
||||
for module_global in module_globals:
|
||||
|
||||
@@ -71,11 +71,7 @@ class TestCutCrossEntropyIntegration:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention_type",
|
||||
[
|
||||
"flash_attention",
|
||||
"sdp_attention",
|
||||
# "xformers_attention",
|
||||
],
|
||||
["flash_attention", "sdp_attention", "xformers_attention"],
|
||||
)
|
||||
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
|
||||
cfg = DictDefault(
|
||||
|
||||
@@ -9,7 +9,6 @@ from pathlib import Path
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from e2e.utils import check_tensorboard
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
@@ -54,7 +53,7 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -62,7 +61,6 @@ class TestMultiGPULlama:
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -85,13 +83,9 @@ class TestMultiGPULlama:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
[1, 4],
|
||||
)
|
||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -118,15 +112,14 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -149,10 +142,6 @@ class TestMultiGPULlama:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
def test_dpo_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -191,7 +180,7 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -200,7 +189,6 @@ class TestMultiGPULlama:
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -223,10 +211,6 @@ class TestMultiGPULlama:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
def test_dpo_qlora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -265,8 +249,8 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
@@ -274,7 +258,6 @@ class TestMultiGPULlama:
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -297,13 +280,9 @@ class TestMultiGPULlama:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
[1, 4],
|
||||
)
|
||||
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -322,8 +301,8 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"max_steps": 10,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
@@ -344,7 +323,6 @@ class TestMultiGPULlama:
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -367,10 +345,6 @@ class TestMultiGPULlama:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_state_dict_type",
|
||||
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||
@@ -394,7 +368,7 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -416,7 +390,6 @@ class TestMultiGPULlama:
|
||||
"fsdp_state_dict_type": fsdp_state_dict_type,
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -439,10 +412,6 @@ class TestMultiGPULlama:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -475,7 +444,7 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -497,7 +466,6 @@ class TestMultiGPULlama:
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -520,41 +488,12 @@ class TestMultiGPULlama:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
[1, 4],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"deepspeed",
|
||||
[
|
||||
"deepspeed_configs/zero3_bf16.json",
|
||||
"deepspeed_configs/zero3_bf16_cpuoffload_all.json",
|
||||
# "deepspeed_configs/zero3_bf16_cpuoffload_params.json",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"qlora",
|
||||
[True, False],
|
||||
)
|
||||
def test_ds_zero3_packed(
|
||||
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
|
||||
):
|
||||
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
if qlora:
|
||||
adapter = {
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
else:
|
||||
adapter = {}
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -572,17 +511,15 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
|
||||
"use_tensorboard": True,
|
||||
**adapter,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -605,35 +542,19 @@ class TestMultiGPULlama:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"qlora",
|
||||
[True, False],
|
||||
)
|
||||
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
||||
def test_ds_zero3_qlora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
if qlora:
|
||||
adapter = {
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
else:
|
||||
adapter = {}
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
@@ -647,17 +568,15 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
|
||||
"use_tensorboard": True,
|
||||
**adapter,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -679,82 +598,3 @@ class TestMultiGPULlama:
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"qlora",
|
||||
[True, False],
|
||||
)
|
||||
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
||||
# pylint: disable=duplicate-code
|
||||
if qlora:
|
||||
adapter = {
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
else:
|
||||
adapter = {}
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||
"use_tensorboard": True,
|
||||
**adapter,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
from importlib import reload
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -21,6 +22,14 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reload_transformers():
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
yield
|
||||
reload(transformers.models.llama.modeling_llama)
|
||||
|
||||
|
||||
class TestFAXentropyLlama:
|
||||
"""
|
||||
Test case for Llama models using LoRA w multipack
|
||||
|
||||
@@ -7,7 +7,6 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
@@ -22,7 +21,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
@pytest.mark.skip("FIXME, mostly underused functionality")
|
||||
class TestFusedLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using Fused layers
|
||||
|
||||
Reference in New Issue
Block a user