Compare commits

..

56 Commits

Author SHA1 Message Date
Wing Lian
64af21bcb2 set env vars trainer needs for FSDP
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-11 08:46:26 -04:00
Wing Lian
6b5cf8b5ea optimize length reducer from 9m -> <5sec 2023-08-11 08:30:30 -04:00
Wing Lian
79500f358a need to pass total num tokens to trainer too 2023-08-10 19:08:23 -04:00
Wing Lian
7e977a9b68 optimization if total_num_tokens is already known 2023-08-10 19:02:28 -04:00
Wing Lian
ac4b700daa optimization if total_num_tokens is already known 2023-08-10 19:01:17 -04:00
Wing Lian
2565c2f259 async batching for multipack 2023-08-10 18:28:15 -04:00
Wing Lian
a07f432d9c calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier 2023-08-10 17:16:01 -04:00
Wing Lian
57d9bf711c let's not cleanup the cached datasets 2023-08-08 21:27:55 -04:00
Wing Lian
26983a1974 fix sampler to prevent overfit w new epochs 2023-08-08 15:34:18 -04:00
Wing Lian
1b8747e319 use custom distributed checks 2023-08-08 13:35:04 -04:00
Wing Lian
035b3c760c add numba to requirements. 2023-08-08 10:55:29 -04:00
Wing Lian
17abbd59e1 previous accelerate is still most performant 2023-08-08 09:46:01 -04:00
Wing Lian
6ec76ddb4c fix steps calculation 2023-08-08 05:13:21 -04:00
Wing Lian
21d307b15b fix counts by accounting for num devices 2023-08-08 04:13:10 -04:00
Wing Lian
58e9dee204 fixes and go back to distributed sampler since batch sampler won't work 2023-08-08 03:49:29 -04:00
Wing Lian
4f7c04bae0 more fixes and optimizations 2023-08-08 03:16:00 -04:00
Wing Lian
1162b93b6b filter w multiple cpus 2023-08-08 00:50:56 -04:00
Wing Lian
21f445d763 more packing and dataset optimizations and fixes 2023-08-08 00:45:24 -04:00
Wing Lian
229b9165aa fix test and pylint checks 2023-08-07 09:38:05 -04:00
Wing Lian
394a65f11f add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test 2023-08-07 09:38:04 -04:00
Wing Lian
c70dae63cc add chatml 2023-08-07 09:38:04 -04:00
Wing Lian
7712955b35 fix chatml system prompt for openorca, legacy tokenizer opts 2023-08-07 09:38:04 -04:00
Wing Lian
f93f0017cd fix flash-attn, xformers, packing, support chatml 2023-08-07 09:38:04 -04:00
Wing Lian
0b01da0713 properly calculate max len 2023-08-07 09:38:04 -04:00
Wing Lian
b2f7bc7ccd use cumulative seq len with var len flash attn v2 w packing 2023-08-07 09:38:04 -04:00
Wing Lian
b8905e2a91 sample_packing_seq_len_multiplier config 2023-08-07 09:38:04 -04:00
Wing Lian
7e1edc662a make sure the chunk size is an int 2023-08-07 09:38:04 -04:00
Wing Lian
98c9bc69de seq_len_multiple for packing 2023-08-07 09:38:04 -04:00
Wing Lian
8378335dc9 limit packing to sequences of max seq len 2023-08-07 09:38:04 -04:00
Wing Lian
bdd34c7400 weighted CEL fixes 2023-08-07 09:38:04 -04:00
Wing Lian
c6cc54c7d9 weighted CE losses 2023-08-07 09:38:04 -04:00
Wing Lian
83f7362480 don't split batches when packing 2023-08-07 09:38:04 -04:00
Wing Lian
958d423e7c only process eval dataset for packing if not None 2023-08-07 09:38:04 -04:00
Wing Lian
e74eab6e73 add a test for the mask expansion for sequence packing 2023-08-07 09:38:04 -04:00
Wing Lian
487abfc769 pass sample packing efficiency to training args 2023-08-07 09:38:04 -04:00
Wing Lian
2bee646e85 fix step calc for packing 2023-08-07 09:38:04 -04:00
Wing Lian
945f2e5029 better handling so that all devices have the same dataloader len 2023-08-07 09:38:04 -04:00
Wing Lian
daed942fe9 fix rounding of len of batches to int 2023-08-07 09:38:04 -04:00
Wing Lian
df3eb645da better handling of variance in multipack dataloader length and trainer hanging when it runs out of data 2023-08-07 09:38:04 -04:00
Wing Lian
32fed7039d optimized expand mask fn 2023-08-07 09:38:04 -04:00
Wing Lian
7d7b5ebd71 more fixes for 4k and optimizations 2023-08-07 09:38:03 -04:00
Wing Lian
4b7ad9927f validation for sample packing and doc 2023-08-07 09:38:03 -04:00
Wing Lian
fedcf5a089 Update src/axolotl/utils/dataloader.py 2023-08-07 09:38:03 -04:00
Wing Lian
2f2974196d fix for position_ids w packing 2023-08-07 09:38:03 -04:00
Wing Lian
2e295c9f94 use accelerator prepare for dataloader 2023-08-07 09:38:03 -04:00
Wing Lian
4ab9ab79fd use distributed sampler, avoid accelerate prepare 2023-08-07 09:38:03 -04:00
Wing Lian
b02484a83e more fixes for sample packing 2023-08-07 09:38:03 -04:00
Wing Lian
58045f0816 more fixes, position_ids seems broken 2023-08-07 09:38:03 -04:00
Wing Lian
66774011c4 est total tokens, fix field loop 2023-08-07 09:38:03 -04:00
Wing Lian
41d4992029 more fixes for dataloader integration 2023-08-07 09:38:03 -04:00
Wing Lian
762f1b08db add position_ids back 2023-08-07 09:38:03 -04:00
Wing Lian
3aba4c5d7c use multi pack dataloader w random sampler 2023-08-07 09:38:03 -04:00
Wing Lian
ffd96839cf don't move masks to cpu 2023-08-07 09:38:03 -04:00
Wing Lian
ef9bf7ad73 fix expand mask for multiple batch items, make sure we pad position_ids 2023-08-07 09:38:03 -04:00
Wing Lian
4964b0d345 set position ids and use block diagonal attn mask 2023-08-07 09:38:03 -04:00
Wing Lian
36b0e30a9d fix attetion mask with packing 2023-08-07 09:38:03 -04:00
79 changed files with 1019 additions and 3935 deletions

View File

@@ -1,129 +0,0 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement on Discord
at https://discord.gg/QYF8QrtEUm
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@@ -1,76 +0,0 @@
# Contributing to axolotl
First of all, thank you for your interest in contributing to axolotl! We appreciate the time and effort you're willing to invest in making our project better. This document provides guidelines and information to make the contribution process as smooth as possible.
## Table of Contents
- [Code of Conduct](#code-of-conduct)
- [Getting Started](#getting-started)
- [How to Contribute](#how-to-contribute)
- [Reporting Bugs](#reporting-bugs)
- [Suggesting Enhancements](#suggesting-enhancements)
- [Submitting Pull Requests](#submitting-pull-requests)
- [Style Guidelines](#style-guidelines)
- [Code Style](#code-style)
- [Commit Messages](#commit-messages)
- [Additional Resources](#additional-resources)
## Code of Conductcode
All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.
## Getting Started
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
PRs are **greatly welcome**!
1. Fork the repository and clone it to your local machine.
2. Set up the development environment by following the instructions in the [README.md](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/README.md) file.
3. Explore the codebase, run tests, and verify that everything works as expected.
Please run below to setup env
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
```
## How to Contribute
### Reporting Bugs
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
### Suggesting Enhancements
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
### Submitting Pull Requests
1. Create a new branch for your feature or bugfix. Use a descriptive name like `feature/your-feature-name` or `fix/your-bugfix-name`.
2. Make your changes, following the [Style Guidelines](#style-guidelines) below.
3. Test your changes and ensure that they don't introduce new issues or break existing functionality.
4. Commit your changes, following the [commit message guidelines](#commit-messages).
5. Push your branch to your fork on GitHub.
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
## Style Guidelines
### Code Style
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
### Commit Messages
Write clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., "Add new feature" or "Fix bug in function".
## Additional Resources
- [GitHub Help](https://help.github.com/)
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
- [{codestyle}]({URLofCodestyle})
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!

13
.github/FUNDING.yml vendored
View File

@@ -1,13 +0,0 @@
# These are supported funding model platforms
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -1,105 +0,0 @@
name: Bug Report
description: File a bug report
labels: ["bug", "needs triage"]
body:
- type: markdown
attributes:
value: |
## Before you start
Please **make sure you are on the latest version.**
If you encountered the issue after you installed, updated, or reloaded, **please try restarting before reporting the bug**.
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "Please check that this issue hasn't been reported before."
description: "The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Bug Reports](https://github.com/OpenAccess-AI-Collective/axolotl/labels/bug) didn't find any similar reports."
required: true
- type: textarea
id: expected
attributes:
label: Expected Behavior
description: Tell us what **should** happen.
validations:
required: true
- type: textarea
id: what-happened
attributes:
label: Current behaviour
description: |
Tell us what happens instead of the expected behavior.
Provide stacktrace and/or screenshots.
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: Steps to reproduce
description: |
Which exact steps can a developer take to reproduce the issue?
The more detail you provide, the easier it will be to narrow down and fix the bug.
Please paste in tasks and/or queries **as text, not screenshots**.
placeholder: |
Example of the level of detail needed to reproduce any bugs efficiently and reliably.
1. Go to the '...' page.
2. Click on the '...' button.
3. Scroll down to '...'.
4. Observe the error.
validations:
required: true
- type: textarea
id: possible-solution
attributes:
label: Possible solution
description: |
Not obligatory, but please suggest a fix or reason for the bug, if you have an idea.
- type: checkboxes
id: operating-systems
attributes:
label: Which Operating Systems are you using?
description: You may select more than one.
options:
- label: Linux
- label: macOS
- label: Windows
- type: input
id: Python-version
attributes:
label: Python Version
description: Which {Programming} version are you using?
placeholder: 3.10 / please change accordingly
validations:
required: true
- type: input
id: axolotl-branch-commit
attributes:
label: axolotl branch-commit
description: On which branch/commit are you?
placeholder: main/4d6490b
validations:
required: true
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this bug has not been reported yet.'
required: true
- label: 'I am using the latest version of axolotl.'
required: true
- label: 'I have provided enough information for the maintainers to reproduce and diagnose the issue.'
required: true

View File

@@ -1,7 +0,0 @@
blank_issues_enabled: false
contact_links:
- name: Ask a question
url: https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/q-a
about: Ask questions and discuss with other community members
- name: Discuss the Project in Discord
url: https://discord.gg/HhrNrHJPRb

View File

@@ -1,46 +0,0 @@
name: Documentation Improvement / Clarity
description: Make a suggestion to improve the project documentation.
labels: ['needs triage', 'docs']
body:
- type: markdown
attributes:
value: '## :book: Documentation :book:'
- type: markdown
attributes:
value: |
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues).
- type: textarea
attributes:
label: What piece of documentation is affected?
description: Please link to the article you'd like to see updated.
validations:
required: true
- type: textarea
attributes:
label: What part(s) of the article would you like to see updated?
description: |
- Give as much detail as you can to help us understand the change you want to see.
- Why should the docs be changed? What use cases does it support?
- What is the expected outcome?
validations:
required: true
- type: textarea
attributes:
label: Additional Information
description: Add any other context or screenshots about the feature request here.
validations:
required: false
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
required: true
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
required: true

View File

@@ -1,63 +0,0 @@
name: Feature Request / Enhancement
description: Suggest a new feature or feature enhancement for the project
labels: ["enhancement", "needs triage"]
body:
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "⚠️ Please check that this feature request hasn't been suggested before."
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Ideas in Discussions](https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
required: true
- label: "I searched previous [Issues](https://github.com/OpenAccess-AI-Collective/axolotl/labels/enhancement) didn't find any similar feature requests."
required: true
- type: textarea
id: feature-description
validations:
required: true
attributes:
label: "🔖 Feature description"
description: "A clear and concise description of what the feature request is."
placeholder: "You should add ..."
- type: textarea
id: solution
validations:
required: true
attributes:
label: "✔️ Solution"
description: "A clear and concise description of what you want to happen, and why."
placeholder: "In my use-case, ..."
- type: textarea
id: alternatives
validations:
required: false
attributes:
label: "❓ Alternatives"
description: "A clear and concise description of any alternative solutions or features you've considered."
placeholder: "I have considered ..."
- type: textarea
id: additional-context
validations:
required: false
attributes:
label: "📝 Additional Context"
description: "Add any other context or screenshots about the feature request here."
placeholder: "..."
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
required: true
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
required: true

View File

@@ -1,22 +0,0 @@
<!--- Provide a general summary of your changes in the Title above -->
# Description
<!--- Describe your changes in detail -->
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here. -->
## How has this been tested?
<!--- Please describe in detail how you tested your changes. -->
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## Screenshots (if appropriate)
## Types of changes
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->

9
.github/SECURITY.md vendored
View File

@@ -1,9 +0,0 @@
# Security Policy
## Supported Versions
Due to the nature of the fast development that is happening in this project, only the latest released version can be supported.
## Reporting a Vulnerability
If you find a vulnerability, please contact us on [Discord](https://discord.gg/xcu3ECkH9a) rather than creating a GitHub issue to allow us some time to fix it before it is a known vulnerability to others.

10
.github/SUPPORT.md vendored
View File

@@ -1,10 +0,0 @@
# Support
If you need help with this project or have questions, please:
1. Check the documentation.
2. Search the existing issues and pull requests.
3. Create a new issue if your question is not answered or your problem is not solved.
4. Have a look in the [Discord server](https://discord.gg/HhrNrHJPRb)
Please note that this project is maintained by volunteers who have limited availability. We'll do our best to address your questions and concerns in a timely manner.

View File

@@ -4,6 +4,7 @@ on:
push:
branches:
- "main"
- "dev"
jobs:
build-axolotl:
@@ -13,17 +14,17 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 118
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
@@ -49,11 +50,10 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-runpod:
needs: build-axolotl
@@ -72,7 +72,6 @@ jobs:
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
@@ -103,7 +102,5 @@ jobs:
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-runpod
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -24,7 +24,7 @@ jobs:
- name: Install dependencies
run: |
pip install -e .[peft]
pip install -e .
pip install -r requirements-tests.txt
- name: Run tests

246
README.md
View File

@@ -1,40 +1,10 @@
# Axolotl
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
<table>
<tr>
<td>
## Table of Contents
- [Introduction](#axolotl)
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Installation](#installation)
- [Docker Installation](#environment)
- [Conda/Pip venv Installation](#condapip-venv)
- [LambdaLabs Installation](#lambdalabs)
- [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-)
- [Need Help?](#need-help-)
- [Badge](#badge-)
- [Community Showcase](#community-showcase)
- [Contributing](#contributing-)
</td>
<td>
<div align="center">
<img src="image/axolotl.png" alt="axolotl" width="160">
<div>
<p>
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
<b>One repo to finetune them all! </b>
</p>
<p>
Go ahead and axolotl questions!!
@@ -44,34 +14,27 @@ Axolotl is a tool designed to streamline the fine-tuning of various AI models, o
</div>
</div>
</td>
</tr>
</table>
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|-------------------|------------|---------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | |
| mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ❓ | ✅
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.9 and Pytorch >=2.0.
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install -e .[flash-attn]
pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
# finetune lora
@@ -100,7 +63,7 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
```
- Conda/Pip venv
1. Install python >=**3.9**
1. Install python **3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/
@@ -153,7 +116,9 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
pip3 install -e . # change depend on needs
pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy
pip3 install -U requests
pip3 install -U --ignore-installed psutil
pip3 install -U scipy
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
```
@@ -163,18 +128,15 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
```
</details>
- Windows: Please use WSL or Docker!
### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
Have dataset(s) in one of the following format (JSONL recommended):
- `alpaca`: instruction; input(optional)
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
- `sharegpt:chat`: conversations
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -259,18 +221,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```json
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
@@ -280,29 +234,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts
Using yaml. Example:
```yaml
datasets:
- path: repo
type:
system_prompt: ""
no_input_format: |-
User: {instruction}<|end_of_turn|>
Assistant:
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
```
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
Optionally, download some datasets, see [data/README.md](data/README.md)
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
### Config
@@ -330,20 +266,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
name: enron_emails
type: completion # format from earlier
# huggingface repo with multiple named configurations/subsets
datasets:
- path: bigcode/commitpackft
name:
- ruby
- python
- typescript
type: ... # unimplemented custom format
# local
datasets:
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
- path: json
data_files: data.jsonl # or json
type: alpaca # format from earlier
```
- loading
@@ -395,8 +322,6 @@ tokenizer_type: AutoTokenizer
trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy:
# resize the model embeddings when new tokens are added to multiples of 32
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
@@ -418,39 +343,16 @@ fp16: true
# Use CUDA tf32
tf32: true # require >=ampere
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
float16: true
# a list of one or more datasets to finetune the model with
datasets:
# hf dataset repo | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
data_files: # path to source data files
shards: # number of shards to split data into
name: # name of dataset configuration to load
# custom user prompt
- path: repo
type:
# the below are defaults. only set what's needed.
system_prompt: ""
field_system: system
field_instruction: instruction
field_output: input
# customizable to be single line or multi-line
system_format: "{system}"
# 'format' can include {input}
format: |-
User: {instruction} {input}
Assistant:
# 'no_input_format' cannot include {input}
no_input_format: "{instruction} "
# axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
@@ -458,13 +360,10 @@ dataset_prepared_path: data/last_run_prepared
push_dataset_to_hub: # repo path
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
val_set_size: 0.04
# Num shards for whole dataset
dataset_shard_num:
@@ -474,19 +373,12 @@ dataset_shard_idx:
# the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# pad inputs so each step uses constant sized buffers
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
# max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
# soon to be DEPRECATED
max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
# use efficient multi-packing with block diagonal attention and per sequence position_ids
sample_packing:
# you can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
@@ -511,19 +403,12 @@ lora_modules_to_save:
lora_out_dir:
lora_fan_in_fan_out: false
# ReLoRA configuration
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # number of steps per ReLoRA restart
relora_warmup_steps: # number of per-restart warmup steps
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # your wandb project name
wandb_entity: # a wandb Team name if using a Team
wandb_mode:
wandb_project:
wandb_watch:
wandb_run_id: # set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
wandb_run_id:
wandb_log_model: # 'checkpoint'
# where to save the finished model to
output_dir: ./completed-model
@@ -535,22 +420,16 @@ eval_batch_size: 2
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
save_strategy: # set to `no` to skip checkpoint saves
save_steps: # leave empty to save at each epoch
eval_steps: # leave empty to eval at each epoch
save_total_limit: # checkpoints saved at a time
max_steps:
save_steps:
eval_steps:
# save model as safetensors (require safetensors package)
save_safetensors:
# whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
# group similarly sized data to minimize padding
# may be slower to start, as it must download and sort the entire dataset
# note that training loss may have an oscillating pattern with this enabled
# don't use this, leads to wonky training (according to someone on the internet)
group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
@@ -586,8 +465,8 @@ max_grad_norm:
flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
flash_attention: # require a100 for llama
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
@@ -596,10 +475,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only
xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# resume from a specific checkpoint dir
resume_from_checkpoint:
@@ -622,19 +497,14 @@ tokens:
fsdp:
fsdp_config:
# Deepspeed config path
# Deepspeed
deepspeed:
# Advanced DDP Arguments
ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:
# Set padding for data collator to 'longest'
collator_pad_to_longest:
# Debug mode
debug:
@@ -652,17 +522,10 @@ strict:
Run
```bash
accelerate launch scripts/finetune.py your_config.yml
accelerate launch scripts/finetune.py configs/your_config.yml
```
#### Multi-GPU
You can optionally pre-tokenize dataset with the following before finetuning:
```bash
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
```
##### Config
#### Multi-GPU Config
- llama FSDP
```yaml
@@ -675,22 +538,7 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
- llama Deepspeed
```yaml
deepspeed: deepspeed/zero3.json
```
##### Weights & Biases Logging
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
```
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
### Inference
@@ -726,7 +574,7 @@ CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
## Common Errors 🧰
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
> Cuda out of memory
Please reduce any below
- `micro_batch_size`
@@ -734,12 +582,6 @@ Please reduce any below
- `gradient_accumulation_steps`
- `sequence_len`
> `failed (exitcode: -9)`
Usually means your system has run out of system memory.
Similarly, you should consider reducing the same settings as when you run out of VRAM.
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.
> RuntimeError: expected scalar type Float but found Half
Try set `fp16: true`
@@ -768,8 +610,6 @@ Building something cool with Axolotl? Consider adding a badge to your model card
## Community Showcase
Check out some of the projects and models that have been built using Axolotl! Have a model you'd like to add to our Community Showcase? Open a PR with your model.
Open Access AI Collective
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
@@ -780,9 +620,7 @@ PocketDoc Labs
## Contributing 🤝
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues/bug) else create a new Issue.
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
PRs are **greatly welcome**!

24
data/README.md Normal file
View File

@@ -0,0 +1,24 @@
## Download some datasets
```shell
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o data/raw/alpaca_data_gpt4.json
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o data/raw/vicuna_cleaned.json
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o data/raw/gpt4-instruct-similarity-0.6-dataset.json
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o data/raw/roleplay-similarity_0.6-instruct-dataset.json
```
## Convert the JSON data files to JSONL.
```shell
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
```
---
Using JSONL makes it easier to subset the data if you want a smaller training set, i.e get 2000 random examples.
```shell
shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
```

1
data/raw/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
**

View File

@@ -1,46 +0,0 @@
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-8,
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -35,7 +35,10 @@
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"betas": [
0.9,
0.95
],
"eps": 1e-8,
"weight_decay": "auto"
}

View File

@@ -11,14 +11,14 @@ RUN apt-get update && \
WORKDIR /workspace
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main"
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
pip install -e .[$AXOLOTL_EXTRAS]; \
else \
pip install -e .[flash-attn]; \
pip install -e .; \
fi
# fix so that git fetch/pull from remote works

View File

@@ -31,6 +31,26 @@ WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
FROM base-builder AS flash-attn-builder
WORKDIR /workspace
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \
git checkout v2.0.1 && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \
cd ../xentropy && \
python3 setup.py bdist_wheel && \
cd ../rotary && \
python3 setup.py bdist_wheel && \
cd ../layer_norm && \
python3 setup.py bdist_wheel
FROM base-builder AS deepspeed-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
@@ -70,8 +90,13 @@ RUN mkdir -p /workspace/wheels/bitsandbytes
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
RUN pip3 install wheels/deepspeed-*.whl
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo
RUN pip3 install awscli && \

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -36,7 +35,7 @@ torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: true

View File

@@ -1,67 +0,0 @@
base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,69 +0,0 @@
base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 100000
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,67 +0,0 @@
base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,69 +0,0 @@
base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 100000
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,67 +0,0 @@
base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,69 +0,0 @@
base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 100000
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,22 +0,0 @@
# Overview
This is an example of CodeLLaMA configuration for 7b, 13b and 34b.
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
The 13b variant will fit if you change these settings to these values:
gradient_accumulation_steps: 2
micro_batch_size: 1
The 34b variant does not fit on 24GB of VRAM - you will need something with +40 gb VRAM that also supports flash attention v2 - A6000 or A100 are good choices.
```shell
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/qlora.yml
```
or
```shell
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/lora.yml
```

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -38,7 +38,6 @@ lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -33,7 +32,7 @@ torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: true

View File

@@ -22,7 +22,6 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -57,7 +56,7 @@ weight_decay: 0.0001
fsdp:
fsdp_config:
tokens:
pad_token: "<pad>"
pad_token: "[PAD]"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -18,7 +18,6 @@ lora_dropout:
lora_target_modules:
lora_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -2,7 +2,6 @@ base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -16,7 +15,7 @@ val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
max_packed_sequence_len: 4096
adapter: lora
lora_model_dir:
@@ -27,7 +26,6 @@ lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -40,7 +38,7 @@ lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: false
@@ -50,8 +48,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
xformers_attention: true
flash_attention:
warmup_steps: 10
eval_steps: 20
@@ -65,3 +63,4 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -2,7 +2,6 @@ base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
@@ -19,8 +18,7 @@ adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
max_packed_sequence_len: 4096
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
@@ -29,7 +27,6 @@ lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -42,7 +39,7 @@ lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: false
@@ -52,8 +49,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
xformers_attention: true
flash_attention:
warmup_steps: 10
eval_steps: 20
@@ -67,3 +64,4 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -1,73 +0,0 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./relora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
relora_steps: 150
relora_warmup_steps: 10
relora_cpu_offload: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps: 50
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -28,7 +28,6 @@ lora_target_modules:
- o_proj
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -35,7 +34,7 @@ torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: true

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -47,3 +46,4 @@ local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:
collator_pad_to_longest: true

View File

@@ -17,7 +17,6 @@ lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -21,7 +21,6 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- mlp_down
lora_fan_in_fan_out:
wandb_project: lora-replit
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -37,7 +37,6 @@ lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -1,28 +1,23 @@
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1
bitsandbytes>=0.39.0
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict
evaluate
fire
PyYAML>=6.0
PyYAML==6.0
datasets
flash-attn>=2.0.8
accelerate>=0.19.0
sentencepiece
wandb
einops
xformers
optimum
hf_transfer
colorama
numba
numpy>=1.24.4
numpy==1.24.4
# qlora things
bert-score==0.3.13
evaluate==0.4.0
rouge-score==0.1.2
scipy
scikit-learn==1.2.2
pynvml
art

View File

@@ -0,0 +1,52 @@
"""Module to convert json file to jsonl"""
import os
import sys
from pathlib import Path
from typing import Optional, Union
import fire
from axolotl.convert import (
FileReader,
FileWriter,
JsonlSerializer,
JsonParser,
JsonToJsonlConverter,
StdoutWriter,
)
from axolotl.logging_config import configure_logging
configure_logging()
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
def main(
file: Path,
output: Optional[Path] = None,
to_stdout: Optional[bool] = False,
):
"""
Convert a json file to jsonl
"""
file_reader = FileReader()
writer: Union[StdoutWriter, FileWriter]
if to_stdout or output is None:
writer = StdoutWriter()
else:
writer = FileWriter(output)
json_parser = JsonParser()
jsonl_serializer = JsonlSerializer()
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
converter.convert(file, output)
if __name__ == "__main__":
fire.Fire(main)

View File

@@ -4,28 +4,31 @@ import importlib
import logging
import os
import random
import signal
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import transformers
import yaml
# add src to the pythonpath so we don't need to pip install this
from art import text2art
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta, train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_model_config, load_tokenizer
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
setup_trainer,
)
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -35,18 +38,30 @@ sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font)
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if is_main_process():
print(ascii_art)
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
@@ -58,45 +73,7 @@ def get_multi_line_input() -> Optional[str]:
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
@@ -118,8 +95,6 @@ def do_inference(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
while True:
print("=" * 80)
# support for multiline inputs
@@ -171,10 +146,6 @@ def choose_config(path: Path):
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
@@ -197,7 +168,11 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1)
def load_cfg(config: Path = Path("examples/"), **kwargs):
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
**kwargs,
):
if Path(config).is_dir():
config = choose_config(config)
@@ -216,72 +191,186 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
else:
cfg[k] = kwargs[k]
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
validate_config(cfg)
normalize_config(cfg)
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
setup_wandb_env_vars(cfg)
return cfg
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
LOG.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if cli_args.debug or cfg.debug:
if is_main_process():
# process on rank 0 first so it gets cached so other ranks load from cache
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if not is_main_process():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
if prepare_ds_only:
LOG.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
LOG.info("loading model and peft_config...")
model, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
)
if "merge_lora" in kwargs and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if cfg.inference:
LOG.info("calling do_inference function")
prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
prompter = None
else:
prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, prompter=prompter)
return
if "shard" in kwargs:
model.save_pretrained(cfg.output_dir)
return
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
)
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
if __name__ == "__main__":
fire.Fire(do_cli)
fire.Fire(train)

View File

@@ -7,7 +7,6 @@ with open("./requirements.txt", encoding="utf-8") as requirements_file:
# don't include peft yet until we check the int4
# need to manually install peft for now...
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
reqs = [r for r in reqs if "flash-attn" not in r]
reqs = [r for r in reqs if r and r[0] != "#"]
for r in reqs:
install_requires.append(r)
@@ -26,14 +25,9 @@ setup(
"gptq_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
"flash-attn": [
"flash-attn==2.0.8",
],
"extras": [
"flash-attn",
"deepspeed",
],
"peft": [
"peft @ git+https://github.com/huggingface/peft.git",
],
},
)

View File

@@ -1,43 +0,0 @@
"""
shared module for cli specific things
"""
import logging
from dataclasses import dataclass, field
from typing import Optional
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
return model, tokenizer

View File

@@ -5,7 +5,7 @@ import os
from typing import List
import torch
from datasets import Dataset, IterableDataset
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(Dataset):
class TokenizedPromptDataset(IterableDataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Iterable dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
@@ -30,18 +30,19 @@ class TokenizedPromptDataset(Dataset):
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
super().__init__(self.process(dataset).data, **kwargs)
self.dataset = dataset
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, os.cpu_count())
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
def __iter__(self):
features = self.dataset.features.keys()
num_proc = os.cpu_count()
return iter(
self.dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
)
)

View File

@@ -1,42 +1,16 @@
"""
Common logging module for axolotl
"""
"""Logging configuration settings"""
import os
import sys
from logging import Formatter
from logging.config import dictConfig
from typing import Any, Dict
from colorama import Fore, Style, init
class ColorfulFormatter(Formatter):
"""
Formatter to add coloring to log messages by log type
"""
COLORS = {
"WARNING": Fore.YELLOW,
"ERROR": Fore.RED,
"CRITICAL": Fore.RED + Style.BRIGHT,
}
def format(self, record):
log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
},
"colorful": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
},
},
"filters": {},
"handlers": {
@@ -46,25 +20,14 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"filters": [],
"stream": sys.stdout,
},
"color_console": {
"class": "logging.StreamHandler",
"formatter": "colorful",
"filters": [],
"stream": sys.stdout,
},
},
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"level": "DEBUG",
"propagate": False,
},
"axolotl": {"handlers": ["console"], "level": "DEBUG", "propagate": False},
},
}
def configure_logging():
"""Configure with default logging"""
init() # Initialize colorama
dictConfig(DEFAULT_LOGGING_CONFIG)

View File

@@ -2,47 +2,110 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import warnings
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
)
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = (
llama_model_forward
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze()
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None,
None,
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
@@ -58,541 +121,8 @@ def _prepare_decoder_attention_mask(
return attention_mask
def flashattn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# flash-attn v2 start
#
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
if cu_seqlens is not None and max_seqlen is not None:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
causal=is_causal,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
#
# flash-attn v2 end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
"""
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -1,140 +0,0 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# sdp-attn start
#
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# sdp-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g
"""
import logging
import warnings
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from torch import nn
try:
import xformers.ops
@@ -21,6 +21,12 @@ def hijack_llama_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def xformers_forward(
self,
hidden_states: torch.Tensor,
@@ -75,15 +81,15 @@ def xformers_forward(
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
@@ -96,50 +102,74 @@ def xformers_forward(
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = transformers.models.llama.modeling_llama.repeat_kv(
key_states, self.num_key_value_groups
)
value_states = transformers.models.llama.modeling_llama.repeat_kv(
value_states, self.num_key_value_groups
)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
#
# xformers-attn start
#
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
# end x-formers vs. not x-formers if-else block
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# xformers-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
@@ -152,4 +182,103 @@ def xformers_forward(
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return attn_output, attn_weights, past_key_value
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value

View File

@@ -1,393 +0,0 @@
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
import glob
import json
import logging
import os.path
import shutil
from pathlib import Path
from typing import Dict, List, Sequence
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from huggingface_hub import snapshot_download
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.relora")
def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
if key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
class ReLoRACallback(TrainerCallback):
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.relora_steps
self.cpu_offload = cfg.relora_cpu_offload
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
self.last_full_model = cfg.base_model
self.resume_from_checkpoint = cfg.resume_from_checkpoint
if not os.path.exists(self.last_full_model):
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
assert os.path.exists(
self.last_full_model
), "for ReLORA base_model must be a local path"
self.num_lora_restarts = 0
self.need_full_save = False
def on_train_begin(
self,
_args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**_kwargs,
):
if self.resume_from_checkpoint:
weight_path = os.path.join(self.resume_from_checkpoint, "relora")
if not os.path.exists(weight_path):
LOG.warning(
"Resuming ReLoRA from checkpoint, but no full-weight save found"
)
else:
LOG.info(f"Loading adjusted base weights from {weight_path}")
load_weight_checkpoint(model, weight_path)
return control
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"relora",
)
with torch.no_grad():
merge_and_save(
model,
self.last_full_model,
checkpoint_folder,
reinit=True,
quantized=self.quantized,
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
reset_optimizer(optimizer)
if self.quantized:
self.last_full_model = checkpoint_folder
self.num_lora_restarts += 1
return control
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**_kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
)
if (
state.global_step >= self.relora_steps
and state.global_step % self.relora_steps != 0
):
if self.quantized:
if is_main_process() and self.last_full_model != checkpoint_folder:
# ensure the latest full parameter save is in the latest checkpoint
# folder, so that automatic pruning of checkpoints does not remove it
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
os.makedirs(checkpoint_folder, exist_ok=True)
chunks = glob.glob(
f"{self.last_full_model}/model*.safetensors"
) + glob.glob(f"{self.last_full_model}/model*.index.json")
for path in chunks:
new_path = os.path.abspath(shutil.move(path, checkpoint_folder))
try:
os.symlink(new_path, path)
except OSError:
# probably on windows without permission to symlink
pass
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
return control
def on_log(
self,
_args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
logs: Dict[str, float],
**_kwargs,
):
logs["num_lora_restarts"] = self.num_lora_restarts
return control
def on_train_end(
self,
args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**_kwargs,
):
if self.quantized:
# perform final merge and save
with torch.no_grad():
merge_and_save(
model,
self.last_full_model,
args.output_dir,
reinit=False,
quantized=self.quantized,
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
# no need to save if unquantized, as finetune.py will call merge_and_unload()
return control
class ReLoRAScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
scale = 1
else:
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
model_name = "model.safetensors"
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
str(Path(path) / f"{model_name}.index.json")
):
model_name = "pytorch_model.bin"
index_path = str(Path(path) / f"{model_name}.index.json")
if os.path.exists(index_path):
with open(index_path, "r", encoding="utf-8") as file:
data = json.load(file)
return data["weight_map"]
return {(module_name + ".weight"): model_name for module_name in module_names}
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
adapter = layer.active_adapter
return (
peft.utils.transpose(
layer.lora_B[adapter].weight.detach().to(device)
@ layer.lora_A[adapter].weight.detach().to(device),
getattr(layer, "fan_in_fan_out", False),
)
* layer.scaling[adapter]
)
return layer.get_delta_weight().to(device)
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
modules: Dict[str, peft.tuners.lora.LoraLayer] = {}
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
for key in key_list:
try:
# pylint: disable=protected-access
_parent, target, _target_name = peft.utils._get_submodules(model.model, key)
except AttributeError:
continue
if isinstance(target, peft.tuners.lora.LoraLayer):
modules[key] = target
return modules
def update_weights(
target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device
):
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
if isinstance(target, peft.tuners.lora.Linear4bit):
# This could be faster, but the quantization of Linear4bit weights occurs
# when the module is moved from cpu to gpu. Without meddling *too* deeply in
# PEFT's innards or maintaining a duplicate of that codepath, this is good
# enough for now.
target.weight.quant_state = None
target.weight.data = new_weight.cpu()
target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
else:
target.weight.data = new_weight.to(device)
def merge_and_save(
model: peft.LoraModel,
model_src: str,
model_dst: str,
reinit: bool = False,
quantized: bool = False,
cpu_offload: bool = False,
actually_save: bool = True,
):
modules = find_lora_modules(model)
if not quantized:
for module_name, target in modules.items():
update = target.get_delta_weight(target.active_adapter).detach()
target.weight.data += update
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
return
os.makedirs(model_dst, exist_ok=True)
shard_paths = sharded_paths(model_src, modules.keys())
out_shard_paths = {}
unique_shards = list(set(shard_paths.values()))
for shard_path in unique_shards:
out_tensors = {}
if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path))
else:
in_tensors = torch.load(Path(model_src) / shard_path)
if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"]
for module_name, target in modules.items():
key = module_name + ".weight"
if key not in shard_paths or shard_paths[key] != shard_path:
continue
orig_weight = in_tensors[key]
old_dev = target.weight.device
math_dev = "cpu" if cpu_offload else old_dev
delta_weight = lora_delta_weight(target, math_dev)
new_weight = orig_weight.to(math_dev) + delta_weight
del delta_weight
if actually_save:
out_tensors[key] = new_weight.half().cpu()
update_weights(target, new_weight, reinit=reinit, device=old_dev)
if actually_save:
out_shard_name = shard_path
if out_shard_name.startswith("pytorch_model"):
out_shard_name = (
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
+ ".safetensors"
)
for module_name in in_tensors:
if module_name not in out_tensors:
out_tensors[module_name] = in_tensors[module_name].half()
out_shard_paths[module_name] = out_shard_name
shard_fn = str(Path(model_dst) / out_shard_name)
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
del in_tensors
del out_tensors
torch.cuda.empty_cache()
if actually_save and len(unique_shards) > 1:
with open(
str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8"
) as file:
json.dump({"metadata": {}, "weight_map": out_shard_paths}, file)
def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str):
modules = find_lora_modules(model)
shard_paths = sharded_paths(checkpoint_path, modules.keys())
unique_shards = list(set(shard_paths.values()))
for shard_path in unique_shards:
tensors = st.load_file(os.path.join(checkpoint_path, shard_path))
for module_name, target in modules.items():
key = module_name + ".weight"
if key not in shard_paths or shard_paths[key] != shard_path:
continue
new_weight = tensors[key]
update_weights(
target, new_weight, reinit=False, device=target.weight.device
)

View File

@@ -2,10 +2,8 @@
import importlib
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
def load(strategy, tokenizer, cfg, ds_cfg):
def load(strategy, tokenizer, cfg):
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
@@ -13,9 +11,6 @@ def load(strategy, tokenizer, cfg, ds_cfg):
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
return func(tokenizer, cfg, **load_kwargs)
return func(tokenizer, cfg)
except Exception: # pylint: disable=broad-exception-caught
return None

View File

@@ -57,8 +57,6 @@ class SystemDataPrompter(AlpacaPrompter):
Alpaca Style Prompter that uses system prompts from the dataset
"""
system_format: str = "### System:\n{system}\n\n"
def build_prompt_w_system(
self,
system: str,
@@ -94,13 +92,12 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
def match_prompt_style(self):
# pylint: disable=duplicate-code
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.turn_format = "### Human:\n{instruction}\n### Additional Context:\n{input}\n### Assistant:\n"
self.turn_no_input_format = "### Human:\n{instruction}\n### Assistant:\n"
self.system_format = "### System:\n{system}\n"
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
self.turn_format = "User: {instruction}\n{input}\nAssistant:"
self.turn_no_input_format = "User: {instruction}\nAssistant:"
self.system_format = "System: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
from axolotl.prompters import IGNORE_TOKEN_ID
@dataclass
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
assert role == conv.roles[j % 2]
if sentence["value"]:
conv.append_message(role, sentence["value"])
yield conv

View File

@@ -1,76 +0,0 @@
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
import logging
from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
# pylint: disable=duplicate-code
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for the Metharme models
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (prompt["prompt"], "", prompt["generation"])
def _tokenize(
self,
prompt: str,
add_eos_token: bool = True,
strip_bos_token: bool = False,
num_eos_tokens: int = 3,
):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
# If there's already an EOS token there, subtract from the number added
if result["input_ids"][-1] == self.tokenizer.eos_token_id:
num_eos_tokens -= 1
if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
for _ in range(num_eos_tokens):
if len(result["input_ids"]) < self.sequence_len:
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
class MetharmePrompter(AlpacaPrompter):
"""
Prompter for the Metharme models.
"""
system_prompt = ""
system_no_input_prompt = ""
system_format = ""
turn_format = "{instruction}"
turn_no_input_format = "{instruction}"
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
pass
def load(tokenizer, cfg):
return MetharmePromptTokenizingStrategy(
MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)

View File

@@ -1,98 +0,0 @@
"""
User Defined prompts with configuration from the YML config
"""
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple
from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter,
)
@dataclass
class UserDefinedDatasetConfig:
"""
dataclass configuration representing a userdefined dataset type
"""
system_prompt: str = ""
field_system: str = "system"
field_instruction: str = "instruction"
field_input: str = "input"
field_output: str = "output"
format: str = "{instruction} {input} "
no_input_format: str = "{instruction} "
system_format: str = "{system}"
def __getitem__(self, item):
return getattr(self, item)
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Prompt Tokenization Strategy for user defined prompts
"""
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
if not ds_cfg:
raise ValueError("Missing dataset prompt configuration")
system_prompt = ""
if ds_cfg.system_prompt:
system_prompt = ds_cfg.system_prompt
def parse_instruction_fields(
field_instruction,
field_input,
field_output,
field_system,
system_prompt,
prompt,
) -> Tuple[str, str, str, str]:
return (
prompt[field_instruction],
prompt[field_input] if field_input in prompt else "",
prompt[field_output] if field_output in prompt else "",
prompt[field_system] if field_system in prompt else system_prompt,
)
turn_format = ds_cfg.format
turn_no_input_format = ds_cfg.no_input_format
system_format = ds_cfg.system_format
class UserDefinedPrompter(SystemDataPrompter):
"""
Prompter for user defined prompts
"""
def match_prompt_style(self):
self.turn_format = turn_format
self.turn_no_input_format = turn_no_input_format
self.system_format = system_format
prompter = UserDefinedPrompter()
strat = UserDefinedPromptTokenizationStrategy(
prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
setattr(
strat,
"parse_instruction_fields",
partial(
parse_instruction_fields,
ds_cfg.field_instruction,
ds_cfg.field_input,
ds_cfg.field_output,
ds_cfg.field_system,
system_prompt,
),
)
return strat

View File

@@ -13,7 +13,7 @@ from axolotl.prompters import IGNORE_TOKEN_ID
LOG = logging.getLogger("axolotl")
IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
@@ -74,22 +74,15 @@ class PromptTokenizingStrategy(abc.ABC):
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]

View File

@@ -26,7 +26,7 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_format: str = "{system}"
system_format: str
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
@@ -63,17 +63,13 @@ class AlpacaPrompter:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = (
self.system_format.format(system=self.system_prompt)
if self.system_prompt
else ""
) + self.turn_format.format(instruction=instruction, input=input)
res = self.system_prompt + self.turn_format.format(
instruction=instruction, input=input
)
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
res = self.system_no_input_prompt + self.turn_no_input_format.format(
instruction=instruction
)
if output:
res = f"{res}{output}"
yield res
@@ -275,11 +271,6 @@ class Conversation:
self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for the ShareGPT
@@ -316,9 +307,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)
raise IndexError
conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
@@ -338,7 +327,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
assert role == conv.roles[j % 2]
conv.append_message(role, sentence["value"])
for part in conv.get_prompt():

View File

@@ -1,139 +0,0 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
import os
import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
# add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.train")
@dataclass
class TrainDatasetMeta:
"""
dataclass to capture the dataset specific options for training
"""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def train(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return model, tokenizer
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return model, tokenizer

View File

@@ -1,43 +0,0 @@
"""Benchmarking and measurement utilities"""
import pynvml
import torch
def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc

View File

@@ -1,20 +1,8 @@
"""Callbacks for Trainer class"""
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING, Dict, List
import evaluate
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from accelerate.state import PartialState
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
TrainerCallback,
TrainerControl,
@@ -23,20 +11,6 @@ from transformers import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
gather_scalar_from_all_ranks,
get_world_size,
is_main_process,
)
if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
dist_state = PartialState()
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter"""
@@ -54,9 +28,7 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(
peft_model_path, save_safetensors=args.save_safetensors
)
kwargs["model"].save_pretrained(peft_model_path)
return control
@@ -95,221 +67,3 @@ class SaveBetterTransformerModelCallback(
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control
class GPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to track GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
tokenizer("A", add_special_tokens=False).input_ids[0],
tokenizer("B", add_special_tokens=False).input_ids[0],
tokenizer("C", add_special_tokens=False).input_ids[0],
tokenizer("D", add_special_tokens=False).input_ids[0],
tokenizer("E", add_special_tokens=False).input_ids[0],
tokenizer("F", add_special_tokens=False).input_ids[0],
tokenizer("G", add_special_tokens=False).input_ids[0],
]
bench_split = "eval"
def transform_bench_subject(example):
# Split on ':' and trim whitespace
parts = example["subject"].split(":")
first_part = (
parts[0].strip().lower().replace("-", "_")
) # Lowercase the first part
second_part = (
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
) # Replace hyphens with underscores
# Return the transformed values
return {"name": first_part, "subject": second_part}
if trainer.args.bench_dataset == "mmlu-zs":
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "zero_shot_mmlu_val.json",
"test": "zero_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns("subject")
# MMLU Five-shot (Eval/Test only)
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
bench_dataset = load_dataset(
"openaccess-ai-collective/mmlu-evals",
data_files={
"eval": "five_shot_mmlu_val.json",
"test": "five_shot_mmlu_test.json",
},
)
# bench_dataset = bench_dataset.remove_columns('subject')
elif "/" in trainer.args.bench_dataset:
bench_ds = trainer.args.bench_dataset
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
bench_dataset = load_dataset(
bench_ds_name,
data_files={
"eval": bench_ds_data_file,
},
)
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
else:
raise ValueError(
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
)
bench_dataset = bench_dataset[trainer.args.bench_split]
if trainer.args.max_bench_samples is not None:
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
def tokenize_evals(example):
source = f"{tokenizer.bos_token}{example['input']}"
target = f"{example['output']}{tokenizer.eos_token}"
tokenized_source = tokenizer(
source,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
tokenized_target = tokenizer(
target,
max_length=2048,
truncation=True,
add_special_tokens=False,
)
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
"input_ids"
]
return {
"input_ids": input_ids,
"labels": labels,
"subject": example["subject"],
}
with dist_state.main_process_first():
bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
class BenchEvalCallback(TrainerCallback):
"""
TrainerCallback that runs the MMLU evals
"""
def on_evaluate(
self,
args: AxolotlTrainingArguments,
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_bench_dataloader(
bench_dataset.remove_columns(["input", "subject", "output", "name"])
)
trainer.model.eval()
preds, refs = [], []
loss_bench = 0
for batch in tqdm(data_loader, total=len(data_loader)):
(loss, logits, labels) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)
# There are two tokens, the output, and eos token.
for i, logit in enumerate(logits):
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
0
][0]
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
preds.append(torch.argmax(logit_abcd).item())
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
refs += [
abcd_idx.index(label) if label in abcd_idx else -1
for label in labels.tolist()
]
loss_bench += loss.item()
# Extract results by subject.
bench_name = bench_dataset["name"]
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r)
dist_state.wait_for_everyone()
local_bench_names = bench_names
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
# Gather results from all GPUs to GPU 0
loss_bench_ranks = gather_scalar_from_all_ranks(
lambda: loss_bench, get_world_size()
)
len_data_loader_ranks = gather_scalar_from_all_ranks(
lambda: len(data_loader), get_world_size()
)
if not is_main_process():
dist.gather_object(local_bench_names, dst=0)
else:
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {f"{bench_split}_bench_loss": bench_loss}
# Combine results from all GPUs
combined_bench_names: Dict[str, Dict[str, List]] = {}
for bench_name in gathered_bench_names:
for name, data in bench_name.items():
if name not in combined_bench_names:
combined_bench_names[name] = {"refs": [], "preds": []}
combined_bench_names[name]["refs"].extend(data["refs"])
combined_bench_names[name]["preds"].extend(data["preds"])
bench_scores = []
bench_refs = []
bench_preds = []
for (
bench_name
) in combined_bench_names: # pylint: disable=consider-using-dict-items
bench_score = accuracy.compute(
references=combined_bench_names[bench_name]["refs"],
predictions=combined_bench_names[bench_name]["preds"],
)["accuracy"]
bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score):
results[
f"{bench_split}_bench_accuracy_{bench_name}"
] = bench_score
bench_scores.append(bench_score)
else:
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
bench_scores.append(0.0)
results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores)
results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute(
references=bench_refs, predictions=bench_preds
)["accuracy"]
trainer.log(results)
return BenchEvalCallback

View File

@@ -1,20 +1,14 @@
"""Module containing data utilities"""
import functools
import hashlib
import itertools
import logging
from hashlib import md5
from pathlib import Path
from typing import Tuple, Union
from typing import List, Tuple, Union
import torch
from accelerate.state import PartialState
from datasets import (
Dataset,
DatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
@@ -42,46 +36,9 @@ from axolotl.prompters import (
ShareGPTPrompter,
SummarizeTLDRPrompter,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
)
from axolotl.utils.distributed import barrier, is_main_process
LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
state = PartialState()
def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset:
with state.main_process_first():
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
with state.main_process_first():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
return train_dataset, eval_dataset, total_num_steps
def load_tokenized_prepared_datasets(
@@ -135,17 +92,8 @@ def load_tokenized_prepared_datasets(
seed = 42
datasets = []
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
else:
yield dataset
# pylint: disable=invalid-name
for d in for_d_in_datasets(cfg.datasets):
for d in cfg.datasets:
ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False
try:
@@ -172,15 +120,8 @@ def load_tokenized_prepared_datasets(
split=None,
)
elif local_path.is_file():
ds_type = "json"
if d.ds_type:
ds_type = d.ds_type
elif ".parquet" in d.path:
ds_type = "parquet"
elif ".arrow" in d.path:
ds_type = "arrow"
ds = load_dataset(
ds_type,
"json",
name=d.name,
data_files=d.path,
streaming=False,
@@ -217,27 +158,13 @@ def load_tokenized_prepared_datasets(
)
else:
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
d_base_type = d_prompt_style = None
d_type = d.type
if isinstance(d_type, str):
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
and "labels" in ds.features
):
# dataset is already tokenized, just drop it straight in
datasets.append(ds)
elif isinstance(d.type, DictDefault):
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif ds_strategy := load(d.type, tokenizer, cfg, d):
if ds_strategy := load(d.type, tokenizer, cfg):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":
@@ -338,12 +265,20 @@ def load_tokenized_prepared_datasets(
raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
)
LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)
LOG.info("tokenizing, merging, and shuffling master dataset")
if len(datasets) > 1:
LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)
samples: List[int] = []
chunk_size = 1000
for d in datasets:
d_iter = iter(d)
while True:
chunk = list(itertools.islice(d_iter, chunk_size))
if not chunk:
break
samples.extend(chunk)
LOG.info("shuffle")
dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)
@@ -508,7 +443,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
with state.main_process_first():
if is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
@@ -516,6 +451,16 @@ def load_prepare_datasets(
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
if not is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

View File

@@ -3,7 +3,9 @@ import hashlib
import itertools
import logging
import math
from typing import Any, Callable, List, Union
import queue
import threading
from typing import Any, Callable, List, Optional, Union
import numba
import numpy as np
@@ -78,7 +80,6 @@ def allocate(
s = 0
start_index = 0
result = []
result_totseqs = []
while True:
# binary search [left, right)
@@ -104,10 +105,8 @@ def allocate(
# add local rank
result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
yield batch[rank], tot_seqs, s, len(result) * c * n
def chunk(iterable, n):
@@ -149,15 +148,14 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
total_num_tokens: Optional[int] = None,
):
# Dataset
self.dataset = dataset
self.lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
lengths_series = (
dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
)
self.lengths: np.ndarray = lengths_series.values
assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier
@@ -172,11 +170,17 @@ class MultipackDistributedDataloader:
self.rank = 0
# statistics
self.total_num_tokens = total_num_tokens
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
# for non-blocking batch creation
self.batch_queue: queue.Queue = queue.Queue(
maxsize=10
) # Adjust maxsize as needed
def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
if self.sampler:
@@ -188,77 +192,83 @@ class MultipackDistributedDataloader:
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
alloc_iter = iter(
allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
for batch, tot_seqs, total_used, total_slots in alloc_iter:
self.batch_queue.put([indices[b_idx] for b_idx in batch])
# statistics
if set_stats:
self.eff_total_used = total_used
self.eff_total_slots = total_slots
self.batch_queue.put(None) # Signal the end of batch generation
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches, totseqs
def _generate_batches_thread(self):
try:
self.generate_batches(set_stats=True)
except Exception as e:
LOG.error(f"Error in batch generation thread: {e}")
self.batch_queue.put(
None
) # Signal the end of batch generation in case of error
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True)
# Start the batch generation in a separate thread
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
batch_gen_thread.start()
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batches in chunk(
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
):
while True:
batch = self.batch_queue.get()
if batch is None: # Sentinel value received, stop iteration
break
chunked_data = []
attn_mask_cum_idx = 0
for batch in batches:
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return
# yield a no-op for cases where we don't have any data left to pack
for i in range(0, len_remaining):
yield self.collate_fn(
[
{
"input_ids": [0],
"labels": [-100],
"attention_mask": [True],
"position_ids": [0],
}
]
)
break
# Wait for the batch generation thread to finish
batch_gen_thread.join(timeout=5)
LOG.info(f"actual packing efficiency: {self.efficiency()}")
def _len_est(self):
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // self.device_count
if not self.total_num_tokens:
self.total_num_tokens = np.sum(self.lengths)
lengths_sum_per_device = self.total_num_tokens // self.device_count
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"

View File

@@ -10,6 +10,3 @@ class DictDefault(Dict):
def __missing__(self, key):
return None
def __or__(self, other):
return DictDefault(super().__or__(other))

View File

@@ -1,27 +1,25 @@
"""
utility helpers for distributed checks
"""
import torch
import torch.distributed as dist
from accelerate import DistributedType
from accelerate.state import PartialState
from accelerate.utils import wait_for_everyone
from accelerate import Accelerator
accelerate = None # pylint: disable=invalid-name
state = PartialState()
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed():
"""
Check if distributed training is initialized.
"""
return state.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_CPU,
DistributedType.DEEPSPEED,
DistributedType.FSDP,
)
global accelerate # pylint: disable=global-statement
if not accelerate:
accelerate = Accelerator()
return dist.is_available() and dist.is_initialized()
def barrier():
@@ -29,48 +27,15 @@ def barrier():
Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further.
"""
wait_for_everyone()
if is_distributed():
dist.barrier()
def is_main_process() -> bool:
def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
"""
return state.is_main_process
def get_world_size() -> int:
return state.num_processes
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not state.is_main_process:
dist.gather(value_tensor, dst=0)
else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None
if not is_distributed():
return True
return dist.get_rank() == 0

View File

@@ -5,13 +5,12 @@ import logging
import math
import os
from pathlib import Path
from typing import Optional, Tuple # noqa: F401
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
@@ -22,61 +21,54 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
LOG = logging.getLogger("axolotl")
if TYPE_CHECKING:
from peft import PeftConfig # noqa: F401
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code: bool = False or cfg.trust_remote_code
return AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer(cfg):
def load_tokenizer(
tokenizer_config,
tokenizer_type,
cfg,
):
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if (
tokenizer.__class__.__name__
in [
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
):
# set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in [
"LlamaTokenizer",
"LlamaTokenizerFast",
]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -91,28 +83,29 @@ def load_tokenizer(cfg):
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
inference: bool = False,
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
):
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model for a given configuration and tokenizer.
Load a model from a base model and a model type.
"""
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = (
"llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model is True
)
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not inference:
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
LOG.info("patching with flash attention")
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
replace_llama_attn_with_flash_attn()
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
@@ -121,7 +114,9 @@ def load_model(
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention,
)
LOG.info("patching with sdp attention")
hijack_llama_sdp_attention()
@@ -145,16 +140,20 @@ def load_model(
LOG.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope()
if (
cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not inference
if cfg.is_llama_derived_model and (
cfg.max_packed_sequence_len or cfg.sample_packing
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
try:
if cfg.gptq:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -186,7 +185,7 @@ def load_model(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=cfg.torch_dtype,
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
@@ -232,20 +231,13 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
torch_dtype=torch_dtype,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -277,10 +269,9 @@ def load_model(
elif model_type and not cfg.trust_remote_code:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -308,10 +299,9 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -322,10 +312,9 @@ def load_model(
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -347,18 +336,6 @@ def load_model(
)
model.config.max_position_embeddings = cfg.sequence_len
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)
for name, module in model.named_modules():
if "norm" in name:
module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -367,20 +344,18 @@ def load_model(
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
module.to(cfg.torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if cfg.flash_attention and cfg.is_llama_derived_model:
for name, module in model.named_modules():
if "norm" in name:
module.to(torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
model, lora_config = load_adapter(model, cfg, adapter)
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
@@ -419,22 +394,19 @@ def load_model(
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
if cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling
return model, lora_config
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_adapter(model, cfg, adapter):
# type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
return load_lora(model, cfg)
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)
@@ -452,7 +424,7 @@ def load_llama_adapter(model, cfg):
)
if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter")
LOG.info("Loading pretained LORA")
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
@@ -466,8 +438,12 @@ def load_llama_adapter(model, cfg):
return model, peft_config
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
def find_all_linear_names(bits, model):
cls = (
bnb.nn.Linear4bit
if bits == 4
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
@@ -480,15 +456,21 @@ def find_all_linear_names(model):
return list(lora_module_names)
def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
def load_lora(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
bits = None
if cfg.load_in_4bit:
bits = 4
elif cfg.load_in_8bit:
bits = 8
linear_names = find_all_linear_names(bits, model)
LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names))
@@ -504,11 +486,10 @@ def load_lora(model, cfg, inference=False):
)
if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - LoRA")
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
is_trainable=(not inference),
is_trainable=not cfg.inference,
)
else:
model = get_peft_model(model, lora_config)

View File

@@ -8,13 +8,13 @@ from termcolor import colored
LOG = logging.getLogger("axolotl")
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
def check_dataset_labels(dataset, tokenizer):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(num_examples):
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
for idx in range(5):
check_example_labels(dataset[idx], tokenizer)
def check_example_labels(example, tokenizer, text_only=False):
def check_example_labels(example, tokenizer):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
@@ -29,10 +29,8 @@ def check_example_labels(example, tokenizer, text_only=False):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + (
not text_only
and colored(f"({label_id}, {mask}, {input_id})", "white")
or ""
colored_token = colored(decoded_input_token, color) + colored(
f"({label_id}, {mask}, {input_id})", "white"
)
colored_tokens.append(colored_token)

View File

@@ -10,30 +10,26 @@ from functools import partial
from pathlib import Path
from typing import Optional, Union
import numpy as np
import bitsandbytes as bnb
import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import (
DataLoader,
DistributedSampler,
RandomSampler,
SequentialSampler,
)
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
bench_eval_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
)
LOG = logging.getLogger("axolotl")
@@ -126,34 +122,9 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
train_data_total_num_tokens: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
metadata={"help": "the total number of tokens in the train dataset"},
)
@@ -164,10 +135,6 @@ class AxolotlTrainer(Trainer):
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
@@ -206,18 +173,6 @@ class AxolotlTrainer(Trainer):
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
batch_size=self.args.per_device_eval_batch_size,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
@@ -231,6 +186,7 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=self.args.train_data_total_num_tokens,
)
)
return super().get_train_dataloader()
@@ -242,7 +198,6 @@ class AxolotlTrainer(Trainer):
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
@@ -254,35 +209,11 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=None,
)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
@@ -322,46 +253,13 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
def add_position_ids(sample):
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
return sample
def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
return len(sample["input_ids"]) <= sequence_len
@contextmanager
@@ -374,15 +272,15 @@ def disable_datasets_caching():
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
if eval_dataset:
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
add_position_ids, num_proc=os.cpu_count()
)
if eval_dataset:
eval_dataset = eval_dataset.map(add_position_ids, num_proc=os.cpu_count())
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
add_position_ids, num_proc=os.cpu_count()
)
return train_dataset, eval_dataset
@@ -390,26 +288,13 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
if not cfg.total_num_tokens:
LOG.info("calculating total_num_tokens")
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing_eff_est:
total_num_steps = (
@@ -417,9 +302,9 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
(
math.floor(
0.99
* cfg.total_num_tokens
* total_num_tokens
/ cfg.sample_packing_eff_est
/ cfg.sequence_len
/ 2048
// cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1))
)
@@ -428,7 +313,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
* cfg.num_epochs
)
LOG.info(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
)
else:
sampler = RandomSampler(train_dataset)
@@ -460,7 +345,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
)
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -471,24 +355,15 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_offload_params:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states:
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
if cfg.fsdp_config.fsdp_state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
os.environ[
"FSDP_TRANSFORMER_CLS_TO_WRAP"
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
warmup_steps = (
cfg.warmup_steps
if cfg.warmup_steps is not None
@@ -536,13 +411,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
# deepspeed
if cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
# deepspeed
if (
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
and torch.cuda.device_count() > 1
):
if cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
else:
# make a guess here
# TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
@@ -557,9 +440,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
@@ -568,38 +448,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.val_set_size == 0:
training_arguments_kwargs["evaluation_strategy"] = "no"
elif cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
else:
# we have an eval set, but no steps defined, use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
else:
training_arguments_kwargs["save_strategy"] = (
"steps" if cfg.save_steps else "epoch"
)
if cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
if cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
@@ -609,6 +459,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps,
output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
@@ -630,9 +483,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
relora_steps=cfg.relora_steps,
relora_warmup_steps=cfg.relora_warmup_steps,
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
train_data_total_num_tokens=cfg.total_num_tokens,
**training_arguments_kwargs,
)
@@ -642,13 +494,68 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if Path(cfg.torchdistx_path).exists():
sys.path.append(cfg.torchdistx_path)
importlib.import_module("torchdistx")
if (
cfg.optimizer == "adamw_bnb_8bit"
and not cfg.gptq
and "deepspeed" not in training_arguments_kwargs
and not cfg.fsdp
):
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": training_args.weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer = bnb.optim.Adam8bit(
optimizer_grouped_parameters,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
lr=training_args.learning_rate,
)
if cfg.lr_scheduler == "one_cycle":
lr_scheduler_kwargs = (
cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {}
)
lr_scheduler = OneCycleLR(
optimizer,
cfg.learning_rate,
total_steps=total_num_steps,
epochs=cfg.num_epochs,
div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6,
**lr_scheduler_kwargs,
)
elif cfg.lr_scheduler == "log_sweep":
lr_scheduler = InterpolatingLogScheduler(
optimizer,
cfg.warmup_steps,
cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
)
else:
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
optimizer,
training_args.warmup_steps,
total_num_steps,
)
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = []
callbacks.append(GPUStatsCallback(cfg))
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
@@ -666,12 +573,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
"padding": True,
}
if cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
cfg.sequence_len / 64
)
if cfg.collator_pad_to_longest:
data_collator_kwargs["padding"] = "longest"
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
@@ -695,11 +600,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
num_proc=32,
)
trainer_cls = AxolotlTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
trainer_cls = OneCycleLRSchedulerTrainer
elif cfg.relora_steps:
trainer_cls = ReLoRATrainer
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
else AxolotlTrainer
)
trainer = trainer_cls(
model=model,
train_dataset=train_dataset,
@@ -710,16 +615,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=callbacks,
**trainer_kwargs,
)
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
return trainer

View File

@@ -1,77 +1,12 @@
"""Module for working with config dicts"""
"""Module for validating config files"""
import logging
import os
import torch
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl")
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
# in `accelerate launch`, we need to not pass through any device map and let
# accelerate figure out which parts of the model to put on which gpu
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
if accelerate_vars:
cfg.device_map = None
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
if cfg.bf16 or cfg.bfloat16:
cfg.torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
cfg.torch_dtype = torch.float16
else:
cfg.torch_dtype = torch.float32
log_gpu_memory_usage(LOG, "baseline", cfg.device)
def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError(
@@ -126,19 +61,6 @@ def validate_config(cfg):
if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if cfg.fsdp:
raise ValueError("fsdp not supported with ReLoRA")
if cfg.deepspeed:
raise ValueError("deepspeed not supported with ReLoRA")
if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
@@ -167,7 +89,7 @@ def validate_config(cfg):
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
if int(torch.__version__.split(".")[0]) < 2:
LOG.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
@@ -188,13 +110,6 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
if cfg.gptq and cfg.model_revision:
raise ValueError(
"model_revision is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
)
if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError(

View File

@@ -9,8 +9,6 @@ def setup_wandb_env_vars(cfg):
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:

View File

@@ -72,13 +72,6 @@ class DictDefaultTest(unittest.TestCase):
assert cfg.random_key is None, "DictDefault should return None for missing keys"
def test_dict_or(self):
cfg = DictDefault({}) | DictDefault({})
assert (
cfg.random_key is None
), "DictDefault should return None for missing keys after | operation"
def test_dict_nested_missingparentkey(self):
"""
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.

View File

@@ -13,22 +13,17 @@ class TestTokenizers(unittest.TestCase):
"""
def test_default_use_fast(self):
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
}
)
tokenizer = load_tokenizer(cfg)
cfg = DictDefault({})
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" in tokenizer.__class__.__name__
def test_dont_use_fast(self):
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"tokenizer_use_fast": False,
}
)
tokenizer = load_tokenizer(cfg)
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" not in tokenizer.__class__.__name__

View File

@@ -6,8 +6,8 @@ from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase):