Compare commits

..

33 Commits

Author SHA1 Message Date
Wing Lian
105c65390e add q-galore optimizer 2024-07-14 19:28:13 -04:00
Wing Lian
78e12f8ca5 add basic support for the optimi adamw optimizer (#1727)
* add support for optimi_adamw optimizer w kahan summation

* pydantic validator for optimi_adamw

* workaround for setting optimizer for fsdp

* make sure to install optimizer packages

* make sure to have parity for model parameters passed to optimizer

* add smoke test for optimi_adamw optimizer

* don't use foreach optimi by default
2024-07-14 19:12:57 -04:00
Wing Lian
98af5388ba bump flash attention 2.5.8 -> 2.6.1 (#1738)
* bump flash attention 2.5.8 -> 2.6.1

* use triton implementation of cross entropy from flash attn

* add smoke test for flash attn cross entropy patch

* fix args to xentropy.apply

* handle tuple from triton loss fn

* ensure the patch tests run independently

* use the wrapper already built into flash attn for cross entropy

* mark pytest as forked for patches

* use pytest xdist instead of forked, since cuda doesn't like forking

* limit to 1 process and use dist loadfile for pytest

* change up pytest for fixture to reload transformers w monkeypathc
2024-07-14 19:11:31 -04:00
RodriMora
219cd0d3c5 Fix eval_sample_packing in llama-3 lora example (#1716) [skip ci]
* Fix eval_sample_packing in llama-3 lora example

* Update examples/llama-3/lora-8b.yml

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-07-13 14:34:44 -04:00
David Meikle
634f384e06 Changed URL for dataset docs (#1744) 2024-07-13 14:34:28 -04:00
Akshaya Shanbhogue
4512738a73 bump xformers to 0.0.27 (#1740)
* Update requirements.txt

Preserve compatibility with torch 2.3.1. [Reference](https://github.com/facebookresearch/xformers/issues/1052)

* fix setup.py to extract the current xformers dep from requirements for replacement

* xformers 0.0.27 wheels not built for torch 2.3.0

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-07-13 14:04:31 -04:00
Wing Lian
1e57b4c562 update to pytorch 2.3.1 (#1746) [skip ci] 2024-07-13 13:28:17 -04:00
Wing Lian
a4a5bf057f fixes to prevent vram spike when train starts (#1742) 2024-07-13 09:53:13 -04:00
Wing Lian
137d84d1b4 add torch 2.3.1 base image (#1745) 2024-07-13 09:41:51 -04:00
Oliver Klingefjord
18abdb447a typo (#1685) [skip ci]
* typo

* typo 2

---------

Co-authored-by: mhenrichsen <mads.gade.henrichsen@live.dk>
2024-07-12 21:24:01 -04:00
Wing Lian
47e1916484 add tests so CI can catch updates where patches will break with unsloth (#1737) [skip ci] 2024-07-11 16:43:19 -04:00
mhenrichsen
1194c2e0b1 github urls (#1734)
Co-authored-by: Henrichsen, Mads (ext) <mads.henrichsen.ext@siemens-energy.com>
2024-07-11 09:19:29 -04:00
Wing Lian
a159724e44 bump trl and accelerate for latest releases (#1730)
* bump trl and accelerate for latest releases

* ensure that the CI runs on new gh org

* drop kto_pair support since removed upstream
2024-07-10 11:15:44 -04:00
Josh Bleecher Snyder
b3f680d305 sanity check ranges in freeze.py (#1686)
* sanity check ranges in freeze.py

this will catch problems earlier and more clearly.

in my case, it appears that deepspeed zero3 sets layer tensor shapes
to [0], which doesn't play well with automatically inferred ranges.
through a bit of luck, inverting ranges still appears to work correctly.

* simplify chained comparison
2024-07-05 09:24:07 -04:00
Wing Lian
c69b7eb2b5 full weights fsdp training seems broken with fsdp_cpu_ram_efficient_loading, disabling for now (#1726) 2024-07-05 09:15:36 -04:00
Wing Lian
c6d83a87c4 add support for .env files for env vars (#1724) 2024-07-02 13:17:40 -04:00
Wing Lian
5370cedf0c support for gemma2 w sample packing (#1718) 2024-06-29 01:38:55 -04:00
Josh Bleecher Snyder
f2480a1d91 improve Pre-Tokenized Dataset docs (#1684) [skip ci]
Fixes #1661
2024-06-26 13:13:21 -07:00
DavidFarago
559562d790 Allow "weight: 0" in messages to mask them (#1703)
Allow in message objects the additional key `weight`, which can be set
to 0 (or 1) to cause that message to be masked out (or left unmasked)
for training (similar to [1]). This is helpful for training the model to be robust and
capable of error recovery upon a bad assistant message.
A missing `weight` key defaults to weight 1, to guarantee downward compatibility.

[1]: https://github.com/mistralai/mistral-finetune
2024-06-20 10:05:16 -04:00
Wing Lian
4de4b4089f add support for multipack for deepseek_v2 (#1712) 2024-06-20 10:02:55 -04:00
Wing Lian
3f1f5e3312 drop length column for issues with eval without packing (#1711) 2024-06-18 23:32:29 -04:00
Wing Lian
5783839c6e download model weights on preprocess step (#1693) 2024-06-09 20:10:17 -04:00
Wing Lian
cbbf039a46 verbose failure message (#1694) 2024-06-09 20:09:36 -04:00
Wing Lian
851ccb1237 bump deepspeed for fix for grad norm compute putting tensors on different devices (#1699) 2024-06-09 17:13:28 -04:00
Wing Lian
18cabc0c46 fix for when sample_packing and eval_sample_packing are different (#1695) 2024-06-08 09:48:30 -04:00
Wing Lian
ed8ef65371 add back packing efficiency estimate so epochs and multi-gpu works properly (#1697) 2024-06-08 09:48:10 -04:00
Wing Lian
00ac3022a1 add qwen2-72b fsdp example (#1696) 2024-06-07 16:38:29 -04:00
Wing Lian
9c1af1a9c0 ensure explicit eval_sample_packing to avoid mismatch issues (#1692) 2024-06-07 11:28:43 -04:00
Aaditya Ura (looking for PhD Fall’24)
a82a711522 Create phi3-ft-fsdp.yml (#1580)
rename to be fsdp specific and tweak settings a bit
2024-06-04 16:20:25 -04:00
Brian Fitzgerald
cf64284a04 Phi-3 conversation format, example training script and perplexity metric (#1582)
* phi-3 support and perplexity metric

* phi-3 chat template

* metrics updates

* chore: lint

* fix assertion on Tensor

* fix tests since tokenization happens in the metric

* fix perplexity value of shorter passage

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-06-04 16:11:56 -04:00
Wing Lian
c996881ec2 add support for rpo_alpha (#1681)
* add support for rpo_alpha

* Add smoke test for dpo + nll loss
2024-06-04 16:09:51 -04:00
Wing Lian
1f151c0d52 re-enable DPO for tests in modal ci (#1374)
* re-enable DPO for tests in modal ci

* workaround for training args

* don't mixin AxolotlTrainingArguments

* fix mixin order so MRO doesn't result in

 TypeError: non-default argument follows default argument error

* use smaller datasets for dpo tests
2024-06-03 12:50:44 -04:00
Saeed Esmaili
5cde06587a Fix the broken link in README (#1678) [skip ci] 2024-06-03 09:38:44 -04:00
66 changed files with 1281 additions and 174 deletions

View File

@@ -21,12 +21,12 @@ All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT
## Getting Started
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
Bugs? Please check for open issue else create a new [Issue](https://github.com/axolotl-ai-cloud/axolotl/issues/new).
PRs are **greatly welcome**!
1. Fork the repository and clone it to your local machine.
2. Set up the development environment by following the instructions in the [README.md](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/README.md) file.
2. Set up the development environment by following the instructions in the [README.md](https://github.com/axolotl-ai-cloud/axolotl/tree/main/README.md) file.
3. Explore the codebase, run tests, and verify that everything works as expected.
Please run below to setup env
@@ -42,11 +42,11 @@ pytest tests/
### Reporting Bugs
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/axolotl-ai-cloud/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
### Suggesting Enhancements
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/axolotl-ai-cloud/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
### Submitting Pull Requests

View File

@@ -15,7 +15,7 @@ body:
label: "Please check that this issue hasn't been reported before."
description: "The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Bug Reports](https://github.com/OpenAccess-AI-Collective/axolotl/labels/bug) didn't find any similar reports."
- label: "I searched previous [Bug Reports](https://github.com/axolotl-ai-cloud/axolotl/labels/bug) didn't find any similar reports."
required: true
- type: textarea

View File

@@ -1,7 +1,7 @@
blank_issues_enabled: false
contact_links:
- name: Ask a question
url: https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/q-a
url: https://github.com/axolotl-ai-cloud/axolotl/discussions/categories/q-a
about: Ask questions and discuss with other community members
- name: Discuss the Project in Discord
url: https://discord.gg/HhrNrHJPRb

View File

@@ -10,7 +10,7 @@ body:
value: |
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues).
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/axolotl-ai-cloud/axolotl/issues).
- type: textarea
attributes:
label: What piece of documentation is affected?

View File

@@ -8,9 +8,9 @@ body:
label: "⚠️ Please check that this feature request hasn't been suggested before."
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Ideas in Discussions](https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
- label: "I searched previous [Ideas in Discussions](https://github.com/axolotl-ai-cloud/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
required: true
- label: "I searched previous [Issues](https://github.com/OpenAccess-AI-Collective/axolotl/labels/enhancement) didn't find any similar feature requests."
- label: "I searched previous [Issues](https://github.com/axolotl-ai-cloud/axolotl/labels/enhancement) didn't find any similar feature requests."
required: true
- type: textarea

View File

@@ -5,7 +5,7 @@ on:
jobs:
build-base:
if: github.repository_owner == 'OpenAccess-AI-Collective'
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: axolotl-gpu-runner
strategy:
@@ -37,6 +37,11 @@ jobs:
python_version: "3.11"
pytorch: 2.3.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3

View File

@@ -8,7 +8,7 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -19,7 +19,6 @@ jobs:
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
@@ -33,8 +32,9 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
pytorch: 2.3.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -70,7 +70,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -80,7 +80,6 @@ jobs:
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
@@ -94,8 +93,9 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
pytorch: 2.3.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -128,7 +128,7 @@ jobs:
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -136,7 +136,7 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
pytorch: 2.3.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -7,7 +7,7 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -18,7 +18,6 @@ jobs:
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
@@ -32,8 +31,9 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
pytorch: 2.3.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -70,7 +70,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -80,7 +80,6 @@ jobs:
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
@@ -94,8 +93,9 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
pytorch: 2.3.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -58,7 +58,7 @@ jobs:
pytest --ignore=tests/e2e/ tests/
docker-e2e-tests:
if: github.repository_owner == 'OpenAccess-AI-Collective'
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60
@@ -87,7 +87,7 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
pytorch: 2.3.1
num_gpus: 1
steps:
- name: Checkout

6
.gitignore vendored
View File

@@ -176,3 +176,9 @@ qlora-out/*
mlruns/*
/.quarto/
prepared-datasets/
submit.sh
*.out*
typings/
out/

View File

@@ -67,8 +67,8 @@ Features:
<p>
Go ahead and Axolotl questions!!
</p>
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
</div>
</div>
@@ -107,7 +107,7 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
pip3 install packaging ninja
@@ -132,7 +132,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
## Advanced Setup
@@ -333,7 +333,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
See [these docs](https://openaccess-ai-collective.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
### Config
@@ -609,7 +609,7 @@ If you decode a prompt constructed by axolotl, you might see spaces between toke
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/finetuning/05_tokenizer_gotchas.html) for a concrete example.
## Debugging Axolotl
@@ -626,10 +626,10 @@ Need dedicated support? Please contact us at [✉wing@openaccessaicollective.
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
```
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
## Community Showcase
@@ -647,7 +647,7 @@ PocketDoc Labs
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues/bug) else create a new Issue.
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
PRs are **greatly welcome**!
@@ -665,7 +665,7 @@ pre-commit run --all-files
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>

View File

@@ -14,7 +14,7 @@ website:
- icon: twitter
href: https://twitter.com/axolotl_ai
- icon: github
href: https://github.com/OpenAccess-AI-Collective/axolotl/
href: https://github.com/axolotl-ai-cloud/axolotl/
- icon: discord
href: https://discord.gg/7m9sfhzaf3

View File

@@ -14,7 +14,7 @@ RUN apt-get update && \
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
@@ -24,13 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image
RUN pip install pytest
RUN pip install -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -15,16 +15,16 @@ RUN apt-get update && \
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -16,7 +16,7 @@ RUN apt-get update && \
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl

View File

@@ -138,7 +138,7 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair'
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing

View File

@@ -4,9 +4,25 @@ description: How to use a custom pre-tokenized dataset.
order: 5
---
- Do not pass a `type:` in your axolotl config.
- Pass an empty `type:` in your axolotl config.
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
- To indicate that a token should be ignored during training, set its corresponding label to `-100`.
- Do not add BOS/EOS. Axolotl will add them for you based on the default tokenizer for the model you're using.
- For pretraining, do not truncate/pad documents to the context window length.
- For instruction training, documents must be truncated/padded as desired.
Sample config:
```{.yaml filename="config.yml"}
- path: ...
datasets:
- path: /path/to/your/file.jsonl
ds_type: json
type:
```
Sample jsonl:
```jsonl
{"input_ids":[271,299,99],"attention_mask":[1,1,1],"labels":[271,-100,99]}
{"input_ids":[87,227,8383,12],"attention_mask":[1,1,1,1],"labels":[87,227,8383,12]}
```

View File

@@ -192,7 +192,7 @@ Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl
On the host that is running axolotl (ex: if you are using a remote host), clone the axolotl repo and change your current directory to the root:
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
```

View File

@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
> See the [example config](#example-config) file in addition to reading these instructions.
1. Set `adapter: qlora` in your axolotl config file.
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Example Config
@@ -29,7 +29,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
## References
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
- [PR #1378](https://github.com/axolotl-ai-cloud/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
- Related HuggingFace PRs Enabling FDSP + QLoRA:
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )

View File

@@ -25,7 +25,7 @@ description: "Template-free prompt construction with the `input_output` format"
### Masking Inputs
One of the most popular features of
[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
setting the following configuration value:
@@ -33,7 +33,7 @@ setting the following configuration value:
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset)
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.

View File

@@ -44,7 +44,7 @@
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
]
@@ -171,7 +171,7 @@
},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
@@ -188,7 +188,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]

68
examples/gemma2/qlora.yml Normal file
View File

@@ -0,0 +1,68 @@
base_model: google/gemma-2-9b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
chat_template: gemma
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: gemma
drop_system_message: true
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -15,6 +15,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora

View File

@@ -0,0 +1,83 @@
base_model: microsoft/Phi-3-mini-4k-instruct
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./phi-sft-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
trust_remote_code: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project: phi3
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 12
num_epochs: 2
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 0.000003
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Phi3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"

64
examples/phi/phi3-ft.yml Normal file
View File

@@ -0,0 +1,64 @@
base_model: microsoft/Phi-3-mini-4k-instruct
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
chat_template: phi_3
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca:phi
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 5.0e-6
train_on_inputs: false
group_by_length: false
bf16: auto
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: True
early_stopping_patience: 3
logging_steps: 1
flash_attention: true
eval_steps: 1000
save_steps: 5000
eval_table_size: 2
eval_batch_size: 2
eval_sample_packing: false
eval_max_new_tokens: 32
eval_causal_lm_metrics: ["perplexity"]
do_causal_lm_eval: true
warmup_ratio: 0.2
debug: true
weight_decay: 0.1
resize_token_embeddings_to_32x: true

View File

@@ -0,0 +1,75 @@
base_model: Qwen/Qwen2-7B
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:

View File

@@ -1 +1,2 @@
pytest
pytest-xdist

View File

@@ -1,22 +1,22 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.41.1
transformers==4.42.3
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.30.1
deepspeed==0.14.2
accelerate==0.32.0
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
datasets==2.19.1
flash-attn==2.5.8
flash-attn==2.6.1
sentencepiece
wandb
einops
xformers==0.0.26.post1
xformers==0.0.27
optimum==1.16.2
hf_transfer
colorama
@@ -31,6 +31,7 @@ art
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
mamba-ssm==1.2.0.post1
@@ -39,6 +40,6 @@ s3fs
gcsfs
# adlfs
trl==0.8.6
trl==0.9.6
zstandard==0.22.0
fastcore

View File

@@ -11,7 +11,7 @@ Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace a
```
cd /workspace
rm -rf /workspace/axolotl
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip install --no-deps -e .
```

View File

@@ -29,9 +29,10 @@ def parse_requirements():
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
_install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
@@ -49,12 +50,14 @@ def parse_requirements():
raise ValueError("Invalid version format")
if (major, minor) >= (2, 3):
pass
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
@@ -77,13 +80,13 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.5.8",
"flash-attn==2.6.1",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed==0.14.2",
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
"deepspeed-kernels",
],
"mamba-ssm": [
@@ -101,5 +104,12 @@ setup(
"galore": [
"galore_torch",
],
"optimizers": [
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"q-galore-torch==1.0",
"torch-optimi==0.2.1",
],
},
)

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import (
do_inference,
@@ -33,4 +34,5 @@ def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
@@ -48,4 +49,5 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -7,7 +7,10 @@ from typing import Union
import fire
import transformers
from accelerate import init_empty_weights
from colorama import Fore
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM
from axolotl.cli import (
check_accelerate_default_config,
@@ -71,6 +74,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.download:
model_name = parsed_cfg.base_model
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
@@ -79,4 +87,5 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -7,6 +7,7 @@ from typing import Union
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
@@ -40,4 +41,5 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -6,6 +6,7 @@ from pathlib import Path
from typing import Tuple, Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -67,4 +68,5 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -40,6 +40,7 @@ class PreprocessCliArgs:
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
def load_model_and_tokenizer(

View File

@@ -30,7 +30,7 @@ from transformers import (
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
@@ -226,6 +226,12 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
@dataclass
@@ -238,6 +244,13 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
@@ -278,25 +291,66 @@ class AxolotlTrainer(Trainer):
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "q_galore_adamw8bit"]
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", None
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "q_galore_adamw8bit":
from q_galore_torch import QGaLoreAdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
QGaLoreAdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -380,6 +434,7 @@ class AxolotlTrainer(Trainer):
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
@@ -405,6 +460,7 @@ class AxolotlTrainer(Trainer):
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
lengths=get_dataset_lengths(self.eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
@@ -450,6 +506,8 @@ class AxolotlTrainer(Trainer):
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if eval_dataset:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
@@ -1080,6 +1138,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_steps = min(int(0.03 * total_num_steps), 100)
if warmup_steps == 1:
warmup_steps = 2
logging_steps = (
self.cfg.logging_steps
@@ -1383,6 +1443,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.optimizer in ["optimi_adamw", "q_galore_adamw8bit"]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
@@ -1608,7 +1673,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_cls = AxolotlTrainingArguments
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
@@ -1655,8 +1722,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair":
dpo_trainer_kwargs["loss_type"] = "kto_pair"
if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
@@ -1665,7 +1730,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]
@@ -1680,7 +1745,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl == "kto":
elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
else:

View File

View File

@@ -104,17 +104,12 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled
if cross_entropy:
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
# skip only if explicitly disabled
if rms_norm:
@@ -130,7 +125,7 @@ def replace_llama_attn_with_flash_attn(
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
@@ -826,7 +821,6 @@ def llama_model_forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

View File

@@ -145,7 +145,7 @@ def flashattn_forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -422,6 +422,9 @@ def mistral_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions

View File

@@ -16,8 +16,10 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"falcon",
"phi",
"gemma",
"gemma2",
"gemmoe",
"starcoder2",
"deepseek_v2",
]
@@ -48,6 +50,10 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
@@ -56,6 +62,8 @@ def patch_for_multipack(model_type, model_name=None):
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
def patch_remote(model_name, config_name, modeling_name):

View File

@@ -80,8 +80,9 @@ def get_forward_code() -> str:
return forward
def test_cel_is_patchable() -> bool:
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
@@ -90,9 +91,10 @@ def get_self_attn_code() -> str:
return forward
def test_self_attn_is_patchable() -> bool:
def check_self_attn_is_patchable() -> bool:
qkv = get_self_attn_code()
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
qkv, _ = detab_code(qkv)
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
def integrate_cross_entropy_loss_patch():

View File

@@ -2,9 +2,12 @@
import importlib
import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg):
try:
@@ -22,5 +25,8 @@ def load(strategy, tokenizer, cfg, ds_cfg):
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
return None

View File

@@ -23,6 +23,7 @@ class ChatTemplatePrompter(Prompter):
message_field_role: str = "from",
message_field_content: str = "value",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
@@ -39,6 +40,7 @@ class ChatTemplatePrompter(Prompter):
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
@@ -49,6 +51,9 @@ class ChatTemplatePrompter(Prompter):
for t in conversation
]
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
return self.tokenizer.apply_chat_template(
turns,
truncation=True,
@@ -111,6 +116,11 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
else "value"
)
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
drop_system_message = (
ds_cfg["drop_system_message"]
if ds_cfg and "drop_system_message" in ds_cfg
else False
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
@@ -119,6 +129,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
drop_system_message=drop_system_message,
),
tokenizer,
cfg.train_on_inputs,

View File

@@ -56,7 +56,9 @@ class ORPODatasetParsingStrategy:
messages: List[Message] = []
if system := prompt.get("system", None):
messages.append(Message(role="system", content=system, label=False))
messages.append(Message(role="user", content=prompt["prompt"], label=False))
messages.append(
Message(role="user", content=prompt["chosen"][0]["content"], label=False)
)
messages.append(
Message(
role="assistant", content=prompt["chosen"][1]["content"], label=True
@@ -70,7 +72,9 @@ class ORPODatasetParsingStrategy:
messages: List[Message] = []
if system := prompt.get("system", None):
messages.append(Message(role="system", content=system, label=False))
messages.append(Message(role="user", content=prompt["prompt"], label=False))
messages.append(
Message(role="user", content=prompt["rejected"][0]["content"], label=False)
)
messages.append(
Message(
role="assistant", content=prompt["rejected"][1]["content"], label=True
@@ -152,8 +156,8 @@ class ORPOTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
# pass the rejected prompt/row to the Prompter to get the formatted prompt
prompt_len = 0
rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
prompt
rejected_message_list: MessageList = (
self.dataset_parser.get_rejected_conversation_thread(prompt)
)
input_ids = []
labels = []
@@ -174,7 +178,9 @@ class ORPOTokenizingStrategy(PromptTokenizingStrategy):
rejected_input_ids = input_ids
rejected_labels = labels
# pass the chosen prompt/row to the Prompter to get the formatted prompt
chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
chosen_message_list: MessageList = (
self.dataset_parser.get_chosen_conversation_thread(prompt)
)
input_ids = []
labels = []
for _, (part, label) in enumerate(

View File

@@ -143,6 +143,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
"weight": 1
if "weight" not in t or t["weight"] is None
else t["weight"],
}
for t in conversations
]

View File

@@ -377,7 +377,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
LOG.warning(f"expected tuple, got {part}")
continue
role, content = part
if len(part) <= 2:
role, content = part
weight = 1
else:
role, content, weight = part
# Uses "in" because role contains extra characters
input_turn = any(r.lower() in role.lower() for r in input_roles)
@@ -403,7 +407,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
add_eos_token=False,
strip_bos_token=True,
)
if self.train_on_inputs:
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
@@ -439,13 +443,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
if weight == 0:
# everything from this is masked out from the labels
# (role is masked out too because it makes no sense if contents is masked out)
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif empty_role:
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
if self.train_on_inputs:
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels

View File

@@ -20,6 +20,7 @@ class PromptStyle(Enum):
INSTRUCT = "instruct"
CHAT = "chat"
CHATML = "chatml"
PHI = "phi"
class Prompter:
@@ -38,9 +39,9 @@ class AlpacaPrompter(Prompter):
system_format: str = "{system}"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
prompt_style: Optional[str] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
def __init__(self, prompt_style: Optional[str] = PromptStyle.INSTRUCT.value):
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
self.match_prompt_style()
@@ -52,16 +53,20 @@ class AlpacaPrompter(Prompter):
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.system_format = "{system}\n\n"
if self.prompt_style == PromptStyle.CHAT.value:
elif self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
elif self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
elif self.prompt_style == PromptStyle.PHI.value:
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
self.system_format = "<|system|>{system}\n"
def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
@@ -314,6 +319,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
conv = self._conversation.copy()
original_source = source.copy()
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.set_system_message(source[0]["value"])
@@ -355,8 +361,27 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])
return conv.get_turns()
turns = list(conv.get_turns())
original_source_length = len(original_source)
assert len(turns) in [
original_source_length - 1,
original_source_length,
original_source_length + 1,
]
if len(turns) == original_source_length + 1:
original_source = [{"weight": None}] + original_source
elif len(turns) == original_source_length - 1:
original_source = original_source[1:]
return [
(*turn, weight)
for turn, weight in zip(
turns,
[
1 if "weight" not in e or e["weight"] is None else e["weight"]
for e in original_source
],
)
]
def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)
@@ -381,12 +406,14 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
role_key_tool=role_key_tool,
roles=roles,
)

View File

@@ -144,7 +144,7 @@ def train(
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
)
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import logging
import math
import os
import traceback
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Dict, List
@@ -30,6 +31,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import (
barrier,
@@ -374,10 +376,14 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
if metric == "perplexity":
max_seq_len = self.cfg.eval_max_new_tokens
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
else:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics
def on_evaluate(
@@ -421,13 +427,20 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
# Only pass the kwargs that are in the metric's feature list
metric_kwargs = {
k: kwargs[k]
for k in metric._feature_names() # pylint: disable=protected-access
if k in kwargs
}
metric_score = metric.compute(**metric_kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
traceback.print_exc()
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
@@ -443,11 +456,12 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
if score is None:
score = compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores

View File

@@ -0,0 +1,76 @@
"""callback to calculate perplexity as an evaluation metric."""
from typing import Dict, List, Optional
import torch
from torch import Tensor
from tqdm import tqdm
from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
class Perplexity:
"""
Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.
This is a custom variant that doesn't re-tokenize the input or re-load the model.
"""
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
max_seq_len: int,
stride: int = 512,
) -> None:
self.max_seq_len = max_seq_len
self.stride = stride
self.model = model
self.tokenizer = tokenizer
self.device = model.device
self.name = "perplexity"
def _feature_names(self) -> List[str]:
return ["references"]
def compute(
self,
references: Optional[List[str]] = None,
) -> Dict[str, float]:
"""
Compute perplexity in a fixed length sliding window across the sequence.
"""
assert references is not None, "Missing parameter: references"
references_tokenized = self.tokenizer(
references, return_tensors="pt", padding=True, truncation=True
)
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
input_ids = input_ids.to(self.device)
sequence_length = input_ids.size(1)
losses = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
trg_len = end_loc - prev_end_loc
input_ids_slice = input_ids[:, begin_loc:end_loc]
labels_slice = input_ids_slice.clone()
labels_slice[:, :-trg_len] = -100
with torch.no_grad():
outputs: CausalLMOutput = self.model(
input_ids=input_ids_slice, labels=labels_slice
)
losses.append(outputs.loss)
prev_end_loc = end_loc
if end_loc == sequence_length:
break
perplexity = torch.exp(torch.stack(losses).mean()).item()
return {
"score": perplexity,
}

View File

@@ -25,6 +25,7 @@ def chat_templates(user_choice: str):
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
}
if user_choice in templates:

View File

@@ -10,6 +10,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
SUPPORTED_METRICS,
AxolotlConfigWCapabilities,
AxolotlInputConfig,
)
@@ -586,13 +587,12 @@ def legacy_validate_config(cfg):
)
if cfg.eval_causal_lm_metrics:
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)
# TODO

View File

@@ -17,6 +17,8 @@ from axolotl.utils.config.models.internals import GPUCapabilities
LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -114,6 +116,7 @@ class SFTDataset(BaseModel):
message_field_content: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
class UserDefinedDPOType(BaseModel):
@@ -162,7 +165,6 @@ class RLType(str, Enum):
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
@@ -176,6 +178,7 @@ class ChatTemplate(str, Enum):
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -219,8 +222,6 @@ class LoraConfig(BaseModel):
peft_layers_to_transform: Optional[List[int]] = None
peft: Optional[PeftConfig] = None
peft_use_dora: Optional[bool] = None
peft_use_mora: Optional[bool] = None
peft_mora_type: Optional[int] = None
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
@@ -340,7 +341,10 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch"]]
Union[
OptimizerNames,
Literal["lion_pytorch", "optimi_adamw", "q_galore_adamw8bit"],
]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
@@ -621,6 +625,7 @@ class AxolotlInputConfig(
neftune_noise_alpha: Optional[float] = None
orpo_alpha: Optional[float] = None
rpo_alpha: Optional[float] = None
kto_desirable_weight: Optional[float] = None
kto_undesirable_weight: Optional[float] = None
@@ -897,6 +902,26 @@ class AxolotlInputConfig(
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
if (
data.get("sample_packing")
and data.get("eval_sample_packing") is None
and not data.get("eval_table_size")
):
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`"
)
data["eval_sample_packing"] = True
if (
data.get("sample_packing")
and data.get("eval_sample_packing") is False
and data.get("remove_unused_columns") is None
):
LOG.info(
"setting `remove_unused_columns: false` for when sample_packing and eval_sample_packing don't match"
)
data["remove_unused_columns"] = False
return data
@model_validator(mode="before")
@@ -1074,13 +1099,12 @@ class AxolotlInputConfig(
)
if data.get("eval_causal_lm_metrics"):
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(data.get("eval_causal_lm_metrics"), list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)
return data

View File

@@ -474,12 +474,16 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx,
)
if split == "train" and cfg.val_set_size:
val_set_size = (
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
)
if split == "train" and val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ str(val_set_size)
+ "|"
+ "train"
+ "|"
@@ -488,7 +492,7 @@ def load_prepare_datasets(
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ str(val_set_size)
+ "|"
+ "test"
+ "|"
@@ -498,9 +502,7 @@ def load_prepare_datasets(
test_fingerprint = md5(to_hash_test)
dataset = dataset.train_test_split(
test_size=int(cfg.val_set_size)
if cfg.val_set_size == int(cfg.val_set_size)
else cfg.val_set_size,
test_size=val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
@@ -535,6 +537,10 @@ def get_dataset_wrapper(
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}
LOG.info(
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
)
if (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features

View File

@@ -120,6 +120,9 @@ def _merge_ranges(
processed_ranges = [
(start, end if end is not None else layer_size) for start, end in given_ranges
]
for start, end in processed_ranges:
if start < 0 or end > layer_size > 0 or start >= end:
raise ValueError(f"invalid unfreeze range: start={start}, end={end}")
# No need to merge if there's only one or no ranges
if len(processed_ranges) <= 1:

View File

@@ -371,6 +371,12 @@ def load_model(
rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
@@ -569,9 +575,11 @@ def load_model(
try:
skip_move_to_device = False
if (
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
) and not qlora_fsdp:
if ( # pylint: disable=condition-evals-to-constant)
(cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
and not qlora_fsdp
and False
):
model = load_sharded_model(
base_model,
model_config,
@@ -597,9 +605,12 @@ def load_model(
and not cfg.trust_remote_code
and not cfg.gptq
):
from transformers import LlamaForCausalLM
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = LlamaForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
@@ -632,7 +643,11 @@ def load_model(
base_model,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
elif (
model_type
and model_type != "AutoModelForCausalLM"
and not cfg.trust_remote_code
):
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
@@ -673,6 +688,7 @@ def load_model(
)
else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
@@ -803,11 +819,7 @@ def load_model(
if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if (
cfg.adapter
and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"]
and not cfg.merge_lora
):
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto"] and not cfg.merge_lora:
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
@@ -953,8 +965,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if cfg.lora_alpha:
lora_config_kwargs["lora_alpha"] = cfg.lora_alpha
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
@@ -962,14 +972,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_use_mora and cfg.peft_mora_type is not None:
lora_config_kwargs["use_mora"] = cfg.peft_use_mora
lora_config_kwargs["mora_type"] = cfg.peft_mora_type
if cfg.peft_layer_replication:
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
layers_to_transform=cfg.peft_layers_to_transform,
lora_dropout=cfg.lora_dropout,

View File

@@ -427,7 +427,7 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]:
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]

View File

@@ -0,0 +1,87 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from importlib import reload
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.fixture(autouse=True)
def reload_transformers():
import transformers.models.llama.modeling_llama
yield
reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama(unittest.TestCase):
"""
Test case for Llama models using LoRA w multipack
"""
@with_temp_dir
def test_lora_packing_fa_cross_entropy(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"flash_attn_cross_entropy": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -7,6 +7,8 @@ import os
import unittest
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip(reason="FIXME?")
class TestLlamaShiftedSparseAttention(unittest.TestCase):
"""
Test case for Llama models using S2 Attn

View File

@@ -0,0 +1,25 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.unsloth_ import (
check_cel_is_patchable,
check_self_attn_is_patchable,
)
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
def test_is_cel_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_cel_is_patchable(),
"HF transformers loss code has changed and isn't patchable",
)
def test_is_self_attn_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_self_attn_is_patchable(),
"HF transformers self attention code has changed and isn't patchable",
)

View File

@@ -21,7 +21,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip(reason="doesn't seem to work on modal")
class TestDPOLlamaLora(unittest.TestCase):
"""
Test case for DPO Llama models using LoRA
@@ -45,8 +44,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"rl": "dpo",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"type": "chatml.intel",
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
@@ -71,6 +70,52 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_dpo_nll_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "dpo",
"rpo_alpha": 0.5,
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
def test_kto_pair_lora(self, temp_dir):
# pylint: disable=duplicate-code
@@ -89,8 +134,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"rl": "kto_pair",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"type": "chatml.intel",
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
@@ -133,8 +178,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"rl": "ipo",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"type": "chatml.intel",
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
@@ -180,7 +225,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"chat_template": "chatml",
"datasets": [
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
"path": "argilla/distilabel-capybara-dpo-7k-binarized",
"type": "chat_template.argilla",
"split": "train",
},
@@ -206,6 +251,7 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip(reason="Fix the implementation")
@with_temp_dir
def test_kto_lora(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
"type": "alpaca",
},
],
"num_epochs": 2,
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,

View File

@@ -0,0 +1,109 @@
"""
E2E tests for custom optimizers using Llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomOptimizers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_optimi_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "optimi_adamw",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_q_galore_adamw8bit(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "q_galore_adamw8bit",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -52,6 +52,51 @@ def fixture_sharegpt_dataset():
)
@pytest.fixture(name="sharegpt_dataset_with_weights")
def fixture_sharegpt_dataset_with_weights():
return Dataset.from_list(
[
{
"conversations": [
{
"from": "system",
"value": "repeat",
},
{
"from": "human",
"value": "hello",
"weight": 1,
},
{
"from": "gpt",
"value": "hello",
"weight": 0,
},
{
"from": "human",
"value": "rehello",
"weight": 0,
},
{
"from": "gpt",
"value": "rehello",
"weight": 1,
},
{
"from": "human",
"value": "goodbye",
},
{
"from": "gpt",
"value": "goodbye",
"weight": 0,
},
]
}
]
)
@pytest.fixture(name="glaive_dataset")
def fixture_sharegpt_glaive_dataset():
return Dataset.from_list(
@@ -162,6 +207,46 @@ class TestSharegptLlama3:
]
# fmt: on
def test_tokenization_with_weights(
self, sharegpt_dataset_with_weights, llama3_tokenizer
):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="llama3",
role_key_model=None,
role_key_human=None,
),
llama3_tokenizer,
False, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, sharegpt_dataset_with_weights, process_count=1
)
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
assert input_ids == [
128000, # bos
128006, 9125, 128007, # system header
271, 31724, 128009, # sys prompt, eot
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 11310, 4896, 128009,
128006, 78191, 128007,
271, 11310, 4896, 128009,
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
class TestSharegptChatML:
"""
@@ -197,7 +282,40 @@ class TestSharegptChatML:
]
# fmt: on
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
def test_no_double_im_end_with_weights(
self, sharegpt_dataset_with_weights, tokenizer
):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
role_key_model=None,
role_key_human=None,
),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, sharegpt_dataset_with_weights, process_count=1
)
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
assert input_ids == [
# 28705, 13, is " \n"
1, # bos
32001, 1587, 13, 25997, 32000, 28705, 13, # system
32001, 2188, 13, 21558, 32000, 28705, 13, # human
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
]
# fmt: on
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
@@ -225,7 +343,39 @@ class TestSharegptChatML:
]
# fmt: on
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
def test_no_train_on_input_with_weights(
self, sharegpt_dataset_with_weights, tokenizer
):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
role_key_model=None,
role_key_human=None,
),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, sharegpt_dataset_with_weights, process_count=1
)
labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, # system
-100, -100, -100, -100, -100, -100, -100, # human
-100, -100, -100, -100, -100, -100, -100, # gpt with weight zero
-100, -100, -100, -100, -100, -100, -100, -100, # human
-100, -100, 13, 267, 21558, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, -100, # human
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight zero
]
# fmt: on
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
@@ -253,6 +403,38 @@ class TestSharegptChatML:
]
# fmt: on
def test_w_train_on_input_with_weights(
self, sharegpt_dataset_with_weights, tokenizer
):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
role_key_model=None,
role_key_human=None,
),
tokenizer,
True, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, sharegpt_dataset_with_weights, process_count=1
)
labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
1, # bos
32001, 1587, 13, 25997, 32000, 28705, 13, # system
32001, 2188, 13, 21558, 32000, 28705, 13, # human
-100, -100, -100, -100, -100, -100, -100, # gpt with weight 0
-100, -100, -100, -100, -100, -100, -100, -100, # human with weight 0
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight 0
]
# fmt: on
def test_chatml_glaive(self, glaive_dataset, tokenizer):
strategy = GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(

41
tests/test_perplexity.py Normal file
View File

@@ -0,0 +1,41 @@
"""unit tests for perplexity eval callback"""
# pylint: disable=redefined-outer-name
from pytest import fixture
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
from transformers.models.auto.tokenization_auto import AutoTokenizer
from axolotl.utils.callbacks.perplexity import Perplexity
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@fixture()
def metric(tokenizer):
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
return Perplexity(model, tokenizer, 512)
@fixture()
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
def test_perplexity_longer_than_stride(metric):
# taken from https://huggingface.co/datasets/roneneldan/TinyStories
sample_text = """
Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.
One day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. "Hi, I am Fin. Do you want to play?" asked the little fish. The crab looked at Fin and said, "No, I don't want to play. I am cold and I don't feel fine." Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, "Please, sun, help my new friend feel fine and not freeze!" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, "Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!" And so, Fin and the crab played and became good friends.
"""
result = metric.compute([sample_text])
ppl = result["score"]
assert round(ppl, 2) == 5.37
def test_perplexity_short(metric):
# taken from https://huggingface.co/datasets/roneneldan/TinyStories
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
result = metric.compute([sample_text])
ppl = result["score"]
assert round(ppl, 2) == 10.02