Compare commits
21 Commits
llama-mult
...
dpo-spawn-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e86dd76154 | ||
|
|
5f58555bd0 | ||
|
|
cfc533a7f7 | ||
|
|
e1725aef2b | ||
|
|
78e12f8ca5 | ||
|
|
98af5388ba | ||
|
|
219cd0d3c5 | ||
|
|
634f384e06 | ||
|
|
4512738a73 | ||
|
|
1e57b4c562 | ||
|
|
a4a5bf057f | ||
|
|
137d84d1b4 | ||
|
|
18abdb447a | ||
|
|
47e1916484 | ||
|
|
1194c2e0b1 | ||
|
|
a159724e44 | ||
|
|
b3f680d305 | ||
|
|
c69b7eb2b5 | ||
|
|
c6d83a87c4 | ||
|
|
5370cedf0c | ||
|
|
f2480a1d91 |
8
.github/CONTRIBUTING.md
vendored
8
.github/CONTRIBUTING.md
vendored
@@ -21,12 +21,12 @@ All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT
|
|||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
Bugs? Please check for open issue else create a new [Issue](https://github.com/axolotl-ai-cloud/axolotl/issues/new).
|
||||||
|
|
||||||
PRs are **greatly welcome**!
|
PRs are **greatly welcome**!
|
||||||
|
|
||||||
1. Fork the repository and clone it to your local machine.
|
1. Fork the repository and clone it to your local machine.
|
||||||
2. Set up the development environment by following the instructions in the [README.md](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/README.md) file.
|
2. Set up the development environment by following the instructions in the [README.md](https://github.com/axolotl-ai-cloud/axolotl/tree/main/README.md) file.
|
||||||
3. Explore the codebase, run tests, and verify that everything works as expected.
|
3. Explore the codebase, run tests, and verify that everything works as expected.
|
||||||
|
|
||||||
Please run below to setup env
|
Please run below to setup env
|
||||||
@@ -42,11 +42,11 @@ pytest tests/
|
|||||||
|
|
||||||
### Reporting Bugs
|
### Reporting Bugs
|
||||||
|
|
||||||
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
|
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/axolotl-ai-cloud/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
|
||||||
|
|
||||||
### Suggesting Enhancements
|
### Suggesting Enhancements
|
||||||
|
|
||||||
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
|
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/axolotl-ai-cloud/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
|
||||||
|
|
||||||
### Submitting Pull Requests
|
### Submitting Pull Requests
|
||||||
|
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
@@ -15,7 +15,7 @@ body:
|
|||||||
label: "Please check that this issue hasn't been reported before."
|
label: "Please check that this issue hasn't been reported before."
|
||||||
description: "The **Label filters** may help make your search more focussed."
|
description: "The **Label filters** may help make your search more focussed."
|
||||||
options:
|
options:
|
||||||
- label: "I searched previous [Bug Reports](https://github.com/OpenAccess-AI-Collective/axolotl/labels/bug) didn't find any similar reports."
|
- label: "I searched previous [Bug Reports](https://github.com/axolotl-ai-cloud/axolotl/labels/bug) didn't find any similar reports."
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/config.yml
vendored
2
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,7 +1,7 @@
|
|||||||
blank_issues_enabled: false
|
blank_issues_enabled: false
|
||||||
contact_links:
|
contact_links:
|
||||||
- name: Ask a question
|
- name: Ask a question
|
||||||
url: https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/q-a
|
url: https://github.com/axolotl-ai-cloud/axolotl/discussions/categories/q-a
|
||||||
about: Ask questions and discuss with other community members
|
about: Ask questions and discuss with other community members
|
||||||
- name: Discuss the Project in Discord
|
- name: Discuss the Project in Discord
|
||||||
url: https://discord.gg/HhrNrHJPRb
|
url: https://discord.gg/HhrNrHJPRb
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/docs.yml
vendored
2
.github/ISSUE_TEMPLATE/docs.yml
vendored
@@ -10,7 +10,7 @@ body:
|
|||||||
value: |
|
value: |
|
||||||
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
|
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
|
||||||
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
|
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
|
||||||
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues).
|
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/axolotl-ai-cloud/axolotl/issues).
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: What piece of documentation is affected?
|
label: What piece of documentation is affected?
|
||||||
|
|||||||
4
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
4
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
@@ -8,9 +8,9 @@ body:
|
|||||||
label: "⚠️ Please check that this feature request hasn't been suggested before."
|
label: "⚠️ Please check that this feature request hasn't been suggested before."
|
||||||
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
|
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
|
||||||
options:
|
options:
|
||||||
- label: "I searched previous [Ideas in Discussions](https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
|
- label: "I searched previous [Ideas in Discussions](https://github.com/axolotl-ai-cloud/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
|
||||||
required: true
|
required: true
|
||||||
- label: "I searched previous [Issues](https://github.com/OpenAccess-AI-Collective/axolotl/labels/enhancement) didn't find any similar feature requests."
|
- label: "I searched previous [Issues](https://github.com/axolotl-ai-cloud/axolotl/labels/enhancement) didn't find any similar feature requests."
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
|
|||||||
7
.github/workflows/base.yml
vendored
7
.github/workflows/base.yml
vendored
@@ -5,7 +5,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-base:
|
build-base:
|
||||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
strategy:
|
strategy:
|
||||||
@@ -37,6 +37,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "121"
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
16
.github/workflows/main.yml
vendored
16
.github/workflows/main.yml
vendored
@@ -8,7 +8,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-axolotl:
|
build-axolotl:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -19,7 +19,6 @@ jobs:
|
|||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||||
is_latest: true
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -33,8 +32,9 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -70,7 +70,7 @@ jobs:
|
|||||||
|
|
||||||
build-axolotl-cloud:
|
build-axolotl-cloud:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -80,7 +80,6 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -94,8 +93,9 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -128,7 +128,7 @@ jobs:
|
|||||||
|
|
||||||
build-axolotl-cloud-no-tmux:
|
build-axolotl-cloud-no-tmux:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -136,7 +136,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
12
.github/workflows/nightlies.yml
vendored
12
.github/workflows/nightlies.yml
vendored
@@ -7,7 +7,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-axolotl:
|
build-axolotl:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -18,7 +18,6 @@ jobs:
|
|||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||||
is_latest: true
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -32,8 +31,9 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -70,7 +70,7 @@ jobs:
|
|||||||
|
|
||||||
build-axolotl-cloud:
|
build-axolotl-cloud:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -80,7 +80,6 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -94,8 +93,9 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
10
.github/workflows/tests.yml
vendored
10
.github/workflows/tests.yml
vendored
@@ -57,8 +57,12 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pytest --ignore=tests/e2e/ tests/
|
pytest --ignore=tests/e2e/ tests/
|
||||||
|
|
||||||
|
- name: cleanup pip cache
|
||||||
|
run: |
|
||||||
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
@@ -87,7 +91,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -99,7 +103,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal jinja2
|
pip install modal==0.63.64 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
|||||||
18
README.md
18
README.md
@@ -67,8 +67,8 @@ Features:
|
|||||||
<p>
|
<p>
|
||||||
Go ahead and Axolotl questions!!
|
Go ahead and Axolotl questions!!
|
||||||
</p>
|
</p>
|
||||||
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||||
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
<img alt="PyTest Status" src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
|||||||
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging ninja
|
pip3 install packaging ninja
|
||||||
@@ -132,7 +132,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
|
|
||||||
# remote yaml files - the yaml config can be hosted on a public URL
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
# Note: the yaml config must directly link to the **raw** yaml
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
|
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced Setup
|
## Advanced Setup
|
||||||
@@ -333,7 +333,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
|
|||||||
|
|
||||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||||
|
|
||||||
See [these docs](https://openaccess-ai-collective.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
||||||
|
|
||||||
### Config
|
### Config
|
||||||
|
|
||||||
@@ -626,10 +626,10 @@ Need dedicated support? Please contact us at [✉️wing@openaccessaicollective.
|
|||||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||||
|
|
||||||
```markdown
|
```markdown
|
||||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
```
|
```
|
||||||
|
|
||||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
|
||||||
## Community Showcase
|
## Community Showcase
|
||||||
|
|
||||||
@@ -647,7 +647,7 @@ PocketDoc Labs
|
|||||||
|
|
||||||
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
||||||
|
|
||||||
Bugs? Please check the [open issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues/bug) else create a new Issue.
|
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
|
||||||
|
|
||||||
PRs are **greatly welcome**!
|
PRs are **greatly welcome**!
|
||||||
|
|
||||||
@@ -665,7 +665,7 @@ pre-commit run --all-files
|
|||||||
|
|
||||||
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||||
|
|
||||||
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
|
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
|
||||||
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ website:
|
|||||||
- icon: twitter
|
- icon: twitter
|
||||||
href: https://twitter.com/axolotl_ai
|
href: https://twitter.com/axolotl_ai
|
||||||
- icon: github
|
- icon: github
|
||||||
href: https://github.com/OpenAccess-AI-Collective/axolotl/
|
href: https://github.com/axolotl-ai-cloud/axolotl/
|
||||||
- icon: discord
|
- icon: discord
|
||||||
href: https://discord.gg/7m9sfhzaf3
|
href: https://discord.gg/7m9sfhzaf3
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ RUN apt-get update && \
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
@@ -24,13 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
RUN pip install causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
RUN pip install pytest
|
RUN pip install -r requirements-tests.txt
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
pytest /workspace/axolotl/tests/e2e/patched/
|
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
|
||||||
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -15,16 +15,16 @@ RUN apt-get update && \
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
RUN pip install causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ RUN apt-get update && \
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
# use RL training: 'dpo', 'ipo', 'kto'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
|
|||||||
@@ -4,9 +4,25 @@ description: How to use a custom pre-tokenized dataset.
|
|||||||
order: 5
|
order: 5
|
||||||
---
|
---
|
||||||
|
|
||||||
- Do not pass a `type:` in your axolotl config.
|
- Pass an empty `type:` in your axolotl config.
|
||||||
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
|
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
|
||||||
|
- To indicate that a token should be ignored during training, set its corresponding label to `-100`.
|
||||||
|
- Do not add BOS/EOS. Axolotl will add them for you based on the default tokenizer for the model you're using.
|
||||||
|
- For pretraining, do not truncate/pad documents to the context window length.
|
||||||
|
- For instruction training, documents must be truncated/padded as desired.
|
||||||
|
|
||||||
|
Sample config:
|
||||||
|
|
||||||
```{.yaml filename="config.yml"}
|
```{.yaml filename="config.yml"}
|
||||||
- path: ...
|
datasets:
|
||||||
|
- path: /path/to/your/file.jsonl
|
||||||
|
ds_type: json
|
||||||
|
type:
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample jsonl:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{"input_ids":[271,299,99],"attention_mask":[1,1,1],"labels":[271,-100,99]}
|
||||||
|
{"input_ids":[87,227,8383,12],"attention_mask":[1,1,1,1],"labels":[87,227,8383,12]}
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl
|
|||||||
On the host that is running axolotl (ex: if you are using a remote host), clone the axolotl repo and change your current directory to the root:
|
On the host that is running axolotl (ex: if you are using a remote host), clone the axolotl repo and change your current directory to the root:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
|||||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||||
|
|
||||||
1. Set `adapter: qlora` in your axolotl config file.
|
1. Set `adapter: qlora` in your axolotl config file.
|
||||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
|
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
||||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||||
|
|
||||||
## Example Config
|
## Example Config
|
||||||
@@ -29,7 +29,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
|||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
|
- [PR #1378](https://github.com/axolotl-ai-cloud/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
|
||||||
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
|
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
|
||||||
- Related HuggingFace PRs Enabling FDSP + QLoRA:
|
- Related HuggingFace PRs Enabling FDSP + QLoRA:
|
||||||
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
|
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ description: "Template-free prompt construction with the `input_output` format"
|
|||||||
### Masking Inputs
|
### Masking Inputs
|
||||||
|
|
||||||
One of the most popular features of
|
One of the most popular features of
|
||||||
[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is
|
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
|
||||||
setting the following configuration value:
|
setting the following configuration value:
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ setting the following configuration value:
|
|||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
```
|
```
|
||||||
|
|
||||||
If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset)
|
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
|
||||||
such as `alpaca` or `chatml`, axolotl knows what is an input
|
such as `alpaca` or `chatml`, axolotl knows what is an input
|
||||||
(i.e. human) vs. an output (i.e. the assistant) and masks the input
|
(i.e. human) vs. an output (i.e. the assistant) and masks the input
|
||||||
labels so that your model can focus on predicting the outputs only.
|
labels so that your model can focus on predicting the outputs only.
|
||||||
|
|||||||
@@ -44,7 +44,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install torch==\"2.1.2\"\n",
|
"!pip install torch==\"2.1.2\"\n",
|
||||||
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
||||||
"!pip install flash-attn==\"2.5.0\"\n",
|
"!pip install flash-attn==\"2.5.0\"\n",
|
||||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||||
]
|
]
|
||||||
@@ -171,7 +171,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
"# By using the ! the comand will be executed as a bash command\n",
|
||||||
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -188,7 +188,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
"# By using the ! the comand will be executed as a bash command\n",
|
||||||
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
||||||
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
||||||
]
|
]
|
||||||
|
|||||||
68
examples/gemma2/qlora.yml
Normal file
68
examples/gemma2/qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
base_model: google/gemma-2-9b
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: gemma
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
chat_template: gemma
|
||||||
|
drop_system_message: true
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -15,6 +15,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|||||||
@@ -1 +1,2 @@
|
|||||||
pytest
|
pytest
|
||||||
|
pytest-xdist
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.11.1
|
||||||
transformers==4.41.1
|
transformers==4.42.3
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.30.1
|
accelerate==0.32.0
|
||||||
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
|
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
@@ -12,11 +12,11 @@ fire
|
|||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets==2.19.1
|
datasets==2.19.1
|
||||||
flash-attn==2.5.8
|
flash-attn==2.6.1
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.26.post1
|
xformers==0.0.27
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -31,6 +31,7 @@ art
|
|||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
@@ -39,6 +40,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@f18253bf2d747f68acc9cd89da95c85ebf59dbb9
|
trl==0.9.6
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace a
|
|||||||
```
|
```
|
||||||
cd /workspace
|
cd /workspace
|
||||||
rm -rf /workspace/axolotl
|
rm -rf /workspace/axolotl
|
||||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip install --no-deps -e .
|
pip install --no-deps -e .
|
||||||
```
|
```
|
||||||
|
|||||||
21
setup.py
21
setup.py
@@ -29,9 +29,10 @@ def parse_requirements():
|
|||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# don't install xformers on MacOS
|
# don't install xformers on MacOS
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
else:
|
else:
|
||||||
# detect the version of torch already installed
|
# detect the version of torch already installed
|
||||||
# and set it so dependencies don't clobber the torch version
|
# and set it so dependencies don't clobber the torch version
|
||||||
@@ -49,12 +50,14 @@ def parse_requirements():
|
|||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 3):
|
if (major, minor) >= (2, 3):
|
||||||
pass
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.26.post1")
|
||||||
elif (major, minor) >= (2, 2):
|
elif (major, minor) >= (2, 2):
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.25.post1")
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
else:
|
else:
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.23.post1")
|
_install_requires.append("xformers>=0.0.23.post1")
|
||||||
|
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
@@ -77,10 +80,10 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.5.8",
|
"flash-attn==2.6.1",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
||||||
@@ -101,5 +104,11 @@ setup(
|
|||||||
"galore": [
|
"galore": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
],
|
],
|
||||||
|
"optimizers": [
|
||||||
|
"galore_torch",
|
||||||
|
"lion-pytorch==0.1.2",
|
||||||
|
"lomo-optim==0.1.1",
|
||||||
|
"torch-optimi==0.2.1",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
do_inference,
|
do_inference,
|
||||||
@@ -33,4 +34,5 @@ def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
|
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
@@ -48,4 +49,5 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import fire
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
|
from dotenv import load_dotenv
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
@@ -86,4 +87,5 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Union
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
@@ -40,4 +41,5 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
from dotenv import load_dotenv
|
||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
@@ -67,4 +68,5 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from abc import abstractmethod
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from multiprocessing import set_start_method
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Optional, Type, Union
|
from typing import Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
@@ -226,6 +227,12 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
)
|
)
|
||||||
|
alternate_optimizer: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -284,26 +291,72 @@ class AxolotlTrainer(Trainer):
|
|||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.torch_compile:
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
|
256
|
||||||
|
)
|
||||||
|
model = torch.compile(
|
||||||
|
model,
|
||||||
|
backend=self.args.torch_compile_backend,
|
||||||
|
mode=self.args.torch_compile_mode,
|
||||||
|
)
|
||||||
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if self.args.loraplus_lr_ratio is None:
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.alternate_optimizer != "optimi_adamw"
|
||||||
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n not in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
self.args,
|
self.args,
|
||||||
opt_model,
|
opt_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
loraplus_lr_embedding = getattr(
|
||||||
opt_model,
|
self.args, "loraplus_lr_embedding", None
|
||||||
optimizer_cls,
|
)
|
||||||
optimizer_kwargs,
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
loraplus_lr_ratio,
|
opt_model,
|
||||||
loraplus_lr_embedding,
|
optimizer_cls,
|
||||||
)
|
optimizer_kwargs,
|
||||||
|
loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding,
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||||
|
from optimi import AdamW
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamW(
|
||||||
|
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -1091,6 +1144,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||||
else:
|
else:
|
||||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||||
|
if warmup_steps == 1:
|
||||||
|
warmup_steps = 2
|
||||||
|
|
||||||
logging_steps = (
|
logging_steps = (
|
||||||
self.cfg.logging_steps
|
self.cfg.logging_steps
|
||||||
@@ -1394,6 +1449,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
|
if self.cfg.optimizer == "optimi_adamw":
|
||||||
|
# Set default so transformers doesn't throw
|
||||||
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
|
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
||||||
|
|
||||||
if self.cfg.optimizer == "lion_pytorch":
|
if self.cfg.optimizer == "lion_pytorch":
|
||||||
from lion_pytorch import Lion
|
from lion_pytorch import Lion
|
||||||
|
|
||||||
@@ -1668,8 +1728,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
elif self.cfg.rl == "kto_pair":
|
|
||||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
@@ -1678,7 +1736,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
@@ -1693,7 +1751,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
elif self.cfg.rl == "kto":
|
elif self.cfg.rl in ["kto"]:
|
||||||
trainer_cls = AxolotlKTOTrainer
|
trainer_cls = AxolotlKTOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
else:
|
else:
|
||||||
@@ -1713,6 +1771,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||||
dpo_trainer.add_callback(callback)
|
dpo_trainer.add_callback(callback)
|
||||||
|
|
||||||
|
# prevents multiprocessing issues for datasets on multiple GPUs
|
||||||
|
set_start_method("spawn")
|
||||||
|
|
||||||
return dpo_trainer
|
return dpo_trainer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
src/axolotl/integrations/__init__.py
Normal file
0
src/axolotl/integrations/__init__.py
Normal file
@@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
|
|||||||
set_module_name(model, name, qkv)
|
set_module_name(model, name, qkv)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama_cross_entropy():
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama_rms_norm():
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
class LlamaRMSNorm(RMSNorm):
|
||||||
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning(
|
||||||
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn(
|
def replace_llama_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
@@ -104,35 +131,11 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
try:
|
patch_llama_cross_entropy()
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
|
||||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
|
||||||
CrossEntropyLoss, inplace_backward=True
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
LOG.info(
|
|
||||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
try:
|
patch_llama_rms_norm()
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
|
||||||
|
|
||||||
class LlamaRMSNorm(RMSNorm):
|
|
||||||
"""Patched LLamaRMSNorm"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__(hidden_size, eps=eps)
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
|
||||||
except ImportError:
|
|
||||||
LOG.info(
|
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedAttention(LlamaAttention):
|
class FusedAttention(LlamaAttention):
|
||||||
@@ -826,7 +829,6 @@ def llama_model_forward(
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
padding_mask=padding_mask,
|
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ def flashattn_forward(
|
|||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
@@ -422,6 +422,9 @@ def mistral_model_forward(
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[ # pylint: disable=unused-argument
|
||||||
|
torch.LongTensor
|
||||||
|
] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions
|
output_attentions
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"gemma",
|
"gemma",
|
||||||
|
"gemma2",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"deepseek_v2",
|
"deepseek_v2",
|
||||||
@@ -54,6 +55,10 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
|
elif model_type == "gemma2":
|
||||||
|
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
elif model_type == "starcoder2":
|
elif model_type == "starcoder2":
|
||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
|
|||||||
@@ -80,8 +80,9 @@ def get_forward_code() -> str:
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def test_cel_is_patchable() -> bool:
|
def check_cel_is_patchable() -> bool:
|
||||||
forward = get_forward_code()
|
forward = get_forward_code()
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
return ORIGINAL_CEL_CODE in forward
|
return ORIGINAL_CEL_CODE in forward
|
||||||
|
|
||||||
|
|
||||||
@@ -90,9 +91,10 @@ def get_self_attn_code() -> str:
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def test_self_attn_is_patchable() -> bool:
|
def check_self_attn_is_patchable() -> bool:
|
||||||
qkv = get_self_attn_code()
|
qkv = get_self_attn_code()
|
||||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
|
qkv, _ = detab_code(qkv)
|
||||||
|
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
||||||
|
|
||||||
|
|
||||||
def integrate_cross_entropy_loss_patch():
|
def integrate_cross_entropy_loss_patch():
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
message_field_role: str = "from",
|
message_field_role: str = "from",
|
||||||
message_field_content: str = "value",
|
message_field_content: str = "value",
|
||||||
roles: Optional[Dict[str, List[str]]] = None,
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
if roles:
|
if roles:
|
||||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
@@ -39,6 +40,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||||
turns = [
|
turns = [
|
||||||
@@ -49,6 +51,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
for t in conversation
|
for t in conversation
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if self.drop_system_message and turns[0]["role"] == "system":
|
||||||
|
turns = turns[1:]
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
turns,
|
turns,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@@ -111,6 +116,11 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
else "value"
|
else "value"
|
||||||
)
|
)
|
||||||
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
||||||
|
drop_system_message = (
|
||||||
|
ds_cfg["drop_system_message"]
|
||||||
|
if ds_cfg and "drop_system_message" in ds_cfg
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
@@ -119,6 +129,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
message_field_role=message_field_role,
|
message_field_role=message_field_role,
|
||||||
message_field_content=message_field_content,
|
message_field_content=message_field_content,
|
||||||
roles=roles,
|
roles=roles,
|
||||||
|
drop_system_message=drop_system_message,
|
||||||
),
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
|
|||||||
@@ -52,6 +52,13 @@ class TrainDatasetMeta:
|
|||||||
def train(
|
def train(
|
||||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||||
|
# enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
# torch_version = torch.__version__.split(".")
|
||||||
|
# torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
|
# if torch_major == 2 and torch_minor >= 2:
|
||||||
|
# if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
|
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
@@ -144,7 +151,7 @@ def train(
|
|||||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||||
)
|
)
|
||||||
|
|
||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||||
|
|
||||||
if getattr(cfg, "axolotl_config_path"):
|
if getattr(cfg, "axolotl_config_path"):
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class SFTDataset(BaseModel):
|
|||||||
message_field_content: Optional[str] = None
|
message_field_content: Optional[str] = None
|
||||||
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
|
drop_system_message: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
class UserDefinedDPOType(BaseModel):
|
||||||
@@ -164,7 +165,6 @@ class RLType(str, Enum):
|
|||||||
|
|
||||||
dpo = "dpo" # pylint: disable=invalid-name
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -341,7 +341,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
learning_rate: Union[str, float]
|
learning_rate: Union[str, float]
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[OptimizerNames, Literal["lion_pytorch"]]
|
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
||||||
@@ -1112,6 +1112,31 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError("either datasets or pretraining_dataset is required")
|
raise ValueError("either datasets or pretraining_dataset is required")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_xentropy_patch_conflicts(cls, data):
|
||||||
|
if data.get("flash_attn_cross_entropy") and data.get(
|
||||||
|
"unsloth_cross_entropy_loss"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_qlora_unsloth(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
|
||||||
|
raise ValueError(
|
||||||
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
@@ -1163,3 +1188,18 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("deepspeed") and data.get("fsdp"):
|
if data.get("deepspeed") and data.get("fsdp"):
|
||||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_multigpu_unsloth(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
capabilities = data.get("capabilities")
|
||||||
|
if capabilities and capabilities.get("num_gpus") > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -120,6 +120,9 @@ def _merge_ranges(
|
|||||||
processed_ranges = [
|
processed_ranges = [
|
||||||
(start, end if end is not None else layer_size) for start, end in given_ranges
|
(start, end if end is not None else layer_size) for start, end in given_ranges
|
||||||
]
|
]
|
||||||
|
for start, end in processed_ranges:
|
||||||
|
if start < 0 or end > layer_size > 0 or start >= end:
|
||||||
|
raise ValueError(f"invalid unfreeze range: start={start}, end={end}")
|
||||||
|
|
||||||
# No need to merge if there's only one or no ranges
|
# No need to merge if there's only one or no ranges
|
||||||
if len(processed_ranges) <= 1:
|
if len(processed_ranges) <= 1:
|
||||||
|
|||||||
@@ -347,6 +347,27 @@ def load_model(
|
|||||||
and cfg.sample_packing
|
and cfg.sample_packing
|
||||||
):
|
):
|
||||||
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model:
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
patch_llama_cross_entropy,
|
||||||
|
patch_llama_rms_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.flash_attn_cross_entropy:
|
||||||
|
patch_llama_cross_entropy()
|
||||||
|
if cfg.flash_attn_rms_norm:
|
||||||
|
patch_llama_rms_norm()
|
||||||
|
if cfg.unsloth_cross_entropy_loss:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import (
|
||||||
|
integrate_cross_entropy_loss_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
integrate_cross_entropy_loss_patch()
|
||||||
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
|
|
||||||
|
patch_self_attn_lora()
|
||||||
elif cfg.is_llama_derived_model:
|
elif cfg.is_llama_derived_model:
|
||||||
# Modify all llama derived models in one block
|
# Modify all llama derived models in one block
|
||||||
|
|
||||||
@@ -371,6 +392,12 @@ def load_model(
|
|||||||
rms_norm=cfg.flash_attn_rms_norm,
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
use_shifted_sparse_attn=True,
|
use_shifted_sparse_attn=True,
|
||||||
)
|
)
|
||||||
|
elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
|
||||||
|
replace_llama_attn_with_flash_attn(
|
||||||
|
packed=False,
|
||||||
|
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||||
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
|
)
|
||||||
elif cfg.xformers_attention:
|
elif cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_attention,
|
hijack_llama_attention,
|
||||||
@@ -569,9 +596,11 @@ def load_model(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
if (
|
if ( # pylint: disable=condition-evals-to-constant)
|
||||||
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
(cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
|
||||||
) and not qlora_fsdp:
|
and not qlora_fsdp
|
||||||
|
and False
|
||||||
|
):
|
||||||
model = load_sharded_model(
|
model = load_sharded_model(
|
||||||
base_model,
|
base_model,
|
||||||
model_config,
|
model_config,
|
||||||
@@ -597,9 +626,12 @@ def load_model(
|
|||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
):
|
):
|
||||||
from transformers import LlamaForCausalLM
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
|
skip_move_to_device = True
|
||||||
|
if "device_map" in model_kwargs:
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -632,7 +664,11 @@ def load_model(
|
|||||||
base_model,
|
base_model,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type and not cfg.trust_remote_code:
|
elif (
|
||||||
|
model_type
|
||||||
|
and model_type != "AutoModelForCausalLM"
|
||||||
|
and not cfg.trust_remote_code
|
||||||
|
):
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -673,6 +709,7 @@ def load_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
|
# disabling either of these two still leads to VRAM spike before setting back down
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
@@ -803,11 +840,7 @@ def load_model(
|
|||||||
if not reference_model or cfg.lora_model_dir:
|
if not reference_model or cfg.lora_model_dir:
|
||||||
# if we're not loading the reference model, then we're loading the model for training
|
# if we're not loading the reference model, then we're loading the model for training
|
||||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||||
if (
|
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto"] and not cfg.merge_lora:
|
||||||
cfg.adapter
|
|
||||||
and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"]
|
|
||||||
and not cfg.merge_lora
|
|
||||||
):
|
|
||||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|||||||
@@ -427,7 +427,7 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
87
tests/e2e/patched/test_fa_xentropy.py
Normal file
87
tests/e2e/patched/test_fa_xentropy.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for lora llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from importlib import reload
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reload_transformers():
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
yield
|
||||||
|
reload(transformers.models.llama.modeling_llama)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFAXentropyLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA w multipack
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_lora_packing_fa_cross_entropy(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
"flash_attn_cross_entropy": True,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.2,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
@@ -7,6 +7,8 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
|||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="FIXME?")
|
||||||
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using S2 Attn
|
Test case for Llama models using S2 Attn
|
||||||
|
|||||||
25
tests/e2e/patched/test_unsloth_integration.py
Normal file
25
tests/e2e/patched/test_unsloth_integration.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import (
|
||||||
|
check_cel_is_patchable,
|
||||||
|
check_self_attn_is_patchable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnslothIntegration(unittest.TestCase):
|
||||||
|
"""Unsloth monkeypatch integration tests."""
|
||||||
|
|
||||||
|
def test_is_cel_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_cel_is_patchable(),
|
||||||
|
"HF transformers loss code has changed and isn't patchable",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_is_self_attn_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_self_attn_is_patchable(),
|
||||||
|
"HF transformers self attention code has changed and isn't patchable",
|
||||||
|
)
|
||||||
@@ -115,6 +115,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_kto_pair_lora(self, temp_dir):
|
def test_kto_pair_lora(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|||||||
@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 32,
|
"lora_r": 8,
|
||||||
"lora_alpha": 64,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.1,
|
||||||
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 2,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 8,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
|||||||
67
tests/e2e/test_optimizers.py
Normal file
67
tests/e2e/test_optimizers.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for custom optimizers using Llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomOptimizers(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_optimi_adamw(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "optimi_adamw",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
Reference in New Issue
Block a user