Compare commits
1 Commits
fa-261
...
llama-mult
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
469e15607d |
8
.github/CONTRIBUTING.md
vendored
8
.github/CONTRIBUTING.md
vendored
@@ -21,12 +21,12 @@ All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT
|
|||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
Bugs? Please check for open issue else create a new [Issue](https://github.com/axolotl-ai-cloud/axolotl/issues/new).
|
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
||||||
|
|
||||||
PRs are **greatly welcome**!
|
PRs are **greatly welcome**!
|
||||||
|
|
||||||
1. Fork the repository and clone it to your local machine.
|
1. Fork the repository and clone it to your local machine.
|
||||||
2. Set up the development environment by following the instructions in the [README.md](https://github.com/axolotl-ai-cloud/axolotl/tree/main/README.md) file.
|
2. Set up the development environment by following the instructions in the [README.md](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/README.md) file.
|
||||||
3. Explore the codebase, run tests, and verify that everything works as expected.
|
3. Explore the codebase, run tests, and verify that everything works as expected.
|
||||||
|
|
||||||
Please run below to setup env
|
Please run below to setup env
|
||||||
@@ -42,11 +42,11 @@ pytest tests/
|
|||||||
|
|
||||||
### Reporting Bugs
|
### Reporting Bugs
|
||||||
|
|
||||||
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/axolotl-ai-cloud/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
|
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
|
||||||
|
|
||||||
### Suggesting Enhancements
|
### Suggesting Enhancements
|
||||||
|
|
||||||
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/axolotl-ai-cloud/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
|
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
|
||||||
|
|
||||||
### Submitting Pull Requests
|
### Submitting Pull Requests
|
||||||
|
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
@@ -15,7 +15,7 @@ body:
|
|||||||
label: "Please check that this issue hasn't been reported before."
|
label: "Please check that this issue hasn't been reported before."
|
||||||
description: "The **Label filters** may help make your search more focussed."
|
description: "The **Label filters** may help make your search more focussed."
|
||||||
options:
|
options:
|
||||||
- label: "I searched previous [Bug Reports](https://github.com/axolotl-ai-cloud/axolotl/labels/bug) didn't find any similar reports."
|
- label: "I searched previous [Bug Reports](https://github.com/OpenAccess-AI-Collective/axolotl/labels/bug) didn't find any similar reports."
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/config.yml
vendored
2
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,7 +1,7 @@
|
|||||||
blank_issues_enabled: false
|
blank_issues_enabled: false
|
||||||
contact_links:
|
contact_links:
|
||||||
- name: Ask a question
|
- name: Ask a question
|
||||||
url: https://github.com/axolotl-ai-cloud/axolotl/discussions/categories/q-a
|
url: https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/q-a
|
||||||
about: Ask questions and discuss with other community members
|
about: Ask questions and discuss with other community members
|
||||||
- name: Discuss the Project in Discord
|
- name: Discuss the Project in Discord
|
||||||
url: https://discord.gg/HhrNrHJPRb
|
url: https://discord.gg/HhrNrHJPRb
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/docs.yml
vendored
2
.github/ISSUE_TEMPLATE/docs.yml
vendored
@@ -10,7 +10,7 @@ body:
|
|||||||
value: |
|
value: |
|
||||||
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
|
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
|
||||||
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
|
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
|
||||||
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/axolotl-ai-cloud/axolotl/issues).
|
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues).
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: What piece of documentation is affected?
|
label: What piece of documentation is affected?
|
||||||
|
|||||||
4
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
4
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
@@ -8,9 +8,9 @@ body:
|
|||||||
label: "⚠️ Please check that this feature request hasn't been suggested before."
|
label: "⚠️ Please check that this feature request hasn't been suggested before."
|
||||||
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
|
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
|
||||||
options:
|
options:
|
||||||
- label: "I searched previous [Ideas in Discussions](https://github.com/axolotl-ai-cloud/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
|
- label: "I searched previous [Ideas in Discussions](https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
|
||||||
required: true
|
required: true
|
||||||
- label: "I searched previous [Issues](https://github.com/axolotl-ai-cloud/axolotl/labels/enhancement) didn't find any similar feature requests."
|
- label: "I searched previous [Issues](https://github.com/OpenAccess-AI-Collective/axolotl/labels/enhancement) didn't find any similar feature requests."
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
|
|||||||
7
.github/workflows/base.yml
vendored
7
.github/workflows/base.yml
vendored
@@ -5,7 +5,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-base:
|
build-base:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
strategy:
|
strategy:
|
||||||
@@ -37,11 +37,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
16
.github/workflows/main.yml
vendored
16
.github/workflows/main.yml
vendored
@@ -8,7 +8,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-axolotl:
|
build-axolotl:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -19,6 +19,7 @@ jobs:
|
|||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||||
|
is_latest: true
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -32,9 +33,8 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -70,7 +70,7 @@ jobs:
|
|||||||
|
|
||||||
build-axolotl-cloud:
|
build-axolotl-cloud:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -80,6 +80,7 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -93,9 +94,8 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -128,7 +128,7 @@ jobs:
|
|||||||
|
|
||||||
build-axolotl-cloud-no-tmux:
|
build-axolotl-cloud-no-tmux:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -136,7 +136,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
12
.github/workflows/nightlies.yml
vendored
12
.github/workflows/nightlies.yml
vendored
@@ -7,7 +7,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-axolotl:
|
build-axolotl:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -18,6 +18,7 @@ jobs:
|
|||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||||
|
is_latest: true
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -31,9 +32,8 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -70,7 +70,7 @@ jobs:
|
|||||||
|
|
||||||
build-axolotl-cloud:
|
build-axolotl-cloud:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -80,6 +80,7 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -93,9 +94,8 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
|||||||
pytest --ignore=tests/e2e/ tests/
|
pytest --ignore=tests/e2e/ tests/
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
@@ -87,7 +87,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
18
README.md
18
README.md
@@ -67,8 +67,8 @@ Features:
|
|||||||
<p>
|
<p>
|
||||||
Go ahead and Axolotl questions!!
|
Go ahead and Axolotl questions!!
|
||||||
</p>
|
</p>
|
||||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||||
<img alt="PyTest Status" src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
|||||||
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging ninja
|
pip3 install packaging ninja
|
||||||
@@ -132,7 +132,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
|
|
||||||
# remote yaml files - the yaml config can be hosted on a public URL
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
# Note: the yaml config must directly link to the **raw** yaml
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
|
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced Setup
|
## Advanced Setup
|
||||||
@@ -333,7 +333,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
|
|||||||
|
|
||||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||||
|
|
||||||
See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
See [these docs](https://openaccess-ai-collective.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
||||||
|
|
||||||
### Config
|
### Config
|
||||||
|
|
||||||
@@ -626,10 +626,10 @@ Need dedicated support? Please contact us at [✉️wing@openaccessaicollective.
|
|||||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||||
|
|
||||||
```markdown
|
```markdown
|
||||||
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||||
```
|
```
|
||||||
|
|
||||||
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||||
|
|
||||||
## Community Showcase
|
## Community Showcase
|
||||||
|
|
||||||
@@ -647,7 +647,7 @@ PocketDoc Labs
|
|||||||
|
|
||||||
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
||||||
|
|
||||||
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
|
Bugs? Please check the [open issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues/bug) else create a new Issue.
|
||||||
|
|
||||||
PRs are **greatly welcome**!
|
PRs are **greatly welcome**!
|
||||||
|
|
||||||
@@ -665,7 +665,7 @@ pre-commit run --all-files
|
|||||||
|
|
||||||
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||||
|
|
||||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
|
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
|
||||||
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ website:
|
|||||||
- icon: twitter
|
- icon: twitter
|
||||||
href: https://twitter.com/axolotl_ai
|
href: https://twitter.com/axolotl_ai
|
||||||
- icon: github
|
- icon: github
|
||||||
href: https://github.com/axolotl-ai-cloud/axolotl/
|
href: https://github.com/OpenAccess-AI-Collective/axolotl/
|
||||||
- icon: discord
|
- icon: discord
|
||||||
href: https://discord.gg/7m9sfhzaf3
|
href: https://discord.gg/7m9sfhzaf3
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ RUN apt-get update && \
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
RUN pip install -r requirements-tests.txt
|
RUN pip install pytest
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ RUN apt-get update && \
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ RUN apt-get update && \
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto'
|
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
|
|||||||
@@ -4,25 +4,9 @@ description: How to use a custom pre-tokenized dataset.
|
|||||||
order: 5
|
order: 5
|
||||||
---
|
---
|
||||||
|
|
||||||
- Pass an empty `type:` in your axolotl config.
|
- Do not pass a `type:` in your axolotl config.
|
||||||
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
|
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
|
||||||
- To indicate that a token should be ignored during training, set its corresponding label to `-100`.
|
|
||||||
- Do not add BOS/EOS. Axolotl will add them for you based on the default tokenizer for the model you're using.
|
|
||||||
- For pretraining, do not truncate/pad documents to the context window length.
|
|
||||||
- For instruction training, documents must be truncated/padded as desired.
|
|
||||||
|
|
||||||
Sample config:
|
|
||||||
|
|
||||||
```{.yaml filename="config.yml"}
|
```{.yaml filename="config.yml"}
|
||||||
datasets:
|
- path: ...
|
||||||
- path: /path/to/your/file.jsonl
|
|
||||||
ds_type: json
|
|
||||||
type:
|
|
||||||
```
|
|
||||||
|
|
||||||
Sample jsonl:
|
|
||||||
|
|
||||||
```jsonl
|
|
||||||
{"input_ids":[271,299,99],"attention_mask":[1,1,1],"labels":[271,-100,99]}
|
|
||||||
{"input_ids":[87,227,8383,12],"attention_mask":[1,1,1,1],"labels":[87,227,8383,12]}
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl
|
|||||||
On the host that is running axolotl (ex: if you are using a remote host), clone the axolotl repo and change your current directory to the root:
|
On the host that is running axolotl (ex: if you are using a remote host), clone the axolotl repo and change your current directory to the root:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
|||||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||||
|
|
||||||
1. Set `adapter: qlora` in your axolotl config file.
|
1. Set `adapter: qlora` in your axolotl config file.
|
||||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
|
||||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||||
|
|
||||||
## Example Config
|
## Example Config
|
||||||
@@ -29,7 +29,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
|||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
- [PR #1378](https://github.com/axolotl-ai-cloud/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
|
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
|
||||||
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
|
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
|
||||||
- Related HuggingFace PRs Enabling FDSP + QLoRA:
|
- Related HuggingFace PRs Enabling FDSP + QLoRA:
|
||||||
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
|
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ description: "Template-free prompt construction with the `input_output` format"
|
|||||||
### Masking Inputs
|
### Masking Inputs
|
||||||
|
|
||||||
One of the most popular features of
|
One of the most popular features of
|
||||||
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
|
[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is
|
||||||
setting the following configuration value:
|
setting the following configuration value:
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ setting the following configuration value:
|
|||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
```
|
```
|
||||||
|
|
||||||
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
|
If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset)
|
||||||
such as `alpaca` or `chatml`, axolotl knows what is an input
|
such as `alpaca` or `chatml`, axolotl knows what is an input
|
||||||
(i.e. human) vs. an output (i.e. the assistant) and masks the input
|
(i.e. human) vs. an output (i.e. the assistant) and masks the input
|
||||||
labels so that your model can focus on predicting the outputs only.
|
labels so that your model can focus on predicting the outputs only.
|
||||||
|
|||||||
@@ -44,7 +44,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install torch==\"2.1.2\"\n",
|
"!pip install torch==\"2.1.2\"\n",
|
||||||
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
||||||
"!pip install flash-attn==\"2.5.0\"\n",
|
"!pip install flash-attn==\"2.5.0\"\n",
|
||||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||||
]
|
]
|
||||||
@@ -171,7 +171,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# By using the ! the comand will be executed as a bash command\n",
|
"# Buy using the ! the comand will be executed as a bash command\n",
|
||||||
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -188,7 +188,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# By using the ! the comand will be executed as a bash command\n",
|
"# Buy using the ! the comand will be executed as a bash command\n",
|
||||||
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
||||||
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
base_model: google/gemma-2-9b
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: gemma
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
chat_template: gemma
|
|
||||||
drop_system_message: true
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -15,7 +15,6 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|||||||
@@ -1,2 +1 @@
|
|||||||
pytest
|
pytest
|
||||||
pytest-xdist
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.11.1
|
||||||
transformers==4.42.3
|
transformers==4.41.1
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.32.0
|
accelerate==0.30.1
|
||||||
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
|
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
@@ -12,11 +12,11 @@ fire
|
|||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets==2.19.1
|
datasets==2.19.1
|
||||||
flash-attn==2.6.1
|
flash-attn==2.5.8
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.27
|
xformers==0.0.26.post1
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -31,7 +31,6 @@ art
|
|||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
@@ -40,6 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.9.6
|
trl @ git+https://github.com/huggingface/trl.git@f18253bf2d747f68acc9cd89da95c85ebf59dbb9
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace a
|
|||||||
```
|
```
|
||||||
cd /workspace
|
cd /workspace
|
||||||
rm -rf /workspace/axolotl
|
rm -rf /workspace/axolotl
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip install --no-deps -e .
|
pip install --no-deps -e .
|
||||||
```
|
```
|
||||||
|
|||||||
15
setup.py
15
setup.py
@@ -29,10 +29,9 @@ def parse_requirements():
|
|||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# don't install xformers on MacOS
|
# don't install xformers on MacOS
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
||||||
else:
|
else:
|
||||||
# detect the version of torch already installed
|
# detect the version of torch already installed
|
||||||
# and set it so dependencies don't clobber the torch version
|
# and set it so dependencies don't clobber the torch version
|
||||||
@@ -50,14 +49,12 @@ def parse_requirements():
|
|||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 3):
|
if (major, minor) >= (2, 3):
|
||||||
if patch == 0:
|
pass
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers>=0.0.26.post1")
|
|
||||||
elif (major, minor) >= (2, 2):
|
elif (major, minor) >= (2, 2):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
||||||
_install_requires.append("xformers>=0.0.25.post1")
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
else:
|
else:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
||||||
_install_requires.append("xformers>=0.0.23.post1")
|
_install_requires.append("xformers>=0.0.23.post1")
|
||||||
|
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
@@ -80,10 +77,10 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.6.1",
|
"flash-attn==2.5.8",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
do_inference,
|
do_inference,
|
||||||
@@ -34,5 +33,4 @@ def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
load_dotenv()
|
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
|
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
@@ -49,5 +48,4 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
load_dotenv()
|
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import fire
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
from dotenv import load_dotenv
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
@@ -87,5 +86,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
load_dotenv()
|
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
@@ -41,5 +40,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
load_dotenv()
|
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
|
||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
@@ -68,5 +67,4 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
load_dotenv()
|
|
||||||
fire.Fire(do_cli)
|
fire.Fire(do_cli)
|
||||||
|
|||||||
@@ -1091,8 +1091,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||||
else:
|
else:
|
||||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||||
if warmup_steps == 1:
|
|
||||||
warmup_steps = 2
|
|
||||||
|
|
||||||
logging_steps = (
|
logging_steps = (
|
||||||
self.cfg.logging_steps
|
self.cfg.logging_steps
|
||||||
@@ -1670,6 +1668,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
|
elif self.cfg.rl == "kto_pair":
|
||||||
|
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
@@ -1678,7 +1678,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
@@ -1693,7 +1693,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
elif self.cfg.rl in ["kto"]:
|
elif self.cfg.rl == "kto":
|
||||||
trainer_cls = AxolotlKTOTrainer
|
trainer_cls = AxolotlKTOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -104,12 +104,17 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
try:
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||||
CrossEntropyLoss, inplace_backward=True
|
CrossEntropyLoss, inplace_backward=True
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
LOG.info(
|
||||||
|
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||||
|
)
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
@@ -125,7 +130,7 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LOG.warning(
|
LOG.info(
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -821,6 +826,7 @@ def llama_model_forward(
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ def flashattn_forward(
|
|||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
@@ -422,9 +422,6 @@ def mistral_model_forward(
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[ # pylint: disable=unused-argument
|
|
||||||
torch.LongTensor
|
|
||||||
] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions
|
output_attentions
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
|||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||||
|
"llama",
|
||||||
"mixtral",
|
"mixtral",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"deepseek_v2",
|
"deepseek_v2",
|
||||||
@@ -30,6 +30,10 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
)
|
)
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
patch_mixtral_moe_forward_zero3()
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
elif model_type == "llama":
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
elif model_type == "qwen2":
|
elif model_type == "qwen2":
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
@@ -50,10 +54,6 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
elif model_type == "gemma2":
|
|
||||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "starcoder2":
|
elif model_type == "starcoder2":
|
||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
|
|||||||
@@ -80,9 +80,8 @@ def get_forward_code() -> str:
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def check_cel_is_patchable() -> bool:
|
def test_cel_is_patchable() -> bool:
|
||||||
forward = get_forward_code()
|
forward = get_forward_code()
|
||||||
forward, _ = detab_code(forward)
|
|
||||||
return ORIGINAL_CEL_CODE in forward
|
return ORIGINAL_CEL_CODE in forward
|
||||||
|
|
||||||
|
|
||||||
@@ -91,10 +90,9 @@ def get_self_attn_code() -> str:
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def check_self_attn_is_patchable() -> bool:
|
def test_self_attn_is_patchable() -> bool:
|
||||||
qkv = get_self_attn_code()
|
qkv = get_self_attn_code()
|
||||||
qkv, _ = detab_code(qkv)
|
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
|
||||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
|
||||||
|
|
||||||
|
|
||||||
def integrate_cross_entropy_loss_patch():
|
def integrate_cross_entropy_loss_patch():
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
message_field_role: str = "from",
|
message_field_role: str = "from",
|
||||||
message_field_content: str = "value",
|
message_field_content: str = "value",
|
||||||
roles: Optional[Dict[str, List[str]]] = None,
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
drop_system_message: bool = False,
|
|
||||||
):
|
):
|
||||||
if roles:
|
if roles:
|
||||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
@@ -40,7 +39,6 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.drop_system_message = drop_system_message
|
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||||
turns = [
|
turns = [
|
||||||
@@ -51,9 +49,6 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
for t in conversation
|
for t in conversation
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.drop_system_message and turns[0]["role"] == "system":
|
|
||||||
turns = turns[1:]
|
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
turns,
|
turns,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
@@ -116,11 +111,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
else "value"
|
else "value"
|
||||||
)
|
)
|
||||||
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
||||||
drop_system_message = (
|
|
||||||
ds_cfg["drop_system_message"]
|
|
||||||
if ds_cfg and "drop_system_message" in ds_cfg
|
|
||||||
else False
|
|
||||||
)
|
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
@@ -129,7 +119,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
message_field_role=message_field_role,
|
message_field_role=message_field_role,
|
||||||
message_field_content=message_field_content,
|
message_field_content=message_field_content,
|
||||||
roles=roles,
|
roles=roles,
|
||||||
drop_system_message=drop_system_message,
|
|
||||||
),
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ def train(
|
|||||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||||
)
|
)
|
||||||
|
|
||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||||
|
|
||||||
if getattr(cfg, "axolotl_config_path"):
|
if getattr(cfg, "axolotl_config_path"):
|
||||||
|
|||||||
@@ -116,7 +116,6 @@ class SFTDataset(BaseModel):
|
|||||||
message_field_content: Optional[str] = None
|
message_field_content: Optional[str] = None
|
||||||
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
drop_system_message: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
class UserDefinedDPOType(BaseModel):
|
||||||
@@ -165,6 +164,7 @@ class RLType(str, Enum):
|
|||||||
|
|
||||||
dpo = "dpo" # pylint: disable=invalid-name
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
|
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|||||||
@@ -120,9 +120,6 @@ def _merge_ranges(
|
|||||||
processed_ranges = [
|
processed_ranges = [
|
||||||
(start, end if end is not None else layer_size) for start, end in given_ranges
|
(start, end if end is not None else layer_size) for start, end in given_ranges
|
||||||
]
|
]
|
||||||
for start, end in processed_ranges:
|
|
||||||
if start < 0 or end > layer_size > 0 or start >= end:
|
|
||||||
raise ValueError(f"invalid unfreeze range: start={start}, end={end}")
|
|
||||||
|
|
||||||
# No need to merge if there's only one or no ranges
|
# No need to merge if there's only one or no ranges
|
||||||
if len(processed_ranges) <= 1:
|
if len(processed_ranges) <= 1:
|
||||||
|
|||||||
@@ -371,12 +371,6 @@ def load_model(
|
|||||||
rms_norm=cfg.flash_attn_rms_norm,
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
use_shifted_sparse_attn=True,
|
use_shifted_sparse_attn=True,
|
||||||
)
|
)
|
||||||
elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
|
|
||||||
replace_llama_attn_with_flash_attn(
|
|
||||||
packed=False,
|
|
||||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
|
||||||
rms_norm=cfg.flash_attn_rms_norm,
|
|
||||||
)
|
|
||||||
elif cfg.xformers_attention:
|
elif cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_attention,
|
hijack_llama_attention,
|
||||||
@@ -575,11 +569,9 @@ def load_model(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
if ( # pylint: disable=condition-evals-to-constant)
|
if (
|
||||||
(cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
|
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
and not qlora_fsdp
|
) and not qlora_fsdp:
|
||||||
and False
|
|
||||||
):
|
|
||||||
model = load_sharded_model(
|
model = load_sharded_model(
|
||||||
base_model,
|
base_model,
|
||||||
model_config,
|
model_config,
|
||||||
@@ -605,12 +597,9 @@ def load_model(
|
|||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
):
|
):
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
from transformers import LlamaForCausalLM
|
||||||
skip_move_to_device = True
|
|
||||||
if "device_map" in model_kwargs:
|
|
||||||
del model_kwargs["device_map"]
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -643,11 +632,7 @@ def load_model(
|
|||||||
base_model,
|
base_model,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif (
|
elif model_type and not cfg.trust_remote_code:
|
||||||
model_type
|
|
||||||
and model_type != "AutoModelForCausalLM"
|
|
||||||
and not cfg.trust_remote_code
|
|
||||||
):
|
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -688,7 +673,6 @@ def load_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
# disabling either of these two still leads to VRAM spike before setting back down
|
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
@@ -819,7 +803,11 @@ def load_model(
|
|||||||
if not reference_model or cfg.lora_model_dir:
|
if not reference_model or cfg.lora_model_dir:
|
||||||
# if we're not loading the reference model, then we're loading the model for training
|
# if we're not loading the reference model, then we're loading the model for training
|
||||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||||
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto"] and not cfg.merge_lora:
|
if (
|
||||||
|
cfg.adapter
|
||||||
|
and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"]
|
||||||
|
and not cfg.merge_lora
|
||||||
|
):
|
||||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|||||||
@@ -427,7 +427,7 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
|
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for lora llama
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
from importlib import reload
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def reload_transformers():
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
|
|
||||||
yield
|
|
||||||
reload(transformers.models.llama.modeling_llama)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFAXentropyLlama(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test case for Llama models using LoRA w multipack
|
|
||||||
"""
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_lora_packing_fa_cross_entropy(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"sample_packing": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
"flash_attn_cross_entropy": True,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 32,
|
|
||||||
"lora_alpha": 64,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.2,
|
|
||||||
"special_tokens": {
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
|
||||||
@@ -7,8 +7,6 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
@@ -21,7 +19,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
|||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="FIXME?")
|
|
||||||
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using S2 Attn
|
Test case for Llama models using S2 Attn
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import (
|
|
||||||
check_cel_is_patchable,
|
|
||||||
check_self_attn_is_patchable,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestUnslothIntegration(unittest.TestCase):
|
|
||||||
"""Unsloth monkeypatch integration tests."""
|
|
||||||
|
|
||||||
def test_is_cel_patchable(self):
|
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
|
||||||
self.assertTrue(
|
|
||||||
check_cel_is_patchable(),
|
|
||||||
"HF transformers loss code has changed and isn't patchable",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_is_self_attn_patchable(self):
|
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
|
||||||
self.assertTrue(
|
|
||||||
check_self_attn_is_patchable(),
|
|
||||||
"HF transformers self attention code has changed and isn't patchable",
|
|
||||||
)
|
|
||||||
@@ -115,7 +115,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|
||||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_kto_pair_lora(self, temp_dir):
|
def test_kto_pair_lora(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|||||||
Reference in New Issue
Block a user