Compare commits
117 Commits
feature/re
...
benchmark-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3de28942c | ||
|
|
45848a9285 | ||
|
|
d6cea18034 | ||
|
|
606846e0a5 | ||
|
|
a6c9223114 | ||
|
|
8b16ecd448 | ||
|
|
f5db88a10d | ||
|
|
99d844f215 | ||
|
|
aefd4d74fa | ||
|
|
24b0e93235 | ||
|
|
2455254b92 | ||
|
|
918e040601 | ||
|
|
ef062d8fcb | ||
|
|
d4c8b66f3d | ||
|
|
64e9824d3e | ||
|
|
1134654c98 | ||
|
|
2fc756c289 | ||
|
|
943b84c490 | ||
|
|
6f166464d8 | ||
|
|
e3b07402a7 | ||
|
|
8d3c8a3eab | ||
|
|
c30120e684 | ||
|
|
9aed60fa54 | ||
|
|
98bf76e236 | ||
|
|
4c37bd0b54 | ||
|
|
f144e98a32 | ||
|
|
3a011ea1ef | ||
|
|
1f613e5aa7 | ||
|
|
f319b0bc67 | ||
|
|
7fd662dd89 | ||
|
|
9e699683d7 | ||
|
|
35130711d6 | ||
|
|
3fc9006298 | ||
|
|
ad8be435ad | ||
|
|
fe4d6baf92 | ||
|
|
f31301063d | ||
|
|
868530c39c | ||
|
|
d03887fad5 | ||
|
|
17605b85d8 | ||
|
|
a184549e4c | ||
|
|
f311df9462 | ||
|
|
c500d02517 | ||
|
|
31f3e71764 | ||
|
|
56c4a94caf | ||
|
|
c29117a0d7 | ||
|
|
0b7ba57ec4 | ||
|
|
71bd06243c | ||
|
|
cb9797ef5a | ||
|
|
bde3c5a478 | ||
|
|
55c23c7bcb | ||
|
|
c69faee7a7 | ||
|
|
d5dcf9c350 | ||
|
|
f4746507f6 | ||
|
|
96deb6bd67 | ||
|
|
50682a3c06 | ||
|
|
5a1985ba24 | ||
|
|
5e9c6afa10 | ||
|
|
a213d9972a | ||
|
|
fbf49a4770 | ||
|
|
58cf7e7fed | ||
|
|
04a42b6db1 | ||
|
|
919f4cac90 | ||
|
|
ee262818ef | ||
|
|
9d629d8bff | ||
|
|
d2e7f27240 | ||
|
|
d21318dfb9 | ||
|
|
f733d0f31e | ||
|
|
008505c8ae | ||
|
|
b3f5e00ff5 | ||
|
|
5247c5004e | ||
|
|
cf6654769a | ||
|
|
06edf175ac | ||
|
|
0a228479b3 | ||
|
|
82e111aba9 | ||
|
|
8cace80175 | ||
|
|
1b7e8604bb | ||
|
|
3d1f203b62 | ||
|
|
d3d6fd6ae6 | ||
|
|
b7449a997f | ||
|
|
5f80b3560b | ||
|
|
24959091d7 | ||
|
|
7af816699e | ||
|
|
f806e86a6e | ||
|
|
2b990eb628 | ||
|
|
bd8cab49c9 | ||
|
|
c01015f33f | ||
|
|
72fe3f8e3d | ||
|
|
47961fdb8b | ||
|
|
7ad37cb6d7 | ||
|
|
29241cf1e4 | ||
|
|
31db0ecce4 | ||
|
|
da10af03e9 | ||
|
|
85cf4f8e2c | ||
|
|
2e22404d2d | ||
|
|
be294fd605 | ||
|
|
fc2d6be96d | ||
|
|
1687be6a35 | ||
|
|
41ecb451c2 | ||
|
|
3c2ad00d07 | ||
|
|
5d48a10548 | ||
|
|
73a0b6ead5 | ||
|
|
63fdb5a7fb | ||
|
|
fdffef5940 | ||
|
|
919246fbc1 | ||
|
|
ffac902c1b | ||
|
|
15f6e57eaa | ||
|
|
729c299256 | ||
|
|
86a91e260b | ||
|
|
094fc2c6e6 | ||
|
|
2dafa730ef | ||
|
|
343ac84e5a | ||
|
|
0c967279ce | ||
|
|
efb3b2c95e | ||
|
|
7b55fe6419 | ||
|
|
e029ab34ea | ||
|
|
8cec513447 | ||
|
|
a13e45d548 |
129
.github/CODE_OF_CONDUCT.md
vendored
Normal file
129
.github/CODE_OF_CONDUCT.md
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement on Discord
|
||||
at https://discord.gg/QYF8QrtEUm
|
||||
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
76
.github/CONTRIBUTING.md
vendored
Normal file
76
.github/CONTRIBUTING.md
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
# Contributing to axolotl
|
||||
|
||||
First of all, thank you for your interest in contributing to axolotl! We appreciate the time and effort you're willing to invest in making our project better. This document provides guidelines and information to make the contribution process as smooth as possible.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Code of Conduct](#code-of-conduct)
|
||||
- [Getting Started](#getting-started)
|
||||
- [How to Contribute](#how-to-contribute)
|
||||
- [Reporting Bugs](#reporting-bugs)
|
||||
- [Suggesting Enhancements](#suggesting-enhancements)
|
||||
- [Submitting Pull Requests](#submitting-pull-requests)
|
||||
- [Style Guidelines](#style-guidelines)
|
||||
- [Code Style](#code-style)
|
||||
- [Commit Messages](#commit-messages)
|
||||
- [Additional Resources](#additional-resources)
|
||||
|
||||
## Code of Conductcode
|
||||
|
||||
All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.
|
||||
|
||||
## Getting Started
|
||||
|
||||
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
||||
|
||||
PRs are **greatly welcome**!
|
||||
|
||||
1. Fork the repository and clone it to your local machine.
|
||||
2. Set up the development environment by following the instructions in the [README.md](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/README.md) file.
|
||||
3. Explore the codebase, run tests, and verify that everything works as expected.
|
||||
|
||||
Please run below to setup env
|
||||
```bash
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
pre-commit install
|
||||
|
||||
# test
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
## How to Contribute
|
||||
|
||||
### Reporting Bugs
|
||||
|
||||
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
|
||||
|
||||
### Suggesting Enhancements
|
||||
|
||||
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
|
||||
|
||||
### Submitting Pull Requests
|
||||
|
||||
1. Create a new branch for your feature or bugfix. Use a descriptive name like `feature/your-feature-name` or `fix/your-bugfix-name`.
|
||||
2. Make your changes, following the [Style Guidelines](#style-guidelines) below.
|
||||
3. Test your changes and ensure that they don't introduce new issues or break existing functionality.
|
||||
4. Commit your changes, following the [commit message guidelines](#commit-messages).
|
||||
5. Push your branch to your fork on GitHub.
|
||||
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
|
||||
|
||||
## Style Guidelines
|
||||
|
||||
### Code Style
|
||||
|
||||
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
|
||||
|
||||
### Commit Messages
|
||||
|
||||
Write clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., "Add new feature" or "Fix bug in function".
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [GitHub Help](https://help.github.com/)
|
||||
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
|
||||
- [{codestyle}]({URLofCodestyle})
|
||||
|
||||
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!
|
||||
13
.github/FUNDING.yml
vendored
Normal file
13
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
105
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
Normal file
105
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
name: Bug Report
|
||||
description: File a bug report
|
||||
labels: ["bug", "needs triage"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Before you start
|
||||
Please **make sure you are on the latest version.**
|
||||
If you encountered the issue after you installed, updated, or reloaded, **please try restarting before reporting the bug**.
|
||||
|
||||
- type: checkboxes
|
||||
id: no-duplicate-issues
|
||||
attributes:
|
||||
label: "Please check that this issue hasn't been reported before."
|
||||
description: "The **Label filters** may help make your search more focussed."
|
||||
options:
|
||||
- label: "I searched previous [Bug Reports](https://github.com/OpenAccess-AI-Collective/axolotl/labels/bug) didn't find any similar reports."
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: expected
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: Tell us what **should** happen.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: what-happened
|
||||
attributes:
|
||||
label: Current behaviour
|
||||
description: |
|
||||
Tell us what happens instead of the expected behavior.
|
||||
Provide stacktrace and/or screenshots.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduce
|
||||
attributes:
|
||||
label: Steps to reproduce
|
||||
description: |
|
||||
Which exact steps can a developer take to reproduce the issue?
|
||||
The more detail you provide, the easier it will be to narrow down and fix the bug.
|
||||
Please paste in tasks and/or queries **as text, not screenshots**.
|
||||
placeholder: |
|
||||
Example of the level of detail needed to reproduce any bugs efficiently and reliably.
|
||||
1. Go to the '...' page.
|
||||
2. Click on the '...' button.
|
||||
3. Scroll down to '...'.
|
||||
4. Observe the error.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: possible-solution
|
||||
attributes:
|
||||
label: Possible solution
|
||||
description: |
|
||||
Not obligatory, but please suggest a fix or reason for the bug, if you have an idea.
|
||||
|
||||
|
||||
- type: checkboxes
|
||||
id: operating-systems
|
||||
attributes:
|
||||
label: Which Operating Systems are you using?
|
||||
description: You may select more than one.
|
||||
options:
|
||||
- label: Linux
|
||||
- label: macOS
|
||||
- label: Windows
|
||||
|
||||
- type: input
|
||||
id: Python-version
|
||||
attributes:
|
||||
label: Python Version
|
||||
description: Which {Programming} version are you using?
|
||||
placeholder: 3.10 / please change accordingly
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: axolotl-branch-commit
|
||||
attributes:
|
||||
label: axolotl branch-commit
|
||||
description: On which branch/commit are you?
|
||||
placeholder: main/4d6490b
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: acknowledgements
|
||||
attributes:
|
||||
label: 'Acknowledgements'
|
||||
description: 'Please confirm the following:'
|
||||
options:
|
||||
- label: 'My issue title is concise, descriptive, and in title casing.'
|
||||
required: true
|
||||
- label: 'I have searched the existing issues to make sure this bug has not been reported yet.'
|
||||
required: true
|
||||
- label: 'I am using the latest version of axolotl.'
|
||||
required: true
|
||||
- label: 'I have provided enough information for the maintainers to reproduce and diagnose the issue.'
|
||||
required: true
|
||||
7
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
7
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Ask a question
|
||||
url: https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/q-a
|
||||
about: Ask questions and discuss with other community members
|
||||
- name: Discuss the Project in Discord
|
||||
url: https://discord.gg/HhrNrHJPRb
|
||||
46
.github/ISSUE_TEMPLATE/docs.yml
vendored
Normal file
46
.github/ISSUE_TEMPLATE/docs.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: Documentation Improvement / Clarity
|
||||
description: Make a suggestion to improve the project documentation.
|
||||
labels: ['needs triage', 'docs']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: '## :book: Documentation :book:'
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
|
||||
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
|
||||
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: What piece of documentation is affected?
|
||||
description: Please link to the article you'd like to see updated.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: What part(s) of the article would you like to see updated?
|
||||
description: |
|
||||
- Give as much detail as you can to help us understand the change you want to see.
|
||||
- Why should the docs be changed? What use cases does it support?
|
||||
- What is the expected outcome?
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional Information
|
||||
description: Add any other context or screenshots about the feature request here.
|
||||
validations:
|
||||
required: false
|
||||
- type: checkboxes
|
||||
id: acknowledgements
|
||||
attributes:
|
||||
label: 'Acknowledgements'
|
||||
description: 'Please confirm the following:'
|
||||
options:
|
||||
- label: 'My issue title is concise, descriptive, and in title casing.'
|
||||
required: true
|
||||
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
|
||||
required: true
|
||||
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
|
||||
required: true
|
||||
63
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
Normal file
63
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
name: Feature Request / Enhancement
|
||||
description: Suggest a new feature or feature enhancement for the project
|
||||
labels: ["enhancement", "needs triage"]
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: no-duplicate-issues
|
||||
attributes:
|
||||
label: "⚠️ Please check that this feature request hasn't been suggested before."
|
||||
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
|
||||
options:
|
||||
- label: "I searched previous [Ideas in Discussions](https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
|
||||
required: true
|
||||
- label: "I searched previous [Issues](https://github.com/OpenAccess-AI-Collective/axolotl/labels/enhancement) didn't find any similar feature requests."
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: feature-description
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: "🔖 Feature description"
|
||||
description: "A clear and concise description of what the feature request is."
|
||||
placeholder: "You should add ..."
|
||||
|
||||
- type: textarea
|
||||
id: solution
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: "✔️ Solution"
|
||||
description: "A clear and concise description of what you want to happen, and why."
|
||||
placeholder: "In my use-case, ..."
|
||||
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: "❓ Alternatives"
|
||||
description: "A clear and concise description of any alternative solutions or features you've considered."
|
||||
placeholder: "I have considered ..."
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: "📝 Additional Context"
|
||||
description: "Add any other context or screenshots about the feature request here."
|
||||
placeholder: "..."
|
||||
|
||||
- type: checkboxes
|
||||
id: acknowledgements
|
||||
attributes:
|
||||
label: 'Acknowledgements'
|
||||
description: 'Please confirm the following:'
|
||||
options:
|
||||
- label: 'My issue title is concise, descriptive, and in title casing.'
|
||||
required: true
|
||||
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
|
||||
required: true
|
||||
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
|
||||
required: true
|
||||
22
.github/PULL_REQUEST_TEMPLATE/pull_request_template_simple.md
vendored
Normal file
22
.github/PULL_REQUEST_TEMPLATE/pull_request_template_simple.md
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
<!--- Provide a general summary of your changes in the Title above -->
|
||||
|
||||
# Description
|
||||
|
||||
<!--- Describe your changes in detail -->
|
||||
|
||||
## Motivation and Context
|
||||
|
||||
<!--- Why is this change required? What problem does it solve? -->
|
||||
<!--- If it fixes an open issue, please link to the issue here. -->
|
||||
|
||||
## How has this been tested?
|
||||
|
||||
<!--- Please describe in detail how you tested your changes. -->
|
||||
<!--- Include details of your testing environment, tests ran to see how -->
|
||||
<!--- your change affects other areas of the code, etc. -->
|
||||
|
||||
## Screenshots (if appropriate)
|
||||
|
||||
## Types of changes
|
||||
|
||||
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
|
||||
9
.github/SECURITY.md
vendored
Normal file
9
.github/SECURITY.md
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Due to the nature of the fast development that is happening in this project, only the latest released version can be supported.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you find a vulnerability, please contact us on [Discord](https://discord.gg/xcu3ECkH9a) rather than creating a GitHub issue to allow us some time to fix it before it is a known vulnerability to others.
|
||||
10
.github/SUPPORT.md
vendored
Normal file
10
.github/SUPPORT.md
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# Support
|
||||
|
||||
If you need help with this project or have questions, please:
|
||||
|
||||
1. Check the documentation.
|
||||
2. Search the existing issues and pull requests.
|
||||
3. Create a new issue if your question is not answered or your problem is not solved.
|
||||
4. Have a look in the [Discord server](https://discord.gg/HhrNrHJPRb)
|
||||
|
||||
Please note that this project is maintained by volunteers who have limited availability. We'll do our best to address your questions and concerns in a timely manner.
|
||||
17
.github/workflows/main.yml
vendored
17
.github/workflows/main.yml
vendored
@@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
- "dev"
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
@@ -14,17 +13,17 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: cu118
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
@@ -50,10 +49,11 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
@@ -72,6 +72,7 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
@@ -102,5 +103,7 @@ jobs:
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-runpod
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .
|
||||
pip install -e .[peft]
|
||||
pip install -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
|
||||
174
README.md
174
README.md
@@ -1,10 +1,40 @@
|
||||
# Axolotl
|
||||
|
||||
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
## Table of Contents
|
||||
- [Introduction](#axolotl)
|
||||
- [Supported Features](#axolotl-supports)
|
||||
- [Quickstart](#quickstart-)
|
||||
- [Installation](#installation)
|
||||
- [Docker Installation](#environment)
|
||||
- [Conda/Pip venv Installation](#condapip-venv)
|
||||
- [LambdaLabs Installation](#lambdalabs)
|
||||
- [Dataset](#dataset)
|
||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||
- [Config](#config)
|
||||
- [Train](#train)
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Common Errors](#common-errors-)
|
||||
- [Need Help?](#need-help-)
|
||||
- [Badge](#badge-)
|
||||
- [Community Showcase](#community-showcase)
|
||||
- [Contributing](#contributing-)
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
<div align="center">
|
||||
<img src="image/axolotl.png" alt="axolotl" width="160">
|
||||
<div>
|
||||
<p>
|
||||
<b>One repo to finetune them all! </b>
|
||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||
</p>
|
||||
<p>
|
||||
Go ahead and axolotl questions!!
|
||||
@@ -14,27 +44,34 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Axolotl supports
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | ✅ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ❓ | ✅
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|-------------------|------------|---------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
|
||||
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
||||
|
||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install -e .
|
||||
pip3 install -e .[flash-attn]
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
# finetune lora
|
||||
@@ -63,7 +100,7 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
```
|
||||
|
||||
- Conda/Pip venv
|
||||
1. Install python **3.9**
|
||||
1. Install python >=**3.9**
|
||||
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
@@ -116,9 +153,7 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
|
||||
pip3 install -e . # change depend on needs
|
||||
pip3 install protobuf==3.20.3
|
||||
pip3 install -U requests
|
||||
pip3 install -U --ignore-installed psutil
|
||||
pip3 install -U scipy
|
||||
pip3 install -U --ignore-installed requests Pillow psutil scipy
|
||||
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
|
||||
```
|
||||
|
||||
@@ -130,13 +165,14 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
|
||||
### Dataset
|
||||
|
||||
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
||||
Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
- `alpaca`: instruction; input(optional)
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt:chat`: conversations
|
||||
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -221,10 +257,18 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `metharme`: instruction, adds additional eos tokens
|
||||
```json
|
||||
{"prompt": "...", "generation": "..."}
|
||||
```
|
||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||
```json
|
||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||
@@ -234,11 +278,29 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
Using yaml. Example:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
no_input_format: |-
|
||||
User: {instruction}<|end_of_turn|>
|
||||
Assistant:
|
||||
format: |-
|
||||
User: {instruction}
|
||||
{input}<|end_of_turn|>
|
||||
Assistant:
|
||||
```
|
||||
|
||||
Optionally, download some datasets, see [data/README.md](data/README.md)
|
||||
Using file:
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
|
||||
#### How to use your custom pretokenized dataset
|
||||
|
||||
- Do not pass a `type:`
|
||||
- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
|
||||
|
||||
|
||||
### Config
|
||||
@@ -268,9 +330,9 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
|
||||
# local
|
||||
datasets:
|
||||
- path: json
|
||||
data_files: data.jsonl # or json
|
||||
type: alpaca # format from earlier
|
||||
- path: data.jsonl # or json
|
||||
ds_type: json # see other options below
|
||||
type: alpaca
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -322,6 +384,8 @@ tokenizer_type: AutoTokenizer
|
||||
trust_remote_code:
|
||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# resize the model embeddings when new tokens are added to multiples of 32
|
||||
# this is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
@@ -349,10 +413,29 @@ datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
|
||||
data_files: # path to source data files
|
||||
shards: # number of shards to split data into
|
||||
name: # name of dataset configuration to load
|
||||
|
||||
# custom user prompt
|
||||
- path: repo
|
||||
type:
|
||||
# the below are defaults. only set what's needed.
|
||||
system_prompt: ""
|
||||
field_system: system
|
||||
field_instruction: instruction
|
||||
field_output: input
|
||||
|
||||
# customizable to be single line or multi-line
|
||||
system_format: "{system}"
|
||||
# 'format' can include {input}
|
||||
format: |-
|
||||
User: {instruction} {input}
|
||||
Assistant:
|
||||
# 'no_input_format' cannot include {input}
|
||||
no_input_format: "{instruction} "
|
||||
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
@@ -360,10 +443,13 @@ dataset_prepared_path: data/last_run_prepared
|
||||
push_dataset_to_hub: # repo path
|
||||
# push checkpoints to hub
|
||||
hub_model_id: # repo path to push finetuned model
|
||||
# how to push checkpoints to hub
|
||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
hub_strategy:
|
||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||
val_set_size: 0.04
|
||||
# Num shards for whole dataset
|
||||
dataset_shard_num:
|
||||
@@ -407,6 +493,12 @@ lora_modules_to_save:
|
||||
lora_out_dir:
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# ReLoRA configuration
|
||||
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # number of per-restart warmup steps
|
||||
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
|
||||
# wandb configuration if you're using it
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # your wandb project name
|
||||
@@ -425,10 +517,13 @@ eval_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
save_steps:
|
||||
eval_steps:
|
||||
save_total_limit:
|
||||
save_strategy: # set to `no` to skip checkpoint saves
|
||||
save_steps: # leave empty to save at each epoch
|
||||
eval_steps: # leave empty to eval at each epoch
|
||||
save_total_limit: # checkpoints saved at a time
|
||||
max_steps:
|
||||
|
||||
# save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
@@ -473,8 +568,8 @@ max_grad_norm:
|
||||
flash_optimum:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
||||
flash_attention: # require a100 for llama
|
||||
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
@@ -509,7 +604,7 @@ tokens:
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
# Deepspeed
|
||||
# Deepspeed config path
|
||||
deepspeed:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
@@ -537,7 +632,7 @@ strict:
|
||||
|
||||
Run
|
||||
```bash
|
||||
accelerate launch scripts/finetune.py configs/your_config.yml
|
||||
accelerate launch scripts/finetune.py your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
@@ -560,7 +655,10 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
|
||||
- llama Deepspeed
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero3.json
|
||||
```
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
@@ -608,7 +706,7 @@ CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
|
||||
|
||||
## Common Errors 🧰
|
||||
|
||||
> Cuda out of memory
|
||||
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
|
||||
|
||||
Please reduce any below
|
||||
- `micro_batch_size`
|
||||
@@ -616,6 +714,12 @@ Please reduce any below
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
> `failed (exitcode: -9)`
|
||||
|
||||
Usually means your system has run out of system memory.
|
||||
Similarly, you should consider reducing the same settings as when you run out of VRAM.
|
||||
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.
|
||||
|
||||
> RuntimeError: expected scalar type Float but found Half
|
||||
|
||||
Try set `fp16: true`
|
||||
@@ -644,6 +748,8 @@ Building something cool with Axolotl? Consider adding a badge to your model card
|
||||
|
||||
## Community Showcase
|
||||
|
||||
Check out some of the projects and models that have been built using Axolotl! Have a model you'd like to add to our Community Showcase? Open a PR with your model.
|
||||
|
||||
Open Access AI Collective
|
||||
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
|
||||
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
|
||||
@@ -654,7 +760,9 @@ PocketDoc Labs
|
||||
|
||||
## Contributing 🤝
|
||||
|
||||
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
||||
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
||||
|
||||
Bugs? Please check the [open issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues/bug) else create a new Issue.
|
||||
|
||||
PRs are **greatly welcome**!
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
|
||||
## Download some datasets
|
||||
```shell
|
||||
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o data/raw/alpaca_data_gpt4.json
|
||||
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o data/raw/vicuna_cleaned.json
|
||||
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o data/raw/gpt4-instruct-similarity-0.6-dataset.json
|
||||
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o data/raw/roleplay-similarity_0.6-instruct-dataset.json
|
||||
```
|
||||
|
||||
## Convert the JSON data files to JSONL.
|
||||
|
||||
```shell
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
```
|
||||
---
|
||||
|
||||
Using JSONL makes it easier to subset the data if you want a smaller training set, i.e get 2000 random examples.
|
||||
|
||||
```shell
|
||||
shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
|
||||
```
|
||||
1
data/raw/.gitignore
vendored
1
data/raw/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
**
|
||||
46
deepspeed/zero2.json
Normal file
46
deepspeed/zero2.json
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"contiguous_gradients": true,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.999
|
||||
],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -16,9 +16,9 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN cd axolotl && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[$AXOLOTL_EXTRAS]; \
|
||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .; \
|
||||
pip install -e .[flash-attn]; \
|
||||
fi
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
|
||||
@@ -31,26 +31,6 @@ WORKDIR /workspace
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
|
||||
FROM base-builder AS flash-attn-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
||||
cd flash-attention && \
|
||||
git checkout v2.0.1 && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd csrc/fused_dense_lib && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd ../xentropy && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd ../rotary && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd ../layer_norm && \
|
||||
python3 setup.py bdist_wheel
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
@@ -90,13 +70,8 @@ RUN mkdir -p /workspace/wheels/bitsandbytes
|
||||
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
|
||||
|
||||
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
|
||||
RUN pip3 install wheels/deepspeed-*.whl
|
||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||
RUN git lfs install --skip-repo
|
||||
RUN pip3 install awscli && \
|
||||
|
||||
67
examples/code-llama/13b/lora.yml
Normal file
67
examples/code-llama/13b/lora.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 100000
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
69
examples/code-llama/13b/qlora.yml
Normal file
69
examples/code-llama/13b/qlora.yml
Normal file
@@ -0,0 +1,69 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 100000
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
67
examples/code-llama/34b/lora.yml
Normal file
67
examples/code-llama/34b/lora.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 100000
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
69
examples/code-llama/34b/qlora.yml
Normal file
69
examples/code-llama/34b/qlora.yml
Normal file
@@ -0,0 +1,69 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 100000
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
67
examples/code-llama/7b/lora.yml
Normal file
67
examples/code-llama/7b/lora.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 100000
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
69
examples/code-llama/7b/qlora.yml
Normal file
69
examples/code-llama/7b/qlora.yml
Normal file
@@ -0,0 +1,69 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 100000
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
22
examples/code-llama/README.md
Normal file
22
examples/code-llama/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# Overview
|
||||
|
||||
This is an example of CodeLLaMA configuration for 7b, 13b and 34b.
|
||||
|
||||
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
|
||||
|
||||
The 13b variant will fit if you change these settings to these values:
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
|
||||
The 34b variant does not fit on 24GB of VRAM - you will need something with +40 gb VRAM that also supports flash attention v2 - A6000 or A100 are good choices.
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/qlora.yml
|
||||
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/lora.yml
|
||||
|
||||
```
|
||||
@@ -57,7 +57,7 @@ weight_decay: 0.0001
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
tokens:
|
||||
pad_token: "[PAD]"
|
||||
pad_token: "<pad>"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
|
||||
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
@@ -15,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
max_packed_sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -49,8 +50,8 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
@@ -64,4 +65,3 @@ special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
pad_token: "<pad>"
|
||||
|
||||
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
@@ -18,7 +19,8 @@ adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
max_packed_sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
@@ -50,8 +52,8 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
@@ -65,4 +67,3 @@ special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
pad_token: "<pad>"
|
||||
|
||||
73
examples/llama-2/relora.yml
Normal file
73
examples/llama-2/relora.yml
Normal file
@@ -0,0 +1,73 @@
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./relora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
relora_steps: 150
|
||||
relora_warmup_steps: 10
|
||||
relora_cpu_offload: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps: 50
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,20 +1,23 @@
|
||||
packaging
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||
addict
|
||||
evaluate
|
||||
fire
|
||||
PyYAML==6.0
|
||||
PyYAML>=6.0
|
||||
datasets
|
||||
accelerate>=0.19.0
|
||||
flash-attn>=2.0.8
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers
|
||||
optimum
|
||||
hf_transfer
|
||||
colorama
|
||||
numba
|
||||
numpy==1.24.4
|
||||
numpy>=1.24.4
|
||||
# qlora things
|
||||
bert-score==0.3.13
|
||||
evaluate==0.4.0
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Module to convert json file to jsonl"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import fire
|
||||
|
||||
from axolotl.convert import (
|
||||
FileReader,
|
||||
FileWriter,
|
||||
JsonlSerializer,
|
||||
JsonParser,
|
||||
JsonToJsonlConverter,
|
||||
StdoutWriter,
|
||||
)
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
configure_logging()
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
|
||||
def main(
|
||||
file: Path,
|
||||
output: Optional[Path] = None,
|
||||
to_stdout: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Convert a json file to jsonl
|
||||
"""
|
||||
|
||||
file_reader = FileReader()
|
||||
writer: Union[StdoutWriter, FileWriter]
|
||||
if to_stdout or output is None:
|
||||
writer = StdoutWriter()
|
||||
else:
|
||||
writer = FileWriter(output)
|
||||
json_parser = JsonParser()
|
||||
jsonl_serializer = JsonlSerializer()
|
||||
|
||||
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
||||
|
||||
converter.convert(file, output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
@@ -18,18 +18,13 @@ from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
setup_trainer,
|
||||
)
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@@ -39,30 +34,21 @@ sys.path.insert(0, src_dir)
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
|
||||
def choose_device(cfg):
|
||||
def get_device():
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
return f"cuda:{cfg.local_rank}"
|
||||
def print_axolotl_text_art():
|
||||
ascii_art = """
|
||||
dP dP dP
|
||||
88 88 88
|
||||
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||
"""
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
raise SystemError("No CUDA/mps device found")
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.device_map != "auto":
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
if is_main_process():
|
||||
print(ascii_art)
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
@@ -96,6 +82,8 @@ def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
@@ -174,6 +162,7 @@ def train(
|
||||
prepare_ds_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
print_axolotl_text_art()
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(config)
|
||||
|
||||
@@ -194,67 +183,18 @@ def train(
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
normalize_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
cfg.tf32 = False
|
||||
if cfg.bf16:
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
|
||||
if cfg.tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# load the tokenizer first
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
LOG.info(f"loading tokenizer... {tokenizer_config}")
|
||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
if (
|
||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||
): # don't need to load dataset for these
|
||||
if not cfg.pretraining_dataset:
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
train_dataset = load_pretraining_dataset(
|
||||
cfg.pretraining_dataset,
|
||||
tokenizer,
|
||||
max_tokens=cfg.sequence_len,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
|
||||
if is_main_process():
|
||||
# process on rank 0 first so it gets cached so other ranks load from cache
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
barrier()
|
||||
if not is_main_process():
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
barrier()
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
LOG.info("check_dataset_labels...")
|
||||
@@ -269,8 +209,6 @@ def train(
|
||||
LOG.info("Finished preparing dataset. Exiting...")
|
||||
return
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
# Load the model and tokenizer
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, peft_config = load_model(cfg, tokenizer)
|
||||
@@ -306,6 +244,21 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
return
|
||||
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
@@ -337,23 +290,10 @@ def train(
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
||||
)
|
||||
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(cfg.output_dir)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
@@ -364,6 +304,13 @@ def train(
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
model = model.merge_and_unload()
|
||||
else:
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
@@ -371,6 +318,7 @@ def train(
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
|
||||
8
setup.py
8
setup.py
@@ -7,6 +7,7 @@ with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
# don't include peft yet until we check the int4
|
||||
# need to manually install peft for now...
|
||||
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
||||
reqs = [r for r in reqs if "flash-attn" not in r]
|
||||
reqs = [r for r in reqs if r and r[0] != "#"]
|
||||
for r in reqs:
|
||||
install_requires.append(r)
|
||||
@@ -25,9 +26,14 @@ setup(
|
||||
"gptq_triton": [
|
||||
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
"flash-attn": [
|
||||
"flash-attn==2.0.8",
|
||||
],
|
||||
"extras": [
|
||||
"flash-attn",
|
||||
"deepspeed",
|
||||
],
|
||||
"peft": [
|
||||
"peft @ git+https://github.com/huggingface/peft.git",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,16 +1,42 @@
|
||||
"""Logging configuration settings"""
|
||||
"""
|
||||
Common logging module for axolotl
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from logging import Formatter
|
||||
from logging.config import dictConfig
|
||||
from typing import Any, Dict
|
||||
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
|
||||
class ColorfulFormatter(Formatter):
|
||||
"""
|
||||
Formatter to add coloring to log messages by log type
|
||||
"""
|
||||
|
||||
COLORS = {
|
||||
"WARNING": Fore.YELLOW,
|
||||
"ERROR": Fore.RED,
|
||||
"CRITICAL": Fore.RED + Style.BRIGHT,
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_message = super().format(record)
|
||||
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
||||
|
||||
|
||||
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
"version": 1,
|
||||
"formatters": {
|
||||
"simple": {
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
||||
},
|
||||
"colorful": {
|
||||
"()": ColorfulFormatter,
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
||||
},
|
||||
},
|
||||
"filters": {},
|
||||
"handlers": {
|
||||
@@ -20,14 +46,25 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
},
|
||||
"color_console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "colorful",
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
},
|
||||
},
|
||||
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
|
||||
"loggers": {
|
||||
"axolotl": {"handlers": ["console"], "level": "DEBUG", "propagate": False},
|
||||
"axolotl": {
|
||||
"handlers": ["color_console"],
|
||||
"level": "DEBUG",
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging():
|
||||
"""Configure with default logging"""
|
||||
init() # Initialize colorama
|
||||
dictConfig(DEFAULT_LOGGING_CONFIG)
|
||||
|
||||
@@ -2,142 +2,47 @@
|
||||
|
||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
except ImportError:
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
||||
)
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
assert past_key_value is None, "past_key_value is not supported"
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
assert not output_attentions, "output_attentions is not supported"
|
||||
assert not use_cache, "use_cache is not supported"
|
||||
|
||||
# Flash attention codes from
|
||||
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
||||
|
||||
# transform the data into the format required by flash attention
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask = attention_mask
|
||||
|
||||
if key_padding_mask is None:
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
max_s = q_len
|
||||
cu_q_lens = torch.arange(
|
||||
0,
|
||||
(bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||
if packed:
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||
llama_model_forward
|
||||
)
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif position_ids.shape[0] == 1:
|
||||
# special handling using sample packing
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_q_lens = cu_q_lens.squeeze()
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
x_unpad,
|
||||
"nnz (three h d) -> nnz three h d",
|
||||
three=3,
|
||||
h=nheads,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
x_unpad,
|
||||
cu_q_lens,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||
indices,
|
||||
bsz,
|
||||
q_len,
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
|
||||
return (
|
||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
@@ -153,8 +58,541 @@ def _prepare_decoder_attention_mask(
|
||||
return attention_mask
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn():
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
def flashattn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
#
|
||||
# flash-attn v2 start
|
||||
#
|
||||
|
||||
if self.training:
|
||||
# during training q,k,v always have same seqlen
|
||||
assert key_states.shape == query_states.shape
|
||||
is_causal = True
|
||||
else:
|
||||
# turn off FA causal mask after first inference autoregressive iteration
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
qkvpacked=True,
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if attention_mask is None or attention_mask.all().item():
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
causal=is_causal,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
_,
|
||||
_,
|
||||
output_pad_fn,
|
||||
) = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
attn_output = output
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
|
||||
#
|
||||
# flash-attn v2 end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||
|
||||
if kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
return (
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
kv,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
return (
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
def llama_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
transformers.logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
output_attentions,
|
||||
None,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
"""
|
||||
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
140
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Normal file
140
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
#
|
||||
# sdp-attn start
|
||||
#
|
||||
|
||||
with torch.backends.cuda.sdp_kernel():
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
#
|
||||
# sdp-attn end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
@@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models.llama.modeling_llama
|
||||
from torch import nn
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
@@ -21,12 +21,6 @@ def hijack_llama_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -81,15 +75,15 @@ def xformers_forward(
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
@@ -102,74 +96,50 @@ def xformers_forward(
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = transformers.models.llama.modeling_llama.repeat_kv(
|
||||
key_states, self.num_key_value_groups
|
||||
)
|
||||
value_states = transformers.models.llama.modeling_llama.repeat_kv(
|
||||
value_states, self.num_key_value_groups
|
||||
)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=None
|
||||
)
|
||||
else:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
# attn_bias=attention_mask,
|
||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
)
|
||||
attn_weights = None
|
||||
#
|
||||
# xformers-attn start
|
||||
#
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=None
|
||||
)
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
# end x-formers vs. not x-formers if-else block
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
# attn_bias=attention_mask,
|
||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
#
|
||||
# xformers-attn end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
@@ -182,103 +152,4 @@ def xformers_forward(
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
with torch.backends.cuda.sdp_kernel():
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
393
src/axolotl/monkeypatch/relora.py
Normal file
393
src/axolotl/monkeypatch/relora.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.relora")
|
||||
|
||||
|
||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
param_state = optimizer.state[param]
|
||||
for key in param_state:
|
||||
if "qmap" in key:
|
||||
continue
|
||||
|
||||
if key == "step" and isinstance(param_state[key], int):
|
||||
param_state[key] = 0
|
||||
else:
|
||||
param_state[key] = torch.zeros_like(param_state[key])
|
||||
|
||||
|
||||
class ReLoRACallback(TrainerCallback):
|
||||
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
||||
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.relora_steps = cfg.relora_steps
|
||||
self.cpu_offload = cfg.relora_cpu_offload
|
||||
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
||||
self.last_full_model = cfg.base_model
|
||||
self.resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
if not os.path.exists(self.last_full_model):
|
||||
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
|
||||
|
||||
assert os.path.exists(
|
||||
self.last_full_model
|
||||
), "for ReLORA base_model must be a local path"
|
||||
|
||||
self.num_lora_restarts = 0
|
||||
self.need_full_save = False
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
_state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
**_kwargs,
|
||||
):
|
||||
if self.resume_from_checkpoint:
|
||||
weight_path = os.path.join(self.resume_from_checkpoint, "relora")
|
||||
if not os.path.exists(weight_path):
|
||||
LOG.warning(
|
||||
"Resuming ReLoRA from checkpoint, but no full-weight save found"
|
||||
)
|
||||
else:
|
||||
LOG.info(f"Loading adjusted base weights from {weight_path}")
|
||||
load_weight_checkpoint(model, weight_path)
|
||||
return control
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
**_kwargs,
|
||||
):
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
"relora",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
self.last_full_model,
|
||||
checkpoint_folder,
|
||||
reinit=True,
|
||||
quantized=self.quantized,
|
||||
actually_save=is_main_process(),
|
||||
cpu_offload=self.cpu_offload,
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
|
||||
if self.quantized:
|
||||
self.last_full_model = checkpoint_folder
|
||||
self.num_lora_restarts += 1
|
||||
|
||||
return control
|
||||
|
||||
def on_save(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
**_kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
|
||||
)
|
||||
if (
|
||||
state.global_step >= self.relora_steps
|
||||
and state.global_step % self.relora_steps != 0
|
||||
):
|
||||
if self.quantized:
|
||||
if is_main_process() and self.last_full_model != checkpoint_folder:
|
||||
# ensure the latest full parameter save is in the latest checkpoint
|
||||
# folder, so that automatic pruning of checkpoints does not remove it
|
||||
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
||||
os.makedirs(checkpoint_folder, exist_ok=True)
|
||||
chunks = glob.glob(
|
||||
f"{self.last_full_model}/model*.safetensors"
|
||||
) + glob.glob(f"{self.last_full_model}/model*.index.json")
|
||||
for path in chunks:
|
||||
new_path = os.path.abspath(shutil.move(path, checkpoint_folder))
|
||||
try:
|
||||
os.symlink(new_path, path)
|
||||
except OSError:
|
||||
# probably on windows without permission to symlink
|
||||
pass
|
||||
|
||||
self.last_full_model = checkpoint_folder
|
||||
else:
|
||||
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
|
||||
|
||||
return control
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
_state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs: Dict[str, float],
|
||||
**_kwargs,
|
||||
):
|
||||
logs["num_lora_restarts"] = self.num_lora_restarts
|
||||
return control
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
_state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
**_kwargs,
|
||||
):
|
||||
if self.quantized:
|
||||
# perform final merge and save
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
self.last_full_model,
|
||||
args.output_dir,
|
||||
reinit=False,
|
||||
quantized=self.quantized,
|
||||
actually_save=is_main_process(),
|
||||
cpu_offload=self.cpu_offload,
|
||||
)
|
||||
# no need to save if unquantized, as finetune.py will call merge_and_unload()
|
||||
return control
|
||||
|
||||
|
||||
class ReLoRAScheduler(LRScheduler):
|
||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
if step < self.relora_steps:
|
||||
scale = 1
|
||||
else:
|
||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
return original * scale
|
||||
|
||||
|
||||
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
||||
model_name = "model.safetensors"
|
||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||
str(Path(path) / f"{model_name}.index.json")
|
||||
):
|
||||
model_name = "pytorch_model.bin"
|
||||
|
||||
index_path = str(Path(path) / f"{model_name}.index.json")
|
||||
if os.path.exists(index_path):
|
||||
with open(index_path, "r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
return data["weight_map"]
|
||||
return {(module_name + ".weight"): model_name for module_name in module_names}
|
||||
|
||||
|
||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
||||
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
||||
adapter = layer.active_adapter
|
||||
return (
|
||||
peft.utils.transpose(
|
||||
layer.lora_B[adapter].weight.detach().to(device)
|
||||
@ layer.lora_A[adapter].weight.detach().to(device),
|
||||
getattr(layer, "fan_in_fan_out", False),
|
||||
)
|
||||
* layer.scaling[adapter]
|
||||
)
|
||||
|
||||
return layer.get_delta_weight().to(device)
|
||||
|
||||
|
||||
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
||||
modules: Dict[str, peft.tuners.lora.LoraLayer] = {}
|
||||
|
||||
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
||||
for key in key_list:
|
||||
try:
|
||||
# pylint: disable=protected-access
|
||||
_parent, target, _target_name = peft.utils._get_submodules(model.model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||
modules[key] = target
|
||||
|
||||
return modules
|
||||
|
||||
|
||||
def update_weights(
|
||||
target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device
|
||||
):
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
|
||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||
# This could be faster, but the quantization of Linear4bit weights occurs
|
||||
# when the module is moved from cpu to gpu. Without meddling *too* deeply in
|
||||
# PEFT's innards or maintaining a duplicate of that codepath, this is good
|
||||
# enough for now.
|
||||
target.weight.quant_state = None
|
||||
target.weight.data = new_weight.cpu()
|
||||
target.to(device)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
|
||||
else:
|
||||
target.weight.data = new_weight.to(device)
|
||||
|
||||
|
||||
def merge_and_save(
|
||||
model: peft.LoraModel,
|
||||
model_src: str,
|
||||
model_dst: str,
|
||||
reinit: bool = False,
|
||||
quantized: bool = False,
|
||||
cpu_offload: bool = False,
|
||||
actually_save: bool = True,
|
||||
):
|
||||
modules = find_lora_modules(model)
|
||||
|
||||
if not quantized:
|
||||
for module_name, target in modules.items():
|
||||
update = target.get_delta_weight(target.active_adapter).detach()
|
||||
target.weight.data += update
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
return
|
||||
|
||||
os.makedirs(model_dst, exist_ok=True)
|
||||
shard_paths = sharded_paths(model_src, modules.keys())
|
||||
out_shard_paths = {}
|
||||
|
||||
unique_shards = list(set(shard_paths.values()))
|
||||
for shard_path in unique_shards:
|
||||
out_tensors = {}
|
||||
if shard_path.endswith(".safetensors"):
|
||||
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
||||
else:
|
||||
in_tensors = torch.load(Path(model_src) / shard_path)
|
||||
if "state_dict" in in_tensors:
|
||||
in_tensors = in_tensors["state_dict"]
|
||||
|
||||
for module_name, target in modules.items():
|
||||
key = module_name + ".weight"
|
||||
if key not in shard_paths or shard_paths[key] != shard_path:
|
||||
continue
|
||||
|
||||
orig_weight = in_tensors[key]
|
||||
old_dev = target.weight.device
|
||||
math_dev = "cpu" if cpu_offload else old_dev
|
||||
|
||||
delta_weight = lora_delta_weight(target, math_dev)
|
||||
new_weight = orig_weight.to(math_dev) + delta_weight
|
||||
del delta_weight
|
||||
|
||||
if actually_save:
|
||||
out_tensors[key] = new_weight.half().cpu()
|
||||
|
||||
update_weights(target, new_weight, reinit=reinit, device=old_dev)
|
||||
|
||||
if actually_save:
|
||||
out_shard_name = shard_path
|
||||
if out_shard_name.startswith("pytorch_model"):
|
||||
out_shard_name = (
|
||||
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
||||
+ ".safetensors"
|
||||
)
|
||||
|
||||
for module_name in in_tensors:
|
||||
if module_name not in out_tensors:
|
||||
out_tensors[module_name] = in_tensors[module_name].half()
|
||||
out_shard_paths[module_name] = out_shard_name
|
||||
|
||||
shard_fn = str(Path(model_dst) / out_shard_name)
|
||||
LOG.info(f"saving tensors to {shard_fn}")
|
||||
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
||||
|
||||
del in_tensors
|
||||
del out_tensors
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if actually_save and len(unique_shards) > 1:
|
||||
with open(
|
||||
str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8"
|
||||
) as file:
|
||||
json.dump({"metadata": {}, "weight_map": out_shard_paths}, file)
|
||||
|
||||
|
||||
def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str):
|
||||
modules = find_lora_modules(model)
|
||||
shard_paths = sharded_paths(checkpoint_path, modules.keys())
|
||||
unique_shards = list(set(shard_paths.values()))
|
||||
|
||||
for shard_path in unique_shards:
|
||||
tensors = st.load_file(os.path.join(checkpoint_path, shard_path))
|
||||
|
||||
for module_name, target in modules.items():
|
||||
key = module_name + ".weight"
|
||||
if key not in shard_paths or shard_paths[key] != shard_path:
|
||||
continue
|
||||
|
||||
new_weight = tensors[key]
|
||||
update_weights(
|
||||
target, new_weight, reinit=False, device=target.weight.device
|
||||
)
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
import importlib
|
||||
|
||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||
|
||||
def load(strategy, tokenizer, cfg):
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||
try:
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
@@ -11,6 +13,9 @@ def load(strategy, tokenizer, cfg):
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||
func = getattr(mod, load_fn)
|
||||
return func(tokenizer, cfg)
|
||||
load_kwargs = {}
|
||||
if strategy == "user_defined":
|
||||
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return None
|
||||
|
||||
@@ -57,6 +57,8 @@ class SystemDataPrompter(AlpacaPrompter):
|
||||
Alpaca Style Prompter that uses system prompts from the dataset
|
||||
"""
|
||||
|
||||
system_format: str = "### System:\n{system}\n\n"
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
@@ -92,8 +94,9 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
||||
def match_prompt_style(self):
|
||||
# pylint: disable=duplicate-code
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
||||
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
||||
self.turn_format = "### Human:\n{instruction}\n### Additional Context:\n{input}\n### Assistant:\n"
|
||||
self.turn_no_input_format = "### Human:\n{instruction}\n### Assistant:\n"
|
||||
self.system_format = "### System:\n{system}\n"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
|
||||
76
src/axolotl/prompt_strategies/metharme.py
Normal file
76
src/axolotl/prompt_strategies/metharme.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
|
||||
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
IGNORE_TOKEN_ID = -100
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for the Metharme models
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (prompt["prompt"], "", prompt["generation"])
|
||||
|
||||
def _tokenize(
|
||||
self,
|
||||
prompt: str,
|
||||
add_eos_token: bool = True,
|
||||
strip_bos_token: bool = False,
|
||||
num_eos_tokens: int = 3,
|
||||
):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if len(result["input_ids"]) == 0:
|
||||
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||
# If there's already an EOS token there, subtract from the number added
|
||||
if result["input_ids"][-1] == self.tokenizer.eos_token_id:
|
||||
num_eos_tokens -= 1
|
||||
|
||||
if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
|
||||
for _ in range(num_eos_tokens):
|
||||
if len(result["input_ids"]) < self.sequence_len:
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
|
||||
class MetharmePrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for the Metharme models.
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
system_format = ""
|
||||
turn_format = "{instruction}"
|
||||
turn_no_input_format = "{instruction}"
|
||||
|
||||
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
|
||||
pass
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return MetharmePromptTokenizingStrategy(
|
||||
MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
)
|
||||
98
src/axolotl/prompt_strategies/user_defined.py
Normal file
98
src/axolotl/prompt_strategies/user_defined.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
User Defined prompts with configuration from the YML config
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
InstructionWSystemPromptTokenizingStrategy,
|
||||
SystemDataPrompter,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserDefinedDatasetConfig:
|
||||
"""
|
||||
dataclass configuration representing a userdefined dataset type
|
||||
"""
|
||||
|
||||
system_prompt: str = ""
|
||||
field_system: str = "system"
|
||||
field_instruction: str = "instruction"
|
||||
field_input: str = "input"
|
||||
field_output: str = "output"
|
||||
format: str = "{instruction} {input} "
|
||||
no_input_format: str = "{instruction} "
|
||||
system_format: str = "{system}"
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||
"""
|
||||
Prompt Tokenization Strategy for user defined prompts
|
||||
"""
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
|
||||
if not ds_cfg:
|
||||
raise ValueError("Missing dataset prompt configuration")
|
||||
|
||||
system_prompt = ""
|
||||
if ds_cfg.system_prompt:
|
||||
system_prompt = ds_cfg.system_prompt
|
||||
|
||||
def parse_instruction_fields(
|
||||
field_instruction,
|
||||
field_input,
|
||||
field_output,
|
||||
field_system,
|
||||
system_prompt,
|
||||
prompt,
|
||||
) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt[field_instruction],
|
||||
prompt[field_input] if field_input in prompt else "",
|
||||
prompt[field_output] if field_output in prompt else "",
|
||||
prompt[field_system] if field_system in prompt else system_prompt,
|
||||
)
|
||||
|
||||
turn_format = ds_cfg.format
|
||||
turn_no_input_format = ds_cfg.no_input_format
|
||||
system_format = ds_cfg.system_format
|
||||
|
||||
class UserDefinedPrompter(SystemDataPrompter):
|
||||
"""
|
||||
Prompter for user defined prompts
|
||||
"""
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_format = turn_format
|
||||
self.turn_no_input_format = turn_no_input_format
|
||||
self.system_format = system_format
|
||||
|
||||
prompter = UserDefinedPrompter()
|
||||
|
||||
strat = UserDefinedPromptTokenizationStrategy(
|
||||
prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
setattr(
|
||||
strat,
|
||||
"parse_instruction_fields",
|
||||
partial(
|
||||
parse_instruction_fields,
|
||||
ds_cfg.field_instruction,
|
||||
ds_cfg.field_input,
|
||||
ds_cfg.field_output,
|
||||
ds_cfg.field_system,
|
||||
system_prompt,
|
||||
),
|
||||
)
|
||||
return strat
|
||||
@@ -13,7 +13,7 @@ from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
|
||||
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||
@@ -74,15 +74,22 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if len(result["input_ids"]) == 0:
|
||||
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.sequence_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class AlpacaPrompter:
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
system_format: str
|
||||
system_format: str = "{system}"
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
prompt_style: Optional[PromptStyle] = None
|
||||
@@ -63,13 +63,17 @@ class AlpacaPrompter:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = self.system_prompt + self.turn_format.format(
|
||||
instruction=instruction, input=input
|
||||
)
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input)
|
||||
else:
|
||||
res = self.system_no_input_prompt + self.turn_no_input_format.format(
|
||||
instruction=instruction
|
||||
)
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
@@ -312,7 +316,9 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError
|
||||
raise IndexError(
|
||||
f"A conversation entry has less than 2 messages :\n{source}"
|
||||
)
|
||||
|
||||
conv = self._conversation.copy()
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
@@ -4,13 +4,23 @@ import pynvml
|
||||
import torch
|
||||
|
||||
|
||||
def gpu_memory_usage(device):
|
||||
def gpu_memory_usage(device=0):
|
||||
return torch.cuda.memory_allocated(device) / 1024.0**3
|
||||
|
||||
|
||||
def gpu_memory_usage_all(device=0):
|
||||
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
||||
smi = gpu_memory_usage_smi(device)
|
||||
return usage, reserved - usage, max(0, smi - reserved)
|
||||
|
||||
|
||||
def gpu_memory_usage_smi(device=0):
|
||||
if isinstance(device, torch.device):
|
||||
device = device.index
|
||||
if isinstance(device, str) and device.startswith("cuda:"):
|
||||
device = int(device[5:])
|
||||
|
||||
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
@@ -18,6 +28,16 @@ def gpu_memory_usage(device):
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
if not torch.cuda.is_available():
|
||||
return (0, 0, 0)
|
||||
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
extras = []
|
||||
if cache > 0:
|
||||
extras.append(f"+{cache:.03f}GB cache")
|
||||
if misc > 0:
|
||||
extras.append(f"+{misc:.03f}GB misc")
|
||||
log.info(
|
||||
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
||||
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
||||
)
|
||||
return usage, cache, misc
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
"""Callbacks for Trainer class"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import load_dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
@@ -13,8 +23,19 @@ from transformers import (
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
gather_scalar_from_all_ranks,
|
||||
get_world_size,
|
||||
is_main_process,
|
||||
zero_first,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||
@@ -33,7 +54,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
||||
)
|
||||
|
||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
kwargs["model"].save_pretrained(
|
||||
peft_model_path, save_safetensors=args.save_safetensors
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
@@ -74,10 +97,10 @@ class SaveBetterTransformerModelCallback(
|
||||
return control
|
||||
|
||||
|
||||
class PrintGPUStatsCallback(
|
||||
class GPUStatsCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||
"""Callback to print GPU utilization"""
|
||||
"""Callback to track GPU utilization"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
@@ -90,7 +113,196 @@ class PrintGPUStatsCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if not self.logged:
|
||||
if not self.logged and state.global_step > 1:
|
||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||
self.logged = True
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
tokenizer("A", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("B", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("C", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("D", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("E", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("F", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("G", add_special_tokens=False).input_ids[0],
|
||||
]
|
||||
bench_split = "eval"
|
||||
|
||||
def transform_bench_subject(example):
|
||||
# Split on ':' and trim whitespace
|
||||
parts = example["subject"].split(":")
|
||||
first_part = (
|
||||
parts[0].strip().lower().replace("-", "_")
|
||||
) # Lowercase the first part
|
||||
second_part = (
|
||||
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
|
||||
) # Replace hyphens with underscores
|
||||
|
||||
# Return the transformed values
|
||||
return {"name": first_part, "subject": second_part}
|
||||
|
||||
if trainer.args.bench_dataset == "mmlu-zs":
|
||||
bench_dataset = load_dataset(
|
||||
"openaccess-ai-collective/mmlu-evals",
|
||||
data_files={
|
||||
"eval": "zero_shot_mmlu_val.json",
|
||||
"test": "zero_shot_mmlu_test.json",
|
||||
},
|
||||
)
|
||||
# bench_dataset = bench_dataset.remove_columns("subject")
|
||||
# MMLU Five-shot (Eval/Test only)
|
||||
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
|
||||
bench_dataset = load_dataset(
|
||||
"openaccess-ai-collective/mmlu-evals",
|
||||
data_files={
|
||||
"eval": "five_shot_mmlu_val.json",
|
||||
"test": "five_shot_mmlu_test.json",
|
||||
},
|
||||
)
|
||||
# bench_dataset = bench_dataset.remove_columns('subject')
|
||||
elif "/" in trainer.args.bench_dataset:
|
||||
bench_ds = trainer.args.bench_dataset
|
||||
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
|
||||
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
|
||||
bench_dataset = load_dataset(
|
||||
bench_ds_name,
|
||||
data_files={
|
||||
"eval": bench_ds_data_file,
|
||||
},
|
||||
)
|
||||
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
|
||||
)
|
||||
bench_dataset = bench_dataset[trainer.args.bench_split]
|
||||
if trainer.args.max_bench_samples is not None:
|
||||
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
|
||||
|
||||
def tokenize_evals(example):
|
||||
source = f"{tokenizer.bos_token}{example['input']}"
|
||||
target = f"{example['output']}{tokenizer.eos_token}"
|
||||
|
||||
tokenized_source = tokenizer(
|
||||
source,
|
||||
max_length=2048,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
tokenized_target = tokenizer(
|
||||
target,
|
||||
max_length=2048,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
|
||||
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"subject": example["subject"],
|
||||
}
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
bench_dataset = bench_dataset.map(tokenize_evals)
|
||||
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||
|
||||
class BenchEvalCallback(TrainerCallback):
|
||||
"""
|
||||
TrainerCallback that runs the MMLU evals
|
||||
"""
|
||||
|
||||
def on_evaluate(
|
||||
self,
|
||||
args: AxolotlTrainingArguments,
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
data_loader = trainer.get_bench_dataloader(
|
||||
bench_dataset.remove_columns(["input", "subject", "output", "name"])
|
||||
)
|
||||
trainer.model.eval()
|
||||
preds, refs = [], []
|
||||
loss_bench = 0
|
||||
for batch in tqdm(data_loader, total=len(data_loader)):
|
||||
(loss, logits, labels) = trainer.prediction_step(
|
||||
trainer.model,
|
||||
batch,
|
||||
prediction_loss_only=False,
|
||||
)
|
||||
# There are two tokens, the output, and eos token.
|
||||
for i, logit in enumerate(logits):
|
||||
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
|
||||
0
|
||||
][0]
|
||||
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
||||
preds.append(torch.argmax(logit_abcd).item())
|
||||
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||
refs += [
|
||||
abcd_idx.index(label) if label in abcd_idx else -1
|
||||
for label in labels.tolist()
|
||||
]
|
||||
loss_bench += loss.item()
|
||||
# Extract results by subject.
|
||||
bench_name = bench_dataset["name"]
|
||||
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
|
||||
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
||||
bench_names[s]["preds"].append(p)
|
||||
bench_names[s]["refs"].append(r)
|
||||
barrier()
|
||||
local_bench_names = bench_names
|
||||
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
|
||||
# Gather results from all GPUs to GPU 0
|
||||
|
||||
loss_bench_ranks = gather_scalar_from_all_ranks(
|
||||
lambda: loss_bench, get_world_size()
|
||||
)
|
||||
len_data_loader_ranks = gather_scalar_from_all_ranks(
|
||||
lambda: len(data_loader), get_world_size()
|
||||
)
|
||||
|
||||
if not is_main_process():
|
||||
dist.gather_object(local_bench_names, dst=0)
|
||||
else:
|
||||
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
||||
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
||||
results = {"bench_loss": bench_loss}
|
||||
|
||||
# Combine results from all GPUs
|
||||
combined_bench_names: Dict[str, Dict[str, List]] = {}
|
||||
for bench_name in gathered_bench_names:
|
||||
for name, data in bench_name.items():
|
||||
if name not in combined_bench_names:
|
||||
combined_bench_names[name] = {"refs": [], "preds": []}
|
||||
combined_bench_names[name]["refs"].extend(data["refs"])
|
||||
combined_bench_names[name]["preds"].extend(data["preds"])
|
||||
|
||||
bench_scores = []
|
||||
for (
|
||||
bench_name
|
||||
) in combined_bench_names: # pylint: disable=consider-using-dict-items
|
||||
bench_score = accuracy.compute(
|
||||
references=combined_bench_names[bench_name]["refs"],
|
||||
predictions=combined_bench_names[bench_name]["preds"],
|
||||
)["accuracy"]
|
||||
if not pd.isna(bench_score):
|
||||
results[
|
||||
f"bench_{bench_split}_accuracy_{bench_name}"
|
||||
] = bench_score
|
||||
bench_scores.append(bench_score)
|
||||
else:
|
||||
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
|
||||
bench_scores.append(0.0)
|
||||
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
|
||||
trainer.log(results)
|
||||
|
||||
return BenchEvalCallback
|
||||
|
||||
@@ -1,12 +1,77 @@
|
||||
"""Module for validating config files"""
|
||||
"""Module for working with config dicts"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def choose_device(cfg):
|
||||
def get_device():
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
return f"cuda:{cfg.local_rank}"
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
raise SystemError("No CUDA/mps device found")
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.device_map != "auto":
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
# in `accelerate launch`, we need to not pass through any device map and let
|
||||
# accelerate figure out which parts of the model to put on which gpu
|
||||
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
|
||||
if accelerate_vars:
|
||||
cfg.device_map = None
|
||||
|
||||
|
||||
def normalize_config(cfg):
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
cfg.tf32 = False
|
||||
if cfg.bf16:
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
cfg.torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
cfg.torch_dtype = torch.float16
|
||||
else:
|
||||
cfg.torch_dtype = torch.float32
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||
raise ValueError(
|
||||
@@ -61,6 +126,19 @@ def validate_config(cfg):
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
if cfg.fsdp:
|
||||
raise ValueError("fsdp not supported with ReLoRA")
|
||||
|
||||
if cfg.deepspeed:
|
||||
raise ValueError("deepspeed not supported with ReLoRA")
|
||||
|
||||
if cfg.lr_scheduler == "one_cycle":
|
||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
@@ -89,7 +167,7 @@ def validate_config(cfg):
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
)
|
||||
if int(torch.__version__.split(".")[0]) < 2:
|
||||
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
|
||||
LOG.warning("torch>=2.0.0 required")
|
||||
raise ValueError(
|
||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||
@@ -41,9 +41,46 @@ from axolotl.prompters import (
|
||||
ShareGPTPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
)
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
|
||||
def prepare_dataset(cfg, tokenizer):
|
||||
if not cfg.pretraining_dataset:
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
train_dataset = load_pretraining_dataset(
|
||||
cfg.pretraining_dataset,
|
||||
tokenizer,
|
||||
max_tokens=cfg.sequence_len,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
||||
)
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
else:
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
return train_dataset, eval_dataset, total_num_steps
|
||||
|
||||
|
||||
def load_tokenized_prepared_datasets(
|
||||
@@ -125,8 +162,15 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = "json"
|
||||
if d.ds_type:
|
||||
ds_type = d.ds_type
|
||||
elif ".parquet" in d.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in d.path:
|
||||
ds_type = "arrow"
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
ds_type,
|
||||
name=d.name,
|
||||
data_files=d.path,
|
||||
streaming=False,
|
||||
@@ -163,13 +207,27 @@ def load_tokenized_prepared_datasets(
|
||||
)
|
||||
else:
|
||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||
|
||||
d_base_type = d_prompt_style = None
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||
if isinstance(d_type, str):
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||
if "train" in ds:
|
||||
ds = ds["train"]
|
||||
if ds_strategy := load(d.type, tokenizer, cfg):
|
||||
if (
|
||||
"input_ids" in ds.features
|
||||
and "attention_mask" in ds.features
|
||||
and "labels" in ds.features
|
||||
):
|
||||
# dataset is already tokenized, just drop it straight in
|
||||
datasets.append(ds)
|
||||
elif isinstance(d.type, DictDefault):
|
||||
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "alpaca":
|
||||
@@ -440,7 +498,7 @@ def load_prepare_datasets(
|
||||
to_hash_test.encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
if is_main_process():
|
||||
with zero_first(is_main_process()):
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
@@ -448,16 +506,6 @@ def load_prepare_datasets(
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
barrier()
|
||||
if not is_main_process():
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
barrier()
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
|
||||
@@ -243,6 +243,18 @@ class MultipackDistributedDataloader:
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
# yield a no-op for cases where we don't have any data left to pack
|
||||
for i in range(0, len_remaining):
|
||||
yield self.collate_fn(
|
||||
[
|
||||
{
|
||||
"input_ids": [0],
|
||||
"labels": [-100],
|
||||
"attention_mask": [True],
|
||||
"position_ids": [0],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def _len_est(self):
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
|
||||
@@ -10,3 +10,6 @@ class DictDefault(Dict):
|
||||
|
||||
def __missing__(self, key):
|
||||
return None
|
||||
|
||||
def __or__(self, other):
|
||||
return DictDefault(super().__or__(other))
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""
|
||||
utility helpers for distributed checks
|
||||
"""
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import Accelerator
|
||||
|
||||
@@ -39,3 +43,51 @@ def is_main_process():
|
||||
if not is_distributed():
|
||||
return True
|
||||
return dist.get_rank() == 0
|
||||
|
||||
|
||||
def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_first(is_main):
|
||||
"""
|
||||
runs the wrapped context so that rank 0 runs first before other ranks
|
||||
"""
|
||||
if not is_main: # other ranks wait first
|
||||
barrier()
|
||||
yield
|
||||
if is_main: # then rank 0 waits after it has run the context
|
||||
barrier()
|
||||
|
||||
|
||||
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
||||
"""
|
||||
Run a callable 'fn' on all ranks and gather the results on the specified rank.
|
||||
|
||||
Args:
|
||||
- fn (callable): A function that computes the value. This should not have any side effects.
|
||||
- rank (int, optional): The rank that gathers the values. Default is 0.
|
||||
- world_size (int, optional): Total number of processes in the current distributed setup.
|
||||
|
||||
Returns:
|
||||
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
||||
"""
|
||||
value_scalar = fn()
|
||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||
|
||||
if not is_main_process():
|
||||
dist.gather(value_tensor, dst=0)
|
||||
else:
|
||||
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
|
||||
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
|
||||
|
||||
# Convert tensors back to their original type (int or float)
|
||||
gathered_values = []
|
||||
for tensor in gathered_tensors:
|
||||
if tensor == tensor.int():
|
||||
gathered_values.append(int(tensor.item()))
|
||||
else:
|
||||
gathered_values.append(float(tensor.item()))
|
||||
return gathered_values
|
||||
return None
|
||||
|
||||
@@ -21,7 +21,7 @@ from transformers import ( # noqa: F401
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -32,44 +32,46 @@ if TYPE_CHECKING:
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
tokenizer_config,
|
||||
tokenizer_type,
|
||||
cfg,
|
||||
):
|
||||
def load_tokenizer(cfg):
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
|
||||
if cfg.tokenizer_use_fast is not None:
|
||||
use_fast = cfg.tokenizer_use_fast
|
||||
if cfg.tokenizer_legacy is not None:
|
||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||
if tokenizer_type:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
tokenizer_cls = AutoTokenizer
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
tokenizer.__class__.__name__
|
||||
in [
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
]
|
||||
and hasattr(tokenizer, "pad_token")
|
||||
and not tokenizer.pad_token
|
||||
):
|
||||
# set a pad_token, but use eos_token so we don't add a new token
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ in [
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
]:
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -92,7 +94,6 @@ def load_model(
|
||||
base_model = cfg.base_model
|
||||
base_model_config = cfg.base_model_config
|
||||
model_type = cfg.model_type
|
||||
adapter = cfg.adapter
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
@@ -109,7 +110,7 @@ def load_model(
|
||||
)
|
||||
|
||||
LOG.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
@@ -118,9 +119,7 @@ def load_model(
|
||||
LOG.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
|
||||
|
||||
LOG.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
@@ -144,20 +143,16 @@ def load_model(
|
||||
LOG.info("patching with xpos rope")
|
||||
replace_llama_rope_with_xpos_rope()
|
||||
|
||||
if cfg.is_llama_derived_model and (
|
||||
cfg.max_packed_sequence_len or cfg.sample_packing
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
||||
and not cfg.inference
|
||||
):
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
LOG.info("patching _expand_mask")
|
||||
hijack_expand_mask()
|
||||
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
try:
|
||||
if cfg.gptq:
|
||||
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||
@@ -189,7 +184,7 @@ def load_model(
|
||||
load_in_4bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch_dtype,
|
||||
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
@@ -235,15 +230,20 @@ def load_model(
|
||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config_kwargs = {}
|
||||
if cfg.rope_scaling:
|
||||
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
||||
config = LlamaConfig.from_pretrained(
|
||||
base_model_config, rope_scaling=cfg.rope_scaling
|
||||
base_model_config,
|
||||
**config_kwargs,
|
||||
)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
device_map=cfg.device_map,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
@@ -275,9 +275,10 @@ def load_model(
|
||||
elif model_type and not cfg.trust_remote_code:
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
device_map=cfg.device_map,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -305,9 +306,10 @@ def load_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
device_map=cfg.device_map,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -318,9 +320,10 @@ def load_model(
|
||||
LOG.exception(err)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
device_map=cfg.device_map,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -345,6 +348,15 @@ def load_model(
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(torch.float32)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch.float32)
|
||||
|
||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
@@ -353,18 +365,20 @@ def load_model(
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
needs_fa2_dtype = True
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if cfg.flash_attention and cfg.is_llama_derived_model:
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(torch_dtype)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch_dtype)
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model):
|
||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(cfg.torch_dtype)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(cfg.torch_dtype)
|
||||
|
||||
model, lora_config = load_adapter(model, cfg, adapter)
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if cfg.ddp and not load_in_8bit:
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
@@ -381,9 +395,6 @@ def load_model(
|
||||
module.scales = module.scales.half()
|
||||
module.bias = module.bias.half()
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||
@@ -406,6 +417,9 @@ def load_model(
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
if cfg.adapter is not None:
|
||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, lora_config
|
||||
|
||||
@@ -436,7 +450,7 @@ def load_llama_adapter(model, cfg):
|
||||
)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.info("Loading pretained LORA")
|
||||
LOG.debug("Loading pretained PEFT - llama_adapter")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
@@ -498,6 +512,7 @@ def load_lora(model, cfg):
|
||||
)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretained PEFT - LoRA")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
|
||||
@@ -10,28 +10,30 @@ from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
import transformers
|
||||
from datasets import Dataset, set_caching_enabled
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
DistributedSampler,
|
||||
RandomSampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
PrintGPUStatsCallback,
|
||||
GPUStatsCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SavePeftModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.schedulers import (
|
||||
InterpolatingLogScheduler,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
)
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -124,6 +126,35 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -133,6 +164,10 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def __init__(self, *args, bench_data_collator=None, **kwargs):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
@@ -171,6 +206,18 @@ class AxolotlTrainer(Trainer):
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return SequentialDistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
@@ -195,6 +242,7 @@ class AxolotlTrainer(Trainer):
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
@@ -210,6 +258,31 @@ class AxolotlTrainer(Trainer):
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
@@ -249,6 +322,39 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
def add_position_ids(sample):
|
||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||
return sample
|
||||
@@ -268,15 +374,15 @@ def disable_datasets_caching():
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
||||
|
||||
if cfg.sample_packing:
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
eval_dataset = eval_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
@@ -355,15 +461,24 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
|
||||
def setup_fsdp_envs(cfg):
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
if cfg.fsdp_config.fsdp_offload_params:
|
||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||
if cfg.fsdp_config.fsdp_sync_module_states:
|
||||
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
||||
if cfg.fsdp_config.fsdp_state_dict_type:
|
||||
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
||||
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
|
||||
os.environ[
|
||||
"FSDP_TRANSFORMER_CLS_TO_WRAP"
|
||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.fsdp:
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
|
||||
warmup_steps = (
|
||||
cfg.warmup_steps
|
||||
if cfg.warmup_steps is not None
|
||||
@@ -411,21 +526,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
if cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
||||
|
||||
# deepspeed
|
||||
if cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
|
||||
|
||||
if cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
|
||||
|
||||
# deepspeed
|
||||
if (
|
||||
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
|
||||
and torch.cuda.device_count() > 1
|
||||
):
|
||||
if cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
|
||||
else:
|
||||
# make a guess here
|
||||
# TODO search Path("./") for one
|
||||
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
||||
|
||||
if cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
||||
if cfg.adam_beta2:
|
||||
@@ -440,6 +547,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
|
||||
if cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
|
||||
|
||||
if cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||
|
||||
@@ -448,8 +558,29 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
"sample_packing_efficiency"
|
||||
] = cfg.sample_packing_eff_est
|
||||
|
||||
if cfg.val_set_size == 0:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
elif cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
||||
else:
|
||||
# we have an eval set, but no steps defined, use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
|
||||
if cfg.save_strategy:
|
||||
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
|
||||
else:
|
||||
training_arguments_kwargs["save_strategy"] = (
|
||||
"steps" if cfg.save_steps else "epoch"
|
||||
)
|
||||
|
||||
if cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
||||
if cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||
max_seq_length=cfg.sequence_len,
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
@@ -459,9 +590,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
num_train_epochs=cfg.num_epochs,
|
||||
learning_rate=cfg.learning_rate,
|
||||
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=cfg.save_steps,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||
@@ -484,6 +612,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
relora_steps=cfg.relora_steps,
|
||||
relora_warmup_steps=cfg.relora_warmup_steps,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
@@ -493,69 +623,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
if Path(cfg.torchdistx_path).exists():
|
||||
sys.path.append(cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
if (
|
||||
cfg.optimizer == "adamw_bnb_8bit"
|
||||
and not cfg.gptq
|
||||
and "deepspeed" not in training_arguments_kwargs
|
||||
and not cfg.fsdp
|
||||
):
|
||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if (n in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": training_args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if (n not in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = bnb.optim.Adam8bit(
|
||||
optimizer_grouped_parameters,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
eps=training_args.adam_epsilon,
|
||||
lr=training_args.learning_rate,
|
||||
)
|
||||
|
||||
if cfg.lr_scheduler == "one_cycle":
|
||||
lr_scheduler_kwargs = (
|
||||
cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
cfg.learning_rate,
|
||||
total_steps=total_num_steps,
|
||||
epochs=cfg.num_epochs,
|
||||
div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
elif cfg.lr_scheduler == "log_sweep":
|
||||
lr_scheduler = InterpolatingLogScheduler(
|
||||
optimizer,
|
||||
cfg.warmup_steps,
|
||||
cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
|
||||
cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
|
||||
)
|
||||
else:
|
||||
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
||||
optimizer,
|
||||
training_args.warmup_steps,
|
||||
total_num_steps,
|
||||
)
|
||||
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
||||
|
||||
callbacks = []
|
||||
callbacks.append(PrintGPUStatsCallback(cfg))
|
||||
callbacks.append(GPUStatsCallback(cfg))
|
||||
|
||||
if cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(cfg))
|
||||
|
||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
@@ -600,11 +674,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = (
|
||||
OneCycleLRSchedulerTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||
else AxolotlTrainer
|
||||
)
|
||||
trainer_cls = AxolotlTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
|
||||
trainer_cls = OneCycleLRSchedulerTrainer
|
||||
elif cfg.relora_steps:
|
||||
trainer_cls = ReLoRATrainer
|
||||
trainer = trainer_cls(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
@@ -615,8 +689,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=callbacks,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
|
||||
if cfg.do_bench_eval:
|
||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||
|
||||
return trainer
|
||||
|
||||
@@ -72,6 +72,13 @@ class DictDefaultTest(unittest.TestCase):
|
||||
|
||||
assert cfg.random_key is None, "DictDefault should return None for missing keys"
|
||||
|
||||
def test_dict_or(self):
|
||||
cfg = DictDefault({}) | DictDefault({})
|
||||
|
||||
assert (
|
||||
cfg.random_key is None
|
||||
), "DictDefault should return None for missing keys after | operation"
|
||||
|
||||
def test_dict_nested_missingparentkey(self):
|
||||
"""
|
||||
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.
|
||||
|
||||
@@ -13,17 +13,22 @@ class TestTokenizers(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_default_use_fast(self):
|
||||
cfg = DictDefault({})
|
||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert "Fast" in tokenizer.__class__.__name__
|
||||
|
||||
def test_dont_use_fast(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"tokenizer_use_fast": False,
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert "Fast" not in tokenizer.__class__.__name__
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.validation import validate_config
|
||||
|
||||
|
||||
class ValidationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user