Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
39ab9626f1 add transformers module to cleanup 2024-12-08 14:52:54 -05:00
Wing Lian
26bd81cec0 re-enable tests w change in patching 2024-12-08 14:52:09 -05:00
37 changed files with 217 additions and 1215 deletions

View File

@@ -13,13 +13,10 @@ jobs:
permissions:
contents: write
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Create release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: gh release create "$GITHUB_REF_NAME" --generate-notes
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
@@ -41,7 +38,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging
pip3 install --no-build-isolation -e .
pip3 install -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name

View File

@@ -44,11 +44,6 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
@@ -65,15 +60,11 @@ jobs:
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install --no-build-isolation -U -e .
pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help

View File

@@ -78,23 +78,19 @@ jobs:
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
- name: cleanup pip cache
run: |
@@ -124,7 +120,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
@@ -133,24 +129,20 @@ jobs:
- name: Install dependencies
run: |
pip3 show torch
python -m build --no-isolation --sdist
pip3 install --no-build-isolation dist/axolotl*.tar.gz
python3 setup.py sdist
pip3 install dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
- name: cleanup pip cache
run: |

View File

@@ -1,5 +1,4 @@
include requirements.txt
include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
recursive-include axolotl *.py

104
README.md
View File

@@ -10,13 +10,9 @@
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<br/>
</p>
<p align="center">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p>
@@ -46,8 +42,7 @@ Features:
- [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents)
- [Quickstart ⚡](#quickstart-)
- [Edge Builds](#edge-builds-)
- [Axolotl CLI Usage](#axolotl-cli-usage)
- [Usage](#usage)
- [Badge ❤🏷️](#badge-)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
@@ -112,49 +107,58 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
```bash
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# download examples and optionally deepspeed configs to the local path
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
# finetune using lora
axolotl train examples/llama-3/lora-1b.yml
```
### Edge Builds 🏎️
If you're looking for the latest features and updates between releases, you'll need to install
from source.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl.git
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
pip3 install -e '.[flash-attn,deepspeed]'
```
### Axolotl CLI Usage
We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/).
### Usage
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
### Axolotl CLI
If you've installed this package using `pip` from source, we now support a new, more
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
the above commands:
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml
# finetune lora
axolotl train examples/llama-3/lora-1b.yml
axolotl train examples/openllama-3b/lora.yml
# inference
axolotl inference examples/llama-3/lora-1b.yml \
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out"
# gradio
axolotl inference examples/llama-3/lora-1b.yml \
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
@@ -171,36 +175,6 @@ axolotl fetch deepspeed_configs
axolotl fetch examples --dest path/to/folder
```
### Legacy Usage
<details>
<summary>Click to Expand</summary>
While the Axolotl CLI is the preferred method for interacting with axolotl, we
still support the legacy `-m axolotl.cli.*` usage.
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
# inference
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
</details>
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
@@ -320,7 +294,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
3. Install Axolotl along with python dependencies
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
pip3 install -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Huggingface to use gated models/datasets.
```bash
@@ -399,7 +373,7 @@ Please use WSL or Docker!
Use the below instead of the install method in QuickStart.
```
pip3 install --no-build-isolation -e '.'
pip3 install -e '.'
```
More info: [mac.md](/docs/mac.qmd)

View File

@@ -31,9 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -1,10 +1,7 @@
#!/bin/bash
set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -52,7 +52,7 @@ export GPU_ARCHS="gfx90a"
cd flash-attention
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
pip install --no-build-isolation .
pip install .
```
### 6. Install Axolotl
@@ -63,7 +63,7 @@ Clone and install Axolotl:
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
pip install packaging ninja
pip install --no-build-isolation -e .
pip install -e .
```
### 7. Apply xformers Workaround

View File

@@ -71,7 +71,7 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
pip3 install -e '.[flash-attn,deepspeed]'
```
#### Remote Hosts
@@ -212,7 +212,7 @@ You will now be in the container. Next, perform an editable install of Axolotl:
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
pip3 install -e '.[flash-attn,deepspeed]'
```
### Attach To Container

View File

@@ -52,26 +52,6 @@ datasets:
type: chat_template.argilla
```
#### KTO
```yaml
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
remove_unused_columns: false
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
```
#### Using local dataset files
```yaml
datasets:

View File

@@ -24,7 +24,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install --no-build-isolation axolotl[deepspeed]"
"!pip install axolotl[deepspeed]"
]
},
{

View File

@@ -1,74 +0,0 @@
base_model: NousResearch/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,75 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/qlora-out
remove_unused_columns: false
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: false # not supported with kto
eval_sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 20
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,4 +1,4 @@
base_model: NousResearch/Llama-3.2-1B
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
@@ -22,6 +22,7 @@ pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj

View File

@@ -17,10 +17,3 @@ Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
[tool.setuptools_scm]
[tool.setuptools]
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"

View File

@@ -1,30 +1,22 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=2.3.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.4.2
# END section
packaging==23.2
peft==0.14.0
transformers>=4.46.3
tokenizers>=0.20.1
bitsandbytes==0.45.0
accelerate==1.2.0
datasets==3.1.0
deepspeed==0.16.1
deepspeed==0.15.4
pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
flash-attn==2.7.0.post2
sentencepiece
wandb
einops
xformers>=0.0.23.post1
optimum==1.16.2
hf_transfer
colorama
@@ -39,6 +31,11 @@ art
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq==0.2.7.post3
triton>=2.3.0
liger-kernel==0.4.2
mamba-ssm==1.2.0.post1
# remote filesystems
s3fs>=2024.5.0

View File

@@ -13,5 +13,5 @@ cd /workspace
rm -rf /workspace/axolotl
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip install --no-build-isolation --no-deps -e .
pip install --no-deps -e .
```

View File

@@ -1,10 +1,7 @@
"""setup.py for axolotl"""
import ast
import os
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from setuptools import find_packages, setup
@@ -93,24 +90,9 @@ def parse_requirements():
return _install_requires, _dependency_links
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__)))
/ "src"
/ "axolotl"
/ "__init__.py",
"r",
encoding="utf-8",
) as fin:
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
version_ = ast.literal_eval(version_match.group(1))
return version_
install_requires, dependency_links = parse_requirements()
setup(
version=get_package_version(),
package_dir={"": "src"},
packages=find_packages("src"),
install_requires=install_requires,
@@ -125,7 +107,7 @@ setup(
"flash-attn==2.7.0.post2",
],
"deepspeed": [
"deepspeed==0.16.1",
"deepspeed==0.15.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -1,3 +1,8 @@
"""Axolotl - Train and fine-tune large language models"""
__version__ = "0.6.0"
try:
from importlib.metadata import version
__version__ = version("axolotl")
except ImportError:
__version__ = "unknown"

View File

@@ -14,22 +14,17 @@ import os
import sys
from abc import abstractmethod
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import torch.nn.functional as F
import transformers
from datasets import Dataset
from liger_kernel.chunked_loss.fused_linear_preference import (
LigerFusedLinearPreferenceBase,
)
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import amp, nn
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -1082,15 +1077,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
self.dataset_tags = dataset_tags
self.optimizer = None
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
self.liger_loss = LigerFusedLinearDPOLoss(
ignore_index=self.label_pad_token_id,
beta=self.beta,
compute_nll_loss=True, # not same as rpo_alpha hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None,
use_ref_model=not self.reference_free,
)
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
@@ -1194,309 +1180,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics using Liger kernel."""
# return super().get_batch_loss_metrics(model, batch, train_eval)
if not self.liger_loss:
raise ValueError("Liger loss not initialized")
metrics = {}
model_output = self.concatenated_forward(model, batch)
# Get the lm_head weights and bias
lin_weight = model.lm_head.weight
lin_bias = getattr(model.lm_head, "bias", None)
hidden_states = model_output["hidden_states"]
labels = model_output["labels"]
if not self.reference_free:
# Adapted from DPO's compute_ref_log_probs
compte_ref_context_manager = (
amp.autocast("cuda")
if self._peft_has_been_casted_to_bf16
else nullcontext()
)
with torch.no_grad(), compte_ref_context_manager: # type: ignore
if self.ref_model is None:
with self.null_ref_context():
ref_model_output = self.concatenated_forward(self.model, batch)
ref_weight = self.model.lm_head.weight
ref_bias = getattr(self.model.lm_head, "bias", None)
ref_hidden_states = ref_model_output["hidden_states"]
else:
ref_model_output = self.concatenated_forward(self.ref_model, batch)
ref_weight = self.ref_model.lm_head.weight
ref_bias = getattr(self.ref_model.lm_head, "bias", None)
ref_hidden_states = ref_model_output["hidden_states"]
(
ref_chosen_logps,
ref_rejected_logps,
_ref_chosen_logits,
_ref_rejected_logits,
_ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk=ref_hidden_states,
weight=ref_weight,
target_chunk=labels,
bias=ref_bias,
# ignore_index=ignore_index,
compute_nll_loss=False,
)
else:
ref_hidden_states = None
ref_weight = None
ref_bias = None
# Compute loss using Liger kernel
loss, return_vars = self.liger_loss(
lin_weight=lin_weight,
_input=hidden_states,
target=labels,
bias=lin_bias, # TODO: check whether to pass bias as FCLE doesn't
ref_input=ref_hidden_states,
ref_weight=ref_weight,
ref_bias=ref_bias,
)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits_mean,
policy_rejected_logits_mean,
policy_nll_loss,
) = return_vars
# Calculate rewards
if not self.reference_free:
chosen_rewards = (
self.beta * (policy_chosen_logps - (ref_chosen_logps)).detach()
)
rejected_rewards = (
self.beta * (policy_rejected_logps - (ref_rejected_logps)).detach()
)
else:
chosen_rewards = self.beta * policy_chosen_logps
rejected_rewards = self.beta * policy_rejected_logps
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics.update(
{
f"{prefix}rewards/chosen": chosen_rewards.mean().cpu(),
f"{prefix}rewards/rejected": rejected_rewards.mean().cpu(),
f"{prefix}rewards/accuracies": reward_accuracies.mean().cpu(),
f"{prefix}rewards/margins": (chosen_rewards - rejected_rewards)
.mean()
.cpu(),
f"{prefix}logps/chosen": policy_chosen_logps.mean().cpu(),
f"{prefix}logps/rejected": policy_rejected_logps.mean().cpu(),
f"{prefix}logits/chosen": policy_chosen_logits_mean.cpu(),
f"{prefix}logits/rejected": policy_rejected_logits_mean.cpu(),
}
)
if hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu()
# TODO: Handle use_weighting, aux_loss_enabled as in upstream
return loss, metrics
def concatenated_forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
):
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
Overridden base function to return the hidden states and labels for the loss calculation.
"""
num_examples = batch["prompt_input_ids"].shape[0] # type: ignore
concatenated_batch = self.concatenated_inputs(
batch, padding_value=self.padding_value
)
model_kwargs = {}
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
# Add to get the hidden states for the loss
model_kwargs["output_hidden_states"] = True
# Add the pixel values and attention masks for vision models
if "pixel_values" in concatenated_batch:
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
if "pixel_attention_mask" in concatenated_batch:
model_kwargs["pixel_attention_mask"] = concatenated_batch[
"pixel_attention_mask"
]
if "image_sizes" in concatenated_batch:
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
prompt_input_ids = concatenated_batch["prompt_input_ids"]
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
completion_input_ids = concatenated_batch["completion_input_ids"]
completion_attention_mask = concatenated_batch["completion_attention_mask"]
if self.is_encoder_decoder:
labels = completion_input_ids
labels[completion_attention_mask == 0] = self.label_pad_token_id
outputs = model(
input_ids=prompt_input_ids,
attention_mask=prompt_attention_mask,
labels=labels, # we need the labels for the logits to be returned
**model_kwargs,
)
logits = outputs.logits
hidden_states = outputs.decoder_hidden_states[-1]
loss_mask = completion_attention_mask.bool()
else:
# Concatenate the prompt and completion inputs
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat(
(prompt_attention_mask, completion_attention_mask), dim=1
)
# Mask the prompt but not the completion for the loss
loss_mask = torch.cat(
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
dim=1,
)
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
for i in range(attention_mask.size(0)):
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) # type: ignore
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) # type: ignore
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) # type: ignore
# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = (
torch.nonzero(empty_cols)[0].item()
if empty_cols.any()
else attention_mask.size(1)
)
input_ids = input_ids[:, :first_empty_col] # type: ignore
attention_mask = attention_mask[:, :first_empty_col] # type: ignore
loss_mask = loss_mask[:, :first_empty_col] # type: ignore
# Truncate right
if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]
# if self.use_num_logits_to_keep:
# # Compute num_logits_to_keep based on loss_mask pattern:
# # [[0, 0, 0, x, x, x, x],
# # [0, 0, 0, x, x, x, 0]]
# # ^ start computing logits from here ([:, -(7-3+1):])
# first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
# num_logits_to_keep = loss_mask.shape[1] - first_compute_index
# model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label
outputs = model(
input_ids=input_ids, attention_mask=attention_mask, **model_kwargs
)
# Offset the logits by one to align with the labels
logits = outputs.logits[:, :-1, :]
hidden_states = outputs.hidden_states[-1][:, :-1, :]
labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()
# if self.use_num_logits_to_keep:
# # Align labels with logits
# # logits: -, -, [x2, x3, x4, x5, x6]
# # ^ --------- ^ after logits[:, :-1, :]
# # labels: [y0, y1, y2, y3, y4, y5, y6]
# # ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
# # loss_mask: [0, 0, 0, 1, 1, 1, 1]
# labels = labels[:, -num_logits_to_keep:]
# loss_mask = loss_mask[:, -num_logits_to_keep:]
# hidden_states = hidden_states[:, -num_logits_to_keep:, :]
if logits.shape[:2] != labels.shape[:2]:
# for llava, the returned logits include the image tokens (placed before the text tokens)
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]
hidden_states = hidden_states[:, -seq_len:]
# Compute the log probabilities of the labels
labels[
~loss_mask
] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(
logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
).squeeze(2)
per_token_logps[~loss_mask] = 0
all_logps = per_token_logps.sum(-1)
output = {}
if self.use_weighting:
with torch.no_grad():
# Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
logprobs = F.log_softmax(logits, dim=-1)
weights_adjustment_factor = torch.logsumexp(
2 * logprobs, dim=-1
) # same as sum(probs**2) in log space
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
all_weights = (per_token_logps_adjusted * loss_mask).sum(
-1
) / loss_mask.sum(-1)
chosen_weights = all_weights[:num_examples]
rejected_weights = all_weights[num_examples:]
output["policy_weights"] = torch.clamp(
torch.exp(chosen_weights + rejected_weights), max=1
)
if self.args.rpo_alpha is not None:
# Only use the chosen logits for the RPO loss
chosen_logits = logits[:num_examples]
chosen_labels = labels[:num_examples]
# Compute the log probabilities of the labels
output["nll_loss"] = F.cross_entropy(
torch.flatten(chosen_logits, end_dim=1),
torch.flatten(chosen_labels, end_dim=1),
ignore_index=0,
)
if self.loss_type == "ipo":
all_logps = all_logps / loss_mask.sum(-1)
output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]
output["mean_chosen_logits"] = logits[:num_examples][
loss_mask[:num_examples]
].mean()
output["mean_rejected_logits"] = logits[num_examples:][
loss_mask[num_examples:]
].mean()
output["hidden_states"] = hidden_states
output["labels"] = labels
if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss
return output
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1685,6 +1368,8 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
@@ -1692,6 +1377,7 @@ class TrainerBuilderBase(abc.ABC):
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
)
if self.cfg.use_comet and is_comet_available():
@@ -2480,14 +2166,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_kwargs["report_to"] = report_to
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,

View File

@@ -204,87 +204,3 @@ def patch_forward_for_ga():
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -28,8 +28,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
:return:
"""
max_length = self.prompter.max_length
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
@@ -41,16 +39,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
:max_length
]
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
@@ -64,18 +52,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
)
rejected_tokenized = super().tokenize_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
:max_length
]
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
:max_length
]
return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
@@ -104,9 +80,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": (
cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len
),
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
}
strategy_params = {

View File

@@ -42,7 +42,6 @@ class ChatTemplatePrompter(Prompter):
"gpt": "assistant",
"system": "system",
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.message_field_training = message_field_training
@@ -54,9 +53,21 @@ class ChatTemplatePrompter(Prompter):
self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
}
for t in conversation
]
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
if self.processor:
text = self.processor.apply_chat_template(
conversation,
turns,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
@@ -65,6 +76,8 @@ class ChatTemplatePrompter(Prompter):
text=text,
images=images,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
@@ -75,7 +88,9 @@ class ChatTemplatePrompter(Prompter):
return batch
return self.tokenizer.apply_chat_template(
conversation,
turns,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
@@ -200,14 +215,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos
self.images = "images"
@@ -254,28 +262,30 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt
turns = self.get_conversation_thread(prompt)
turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
for index, turn in enumerate(turns):
role = turn.get("role")
content = turn.get("content")
train_turn = turn.get("training")
train_detail = turn.get("training_detail")
role = turn.get(self.prompter.message_field_role)
content = turn.get(self.prompter.message_field_content)
train_turn = turn.get(self.prompter.message_field_training)
train_detail = turn.get(self.prompter.message_field_training_detail)
LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
)
should_train = None
if train_turn is not None:
should_train = train_turn
elif train_detail is not None:
should_train = bool(train_detail)
else:
should_train = self.train_on_inputs or role in self.roles_to_train
should_train = (
train_turn
if train_turn is not None
else (
bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
)
LOG.debug(f"Should train: {should_train}")
@@ -283,9 +293,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
conversation_ids=input_ids, turn=index, turn_content=turn
)
if turn_start_idx == -1 or turn_end_idx == -1:
LOG.warning(f"Failed to find boundaries for turn {index}")
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
@@ -306,9 +313,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
LOG.debug(
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
LOG.debug(f"Labels after processing turn {index}: {labels}")
@@ -346,73 +351,52 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i
return -1
def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
def find_turn(self, conversation_ids, turn, turn_content):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
"""
content = turn_content.get("content")
content = turn_content.get(self.prompter.message_field_content, "")
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
if not content_ids:
LOG.warning(f"Empty content for turn {turn}")
return -1, -1
# Locate the starting index after the specified number of EOS tokens
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break
# For first turn, start from beginning
if turn == 0:
start_search_idx = 0
# Find the start index of the content within the conversation
start_idx = -1
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break
if start_idx != -1:
end_idx = start_idx + len(content_ids)
else:
# For subsequent turns, find the previous EOS token
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
end_idx = -1
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn: # Find the nth EOS token where n = turn
start_search_idx = i + 1
break
# we can optimize this to only search for a few tokens from start_search_idx
# but it would risk missing the content if it's not found within the first few tokens or
# if start_search_idx cannot be found above.
last_index = len(conversation_ids) - len(content_ids) + 1
if last_index < start_search_idx:
LOG.warning(
f"last_index to search is less than start_search_idx for turn {turn}"
)
return -1, -1
# Search for content starting from start_search_idx
first_elem = content_ids[0]
for i in range(start_search_idx, last_index):
# Quick check of first element before doing full comparison
if conversation_ids[i] == first_elem:
# Check if the rest of the content matches
if conversation_ids[i : i + len(content_ids)] == content_ids:
LOG.debug(f"Found turn {turn} content at position {i}")
return i, i + len(content_ids)
return -1, -1
return start_idx, end_idx
def get_conversation_thread(self, prompt):
turns = [
{
"role": self.prompter.roles[t[self.prompter.message_field_role]],
"content": t[self.prompter.message_field_content],
"training": t.get(self.prompter.message_field_training),
"training_detail": t.get(self.prompter.message_field_training_detail),
}
for t in prompt[self.messages]
]
if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
return turns
return prompt[self.messages]
def get_images(self, prompt):
return prompt.get(self.images, None)

View File

@@ -259,7 +259,14 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
try:
# Check to make sure the base model is from HuggingFace not a local directory
hf_api = HfApi()
hf_api.model_info(cfg.base_model)
model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
@@ -267,22 +274,16 @@ def train(
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model:
dataset_tags = [
model_card_kwarg["dataset_name"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_name"] = dataset_tags
else:
dataset_tags = [
model_card_kwarg["dataset_tags"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError):
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated

View File

@@ -66,7 +66,10 @@ class EvalFirstStepCallback(
control: TrainerControl,
**kwargs,
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step == 1
):
control.should_evaluate = True
return control

View File

@@ -1475,27 +1475,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_kto_config(cls, data):
if data.get("rl") == "kto":
if data.get("sample_packing") or data.get("eval_sample_packing"):
raise ValueError("sample_packing is not supported with kto")
if data.get("remove_unused_columns") is not False:
raise ValueError("Set `remove_unused_columns: False` when using kto")
if data.get("gradient_checkpointing") and not (
data.get("gradient_checkpointing_kwargs")
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
):
raise ValueError(
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -386,12 +386,6 @@ class ModelLoader:
)
patch_training_loop_for_fsdp()
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
from axolotl.monkeypatch.trainer_grad_accum import (
patch_training_loop_for_deepspeed_0_16_x,
)
patch_training_loop_for_deepspeed_0_16_x()
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper

View File

@@ -1,104 +0,0 @@
"""
dynamic requirements for axolotl
"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools.command.build_py import build_py as _build_py
# pylint: disable=duplicate-code
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
class BuildPyCommand(_build_py):
"""
custom build_py command to parse dynamic requirements
"""
def finalize_options(self):
super().finalize_options()
install_requires, _ = parse_requirements()
self.distribution.install_requires = install_requires

View File

@@ -1,10 +0,0 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli
def test_print_version(cli_runner):
"""Test that version is printed when --version is used."""
result = cli_runner.invoke(cli, ["--version"])
assert result.exit_code == 0
assert "axolotl, version " in result.output

View File

@@ -120,15 +120,9 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM,
)
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
@@ -137,8 +131,6 @@ def cleanup_monkeypatches():
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
@@ -146,25 +138,16 @@ def cleanup_monkeypatches():
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama",),
(
"transformers.models.llama.modeling_llama",
["LlamaFlashAttention2", "LlamaAttention"],
),
("transformers.trainer",),
("transformers", ["Trainer"]),
("transformers",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
spec = importlib.util.spec_from_file_location(
module_name, sys.modules[module_name].__file__
)
sys.modules[module_name] = importlib.util.module_from_spec(spec)
spec.loader.exec_module(sys.modules[module_name])
sys.modules[module_name] = importlib.reload(sys.modules[module_name])
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:

View File

@@ -71,11 +71,7 @@ class TestCutCrossEntropyIntegration:
@pytest.mark.parametrize(
"attention_type",
[
"flash_attention",
"sdp_attention",
# "xformers_attention",
],
["flash_attention", "sdp_attention", "xformers_attention"],
)
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
cfg = DictDefault(

View File

@@ -9,7 +9,6 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
@@ -54,7 +53,7 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -62,7 +61,6 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
}
)
@@ -85,13 +83,9 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
[1, 4],
)
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
@@ -118,15 +112,14 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
}
)
@@ -149,10 +142,6 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -191,7 +180,7 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -200,7 +189,6 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
}
)
@@ -223,10 +211,6 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -265,8 +249,8 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
@@ -274,7 +258,6 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
}
)
@@ -297,13 +280,9 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
[1, 4],
)
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
@@ -322,8 +301,8 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"max_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
@@ -344,7 +323,6 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
}
)
@@ -367,10 +345,6 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"fsdp_state_dict_type",
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
@@ -394,7 +368,7 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -416,7 +390,6 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": fsdp_state_dict_type,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
}
)
@@ -439,10 +412,6 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -475,7 +444,7 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -497,7 +466,6 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
}
)
@@ -520,41 +488,12 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
[1, 4],
)
@pytest.mark.parametrize(
"deepspeed",
[
"deepspeed_configs/zero3_bf16.json",
"deepspeed_configs/zero3_bf16_cpuoffload_all.json",
# "deepspeed_configs/zero3_bf16_cpuoffload_params.json",
],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero3_packed(
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
):
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -572,17 +511,15 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
"use_tensorboard": True,
**adapter,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
@@ -605,35 +542,19 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
def test_ds_zero3_qlora_packed(self, temp_dir):
# pylint: disable=duplicate-code
if qlora:
adapter = {
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
@@ -647,17 +568,15 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
**adapter,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
@@ -679,82 +598,3 @@ class TestMultiGPULlama:
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
# pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
**adapter,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)

View File

@@ -4,6 +4,7 @@ E2E tests for lora llama
import logging
import os
from importlib import reload
from pathlib import Path
import pytest
@@ -21,6 +22,14 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.fixture(autouse=True)
def reload_transformers():
import transformers.models.llama.modeling_llama
yield
reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama:
"""
Test case for Llama models using LoRA w multipack

View File

@@ -7,7 +7,6 @@ import os
import unittest
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
@@ -22,7 +21,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("FIXME, mostly underused functionality")
class TestFusedLlama(unittest.TestCase):
"""
Test case for Llama models using Fused layers