Compare commits

..

56 Commits

Author SHA1 Message Date
Wing Lian
64af21bcb2 set env vars trainer needs for FSDP
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-11 08:46:26 -04:00
Wing Lian
6b5cf8b5ea optimize length reducer from 9m -> <5sec 2023-08-11 08:30:30 -04:00
Wing Lian
79500f358a need to pass total num tokens to trainer too 2023-08-10 19:08:23 -04:00
Wing Lian
7e977a9b68 optimization if total_num_tokens is already known 2023-08-10 19:02:28 -04:00
Wing Lian
ac4b700daa optimization if total_num_tokens is already known 2023-08-10 19:01:17 -04:00
Wing Lian
2565c2f259 async batching for multipack 2023-08-10 18:28:15 -04:00
Wing Lian
a07f432d9c calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier 2023-08-10 17:16:01 -04:00
Wing Lian
57d9bf711c let's not cleanup the cached datasets 2023-08-08 21:27:55 -04:00
Wing Lian
26983a1974 fix sampler to prevent overfit w new epochs 2023-08-08 15:34:18 -04:00
Wing Lian
1b8747e319 use custom distributed checks 2023-08-08 13:35:04 -04:00
Wing Lian
035b3c760c add numba to requirements. 2023-08-08 10:55:29 -04:00
Wing Lian
17abbd59e1 previous accelerate is still most performant 2023-08-08 09:46:01 -04:00
Wing Lian
6ec76ddb4c fix steps calculation 2023-08-08 05:13:21 -04:00
Wing Lian
21d307b15b fix counts by accounting for num devices 2023-08-08 04:13:10 -04:00
Wing Lian
58e9dee204 fixes and go back to distributed sampler since batch sampler won't work 2023-08-08 03:49:29 -04:00
Wing Lian
4f7c04bae0 more fixes and optimizations 2023-08-08 03:16:00 -04:00
Wing Lian
1162b93b6b filter w multiple cpus 2023-08-08 00:50:56 -04:00
Wing Lian
21f445d763 more packing and dataset optimizations and fixes 2023-08-08 00:45:24 -04:00
Wing Lian
229b9165aa fix test and pylint checks 2023-08-07 09:38:05 -04:00
Wing Lian
394a65f11f add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test 2023-08-07 09:38:04 -04:00
Wing Lian
c70dae63cc add chatml 2023-08-07 09:38:04 -04:00
Wing Lian
7712955b35 fix chatml system prompt for openorca, legacy tokenizer opts 2023-08-07 09:38:04 -04:00
Wing Lian
f93f0017cd fix flash-attn, xformers, packing, support chatml 2023-08-07 09:38:04 -04:00
Wing Lian
0b01da0713 properly calculate max len 2023-08-07 09:38:04 -04:00
Wing Lian
b2f7bc7ccd use cumulative seq len with var len flash attn v2 w packing 2023-08-07 09:38:04 -04:00
Wing Lian
b8905e2a91 sample_packing_seq_len_multiplier config 2023-08-07 09:38:04 -04:00
Wing Lian
7e1edc662a make sure the chunk size is an int 2023-08-07 09:38:04 -04:00
Wing Lian
98c9bc69de seq_len_multiple for packing 2023-08-07 09:38:04 -04:00
Wing Lian
8378335dc9 limit packing to sequences of max seq len 2023-08-07 09:38:04 -04:00
Wing Lian
bdd34c7400 weighted CEL fixes 2023-08-07 09:38:04 -04:00
Wing Lian
c6cc54c7d9 weighted CE losses 2023-08-07 09:38:04 -04:00
Wing Lian
83f7362480 don't split batches when packing 2023-08-07 09:38:04 -04:00
Wing Lian
958d423e7c only process eval dataset for packing if not None 2023-08-07 09:38:04 -04:00
Wing Lian
e74eab6e73 add a test for the mask expansion for sequence packing 2023-08-07 09:38:04 -04:00
Wing Lian
487abfc769 pass sample packing efficiency to training args 2023-08-07 09:38:04 -04:00
Wing Lian
2bee646e85 fix step calc for packing 2023-08-07 09:38:04 -04:00
Wing Lian
945f2e5029 better handling so that all devices have the same dataloader len 2023-08-07 09:38:04 -04:00
Wing Lian
daed942fe9 fix rounding of len of batches to int 2023-08-07 09:38:04 -04:00
Wing Lian
df3eb645da better handling of variance in multipack dataloader length and trainer hanging when it runs out of data 2023-08-07 09:38:04 -04:00
Wing Lian
32fed7039d optimized expand mask fn 2023-08-07 09:38:04 -04:00
Wing Lian
7d7b5ebd71 more fixes for 4k and optimizations 2023-08-07 09:38:03 -04:00
Wing Lian
4b7ad9927f validation for sample packing and doc 2023-08-07 09:38:03 -04:00
Wing Lian
fedcf5a089 Update src/axolotl/utils/dataloader.py 2023-08-07 09:38:03 -04:00
Wing Lian
2f2974196d fix for position_ids w packing 2023-08-07 09:38:03 -04:00
Wing Lian
2e295c9f94 use accelerator prepare for dataloader 2023-08-07 09:38:03 -04:00
Wing Lian
4ab9ab79fd use distributed sampler, avoid accelerate prepare 2023-08-07 09:38:03 -04:00
Wing Lian
b02484a83e more fixes for sample packing 2023-08-07 09:38:03 -04:00
Wing Lian
58045f0816 more fixes, position_ids seems broken 2023-08-07 09:38:03 -04:00
Wing Lian
66774011c4 est total tokens, fix field loop 2023-08-07 09:38:03 -04:00
Wing Lian
41d4992029 more fixes for dataloader integration 2023-08-07 09:38:03 -04:00
Wing Lian
762f1b08db add position_ids back 2023-08-07 09:38:03 -04:00
Wing Lian
3aba4c5d7c use multi pack dataloader w random sampler 2023-08-07 09:38:03 -04:00
Wing Lian
ffd96839cf don't move masks to cpu 2023-08-07 09:38:03 -04:00
Wing Lian
ef9bf7ad73 fix expand mask for multiple batch items, make sure we pad position_ids 2023-08-07 09:38:03 -04:00
Wing Lian
4964b0d345 set position ids and use block diagonal attn mask 2023-08-07 09:38:03 -04:00
Wing Lian
36b0e30a9d fix attetion mask with packing 2023-08-07 09:38:03 -04:00
55 changed files with 334 additions and 1192 deletions

View File

@@ -1,129 +0,0 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement on Discord
at https://discord.gg/QYF8QrtEUm
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@@ -1,76 +0,0 @@
# Contributing to axolotl
First of all, thank you for your interest in contributing to axolotl! We appreciate the time and effort you're willing to invest in making our project better. This document provides guidelines and information to make the contribution process as smooth as possible.
## Table of Contents
- [Code of Conduct](#code-of-conduct)
- [Getting Started](#getting-started)
- [How to Contribute](#how-to-contribute)
- [Reporting Bugs](#reporting-bugs)
- [Suggesting Enhancements](#suggesting-enhancements)
- [Submitting Pull Requests](#submitting-pull-requests)
- [Style Guidelines](#style-guidelines)
- [Code Style](#code-style)
- [Commit Messages](#commit-messages)
- [Additional Resources](#additional-resources)
## Code of Conductcode
All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.
## Getting Started
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
PRs are **greatly welcome**!
1. Fork the repository and clone it to your local machine.
2. Set up the development environment by following the instructions in the [README.md](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/README.md) file.
3. Explore the codebase, run tests, and verify that everything works as expected.
Please run below to setup env
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
```
## How to Contribute
### Reporting Bugs
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
### Suggesting Enhancements
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
### Submitting Pull Requests
1. Create a new branch for your feature or bugfix. Use a descriptive name like `feature/your-feature-name` or `fix/your-bugfix-name`.
2. Make your changes, following the [Style Guidelines](#style-guidelines) below.
3. Test your changes and ensure that they don't introduce new issues or break existing functionality.
4. Commit your changes, following the [commit message guidelines](#commit-messages).
5. Push your branch to your fork on GitHub.
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
## Style Guidelines
### Code Style
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
### Commit Messages
Write clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., "Add new feature" or "Fix bug in function".
## Additional Resources
- [GitHub Help](https://help.github.com/)
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
- [{codestyle}]({URLofCodestyle})
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!

13
.github/FUNDING.yml vendored
View File

@@ -1,13 +0,0 @@
# These are supported funding model platforms
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -1,105 +0,0 @@
name: Bug Report
description: File a bug report
labels: ["bug", "needs triage"]
body:
- type: markdown
attributes:
value: |
## Before you start
Please **make sure you are on the latest version.**
If you encountered the issue after you installed, updated, or reloaded, **please try restarting before reporting the bug**.
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "Please check that this issue hasn't been reported before."
description: "The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Bug Reports](https://github.com/OpenAccess-AI-Collective/axolotl/labels/bug) didn't find any similar reports."
required: true
- type: textarea
id: expected
attributes:
label: Expected Behavior
description: Tell us what **should** happen.
validations:
required: true
- type: textarea
id: what-happened
attributes:
label: Current behaviour
description: |
Tell us what happens instead of the expected behavior.
Provide stacktrace and/or screenshots.
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: Steps to reproduce
description: |
Which exact steps can a developer take to reproduce the issue?
The more detail you provide, the easier it will be to narrow down and fix the bug.
Please paste in tasks and/or queries **as text, not screenshots**.
placeholder: |
Example of the level of detail needed to reproduce any bugs efficiently and reliably.
1. Go to the '...' page.
2. Click on the '...' button.
3. Scroll down to '...'.
4. Observe the error.
validations:
required: true
- type: textarea
id: possible-solution
attributes:
label: Possible solution
description: |
Not obligatory, but please suggest a fix or reason for the bug, if you have an idea.
- type: checkboxes
id: operating-systems
attributes:
label: Which Operating Systems are you using?
description: You may select more than one.
options:
- label: Linux
- label: macOS
- label: Windows
- type: input
id: Python-version
attributes:
label: Python Version
description: Which {Programming} version are you using?
placeholder: 3.10 / please change accordingly
validations:
required: true
- type: input
id: axolotl-branch-commit
attributes:
label: axolotl branch-commit
description: On which branch/commit are you?
placeholder: main/4d6490b
validations:
required: true
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this bug has not been reported yet.'
required: true
- label: 'I am using the latest version of axolotl.'
required: true
- label: 'I have provided enough information for the maintainers to reproduce and diagnose the issue.'
required: true

View File

@@ -1,7 +0,0 @@
blank_issues_enabled: false
contact_links:
- name: Ask a question
url: https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/q-a
about: Ask questions and discuss with other community members
- name: Discuss the Project in Discord
url: https://discord.gg/HhrNrHJPRb

View File

@@ -1,46 +0,0 @@
name: Documentation Improvement / Clarity
description: Make a suggestion to improve the project documentation.
labels: ['needs triage', 'docs']
body:
- type: markdown
attributes:
value: '## :book: Documentation :book:'
- type: markdown
attributes:
value: |
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues).
- type: textarea
attributes:
label: What piece of documentation is affected?
description: Please link to the article you'd like to see updated.
validations:
required: true
- type: textarea
attributes:
label: What part(s) of the article would you like to see updated?
description: |
- Give as much detail as you can to help us understand the change you want to see.
- Why should the docs be changed? What use cases does it support?
- What is the expected outcome?
validations:
required: true
- type: textarea
attributes:
label: Additional Information
description: Add any other context or screenshots about the feature request here.
validations:
required: false
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
required: true
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
required: true

View File

@@ -1,63 +0,0 @@
name: Feature Request / Enhancement
description: Suggest a new feature or feature enhancement for the project
labels: ["enhancement", "needs triage"]
body:
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "⚠️ Please check that this feature request hasn't been suggested before."
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Ideas in Discussions](https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
required: true
- label: "I searched previous [Issues](https://github.com/OpenAccess-AI-Collective/axolotl/labels/enhancement) didn't find any similar feature requests."
required: true
- type: textarea
id: feature-description
validations:
required: true
attributes:
label: "🔖 Feature description"
description: "A clear and concise description of what the feature request is."
placeholder: "You should add ..."
- type: textarea
id: solution
validations:
required: true
attributes:
label: "✔️ Solution"
description: "A clear and concise description of what you want to happen, and why."
placeholder: "In my use-case, ..."
- type: textarea
id: alternatives
validations:
required: false
attributes:
label: "❓ Alternatives"
description: "A clear and concise description of any alternative solutions or features you've considered."
placeholder: "I have considered ..."
- type: textarea
id: additional-context
validations:
required: false
attributes:
label: "📝 Additional Context"
description: "Add any other context or screenshots about the feature request here."
placeholder: "..."
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
required: true
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
required: true

View File

@@ -1,22 +0,0 @@
<!--- Provide a general summary of your changes in the Title above -->
# Description
<!--- Describe your changes in detail -->
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here. -->
## How has this been tested?
<!--- Please describe in detail how you tested your changes. -->
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## Screenshots (if appropriate)
## Types of changes
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->

9
.github/SECURITY.md vendored
View File

@@ -1,9 +0,0 @@
# Security Policy
## Supported Versions
Due to the nature of the fast development that is happening in this project, only the latest released version can be supported.
## Reporting a Vulnerability
If you find a vulnerability, please contact us on [Discord](https://discord.gg/xcu3ECkH9a) rather than creating a GitHub issue to allow us some time to fix it before it is a known vulnerability to others.

10
.github/SUPPORT.md vendored
View File

@@ -1,10 +0,0 @@
# Support
If you need help with this project or have questions, please:
1. Check the documentation.
2. Search the existing issues and pull requests.
3. Create a new issue if your question is not answered or your problem is not solved.
4. Have a look in the [Discord server](https://discord.gg/HhrNrHJPRb)
Please note that this project is maintained by volunteers who have limited availability. We'll do our best to address your questions and concerns in a timely manner.

View File

@@ -4,6 +4,7 @@ on:
push:
branches:
- "main"
- "dev"
jobs:
build-axolotl:
@@ -71,7 +72,6 @@ jobs:
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
@@ -102,7 +102,5 @@ jobs:
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-runpod
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

146
README.md
View File

@@ -1,39 +1,10 @@
# Axolotl
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
<table>
<tr>
<td>
## Table of Contents
- [Introduction](#axolotl)
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Installation](#installation)
- [Docker Installation](#environment)
- [Conda/Pip venv Installation](#condapip-venv)
- [LambdaLabs Installation](#lambdalabs)
- [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [Config](#config)
- [Train](#train)
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-)
- [Need Help?](#need-help-)
- [Badge](#badge-)
- [Community Showcase](#community-showcase)
- [Contributing](#contributing-)
</td>
<td>
<div align="center">
<img src="image/axolotl.png" alt="axolotl" width="160">
<div>
<p>
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
<b>One repo to finetune them all! </b>
</p>
<p>
Go ahead and axolotl questions!!
@@ -43,27 +14,21 @@ Axolotl is a tool designed to streamline the fine-tuning of various AI models, o
</div>
</div>
</td>
</tr>
</table>
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|-------------------|------------|---------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | |
| mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ❓ | ✅
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.9 and Pytorch >=2.0.
```bash
@@ -165,14 +130,13 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
Have dataset(s) in one of the following format (JSONL recommended):
- `alpaca`: instruction; input(optional)
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
- `sharegpt:chat`: conversations
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -261,10 +225,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```json
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
@@ -362,8 +322,6 @@ tokenizer_type: AutoTokenizer
trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy:
# resize the model embeddings when new tokens are added to multiples of 32
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
@@ -402,13 +360,10 @@ dataset_prepared_path: data/last_run_prepared
push_dataset_to_hub: # repo path
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
val_set_size: 0.04
# Num shards for whole dataset
dataset_shard_num:
@@ -420,14 +375,10 @@ dataset_shard_idx:
sequence_len: 2048
# max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
# soon to be DEPRECATED
max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
# use efficient multi-packing with block diagonal attention and per sequence position_ids
sample_packing:
# you can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
@@ -453,12 +404,11 @@ lora_out_dir:
lora_fan_in_fan_out: false
# wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # your wandb project name
wandb_entity: # a wandb Team name if using a Team
wandb_mode:
wandb_project:
wandb_watch:
wandb_run_id: # set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
wandb_run_id:
wandb_log_model: # 'checkpoint'
# where to save the finished model to
output_dir: ./completed-model
@@ -470,21 +420,16 @@ eval_batch_size: 2
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
save_steps: # leave empty to save at each epoch
save_steps:
eval_steps:
save_total_limit: # checkpoints saved at a time
max_steps:
# save model as safetensors (require safetensors package)
save_safetensors:
# whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
# group similarly sized data to minimize padding
# may be slower to start, as it must download and sort the entire dataset
# note that training loss may have an oscillating pattern with this enabled
# don't use this, leads to wonky training (according to someone on the internet)
group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
@@ -520,8 +465,8 @@ max_grad_norm:
flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
flash_attention: # require a100 for llama
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
@@ -530,10 +475,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only
xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# resume from a specific checkpoint dir
resume_from_checkpoint:
@@ -556,7 +497,7 @@ tokens:
fsdp:
fsdp_config:
# Deepspeed config path
# Deepspeed
deepspeed:
# Path to torch distx for optim 'adamw_anyprecision'
@@ -565,9 +506,6 @@ torchdistx_path:
# Set padding for data collator to 'longest'
collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:
# Debug mode
debug:
@@ -587,14 +525,7 @@ Run
accelerate launch scripts/finetune.py configs/your_config.yml
```
#### Multi-GPU
You can optionally pre-tokenize dataset with the following before finetuning:
```bash
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
```
##### Config
#### Multi-GPU Config
- llama FSDP
```yaml
@@ -607,22 +538,7 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
- llama Deepspeed
```yaml
deepspeed: deepspeed/zero3.json
```
##### Weights & Biases Logging
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
```
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
### Inference
@@ -658,7 +574,7 @@ CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
## Common Errors 🧰
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
> Cuda out of memory
Please reduce any below
- `micro_batch_size`
@@ -666,10 +582,6 @@ Please reduce any below
- `gradient_accumulation_steps`
- `sequence_len`
> `failed (exitcode: -9)` usually means your system has run out of system memory.
Similarly, you should consider reducing the same settings as when you run out of VRAM.
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.
> RuntimeError: expected scalar type Float but found Half
Try set `fp16: true`
@@ -698,8 +610,6 @@ Building something cool with Axolotl? Consider adding a badge to your model card
## Community Showcase
Check out some of the projects and models that have been built using Axolotl! Have a model you'd like to add to our Community Showcase? Open a PR with your model.
Open Access AI Collective
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
@@ -710,9 +620,7 @@ PocketDoc Labs
## Contributing 🤝
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues/bug) else create a new Issue.
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
PRs are **greatly welcome**!

View File

@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \
git checkout v2.0.4 && \
git checkout v2.0.1 && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -36,7 +35,7 @@ torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: true

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -38,7 +38,6 @@ lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -33,7 +32,7 @@ torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: true

View File

@@ -22,7 +22,6 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -18,7 +18,6 @@ lora_dropout:
lora_target_modules:
lora_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -2,7 +2,6 @@ base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -16,7 +15,7 @@ val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
max_packed_sequence_len: 4096
adapter: lora
lora_model_dir:
@@ -27,7 +26,6 @@ lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -40,7 +38,7 @@ lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: false
@@ -50,8 +48,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
xformers_attention: true
flash_attention:
warmup_steps: 10
eval_steps: 20
@@ -65,3 +63,4 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -2,7 +2,6 @@ base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
@@ -19,8 +18,7 @@ adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
max_packed_sequence_len: 4096
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
@@ -29,7 +27,6 @@ lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -42,7 +39,7 @@ lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: false
@@ -52,8 +49,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
xformers_attention: true
flash_attention:
warmup_steps: 10
eval_steps: 20
@@ -67,3 +64,4 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -28,7 +28,6 @@ lora_target_modules:
- o_proj
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -35,7 +34,7 @@ torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
group_by_length: true
bf16: true
fp16: false
tf32: true

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -17,7 +17,6 @@ lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -21,7 +21,6 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- mlp_down
lora_fan_in_fan_out:
wandb_project: lora-replit
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -37,7 +37,6 @@ lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -1,6 +1,6 @@
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1
bitsandbytes>=0.39.0
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict
fire
@@ -21,4 +21,3 @@ evaluate==0.4.0
rouge-score==0.1.2
scipy
scikit-learn==1.2.2
pynvml

View File

@@ -18,13 +18,17 @@ from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
setup_trainer,
)
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -34,21 +38,30 @@ sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art():
ascii_art = """
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
"""
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if is_main_process():
print(ascii_art)
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
@@ -160,7 +173,6 @@ def train(
prepare_ds_only: bool = False,
**kwargs,
):
print_axolotl_text_art()
if Path(config).is_dir():
config = choose_config(config)
@@ -181,18 +193,67 @@ def train(
validate_config(cfg)
normalize_config(cfg)
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
LOG.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if is_main_process():
# process on rank 0 first so it gets cached so other ranks load from cache
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if not is_main_process():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...")
@@ -208,10 +269,15 @@ def train(
return
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer)
safe_serialization = cfg.save_safetensors is True
LOG.info("loading model and peft_config...")
model, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
)
if "merge_lora" in kwargs and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model")
@@ -220,11 +286,7 @@ def train(
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if cfg.inference:
@@ -239,7 +301,7 @@ def train(
return
if "shard" in kwargs:
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
model.save_pretrained(cfg.output_dir)
return
trainer = setup_trainer(
@@ -263,7 +325,7 @@ def train(
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal(
@@ -290,7 +352,6 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
@@ -308,7 +369,7 @@ def train(
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
model.save_pretrained(cfg.output_dir)
if __name__ == "__main__":

View File

@@ -5,7 +5,7 @@ import os
from typing import List
import torch
from datasets import Dataset, IterableDataset
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
@@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(Dataset):
class TokenizedPromptDataset(IterableDataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Iterable dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
@@ -30,18 +30,19 @@ class TokenizedPromptDataset(Dataset):
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
super().__init__(self.process(dataset).data, **kwargs)
self.dataset = dataset
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, os.cpu_count())
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
def __iter__(self):
features = self.dataset.features.keys()
num_proc = os.cpu_count()
return iter(
self.dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
)
)

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
@@ -92,8 +91,7 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif attention_mask.shape[0] == 1:
# special handling using sample packing
else:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze()
@@ -102,36 +100,6 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),

View File

@@ -1,49 +1,8 @@
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
import logging
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
LOG = logging.getLogger("axolotl.prompt_strategies.alpaca_instruct")
class LatentSpaceAlpacaPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
"""
Overrides the tokenization to include additional padding tokens as
latent space on the inputs
"""
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
# pylint: disable=duplicate-code
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
# latent space
if add_eos_token and not strip_bos_token:
result["input_ids"].extend([self.tokenizer.pad_token_id] * 100)
result["labels"] = result["input_ids"].copy()
return result
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
@@ -61,12 +20,3 @@ def load_no_prompt(tokenizer, cfg):
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_latent_space(tokenizer, cfg):
return LatentSpaceAlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -92,13 +92,12 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
def match_prompt_style(self):
# pylint: disable=duplicate-code
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.turn_format = "### Human:\n{instruction}\n### Additional Context:\n{input}\n### Assistant:\n"
self.turn_no_input_format = "### Human:\n{instruction}\n### Assistant:\n"
self.system_format = "### System:\n{system}\n"
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
self.turn_format = "User: {instruction}\n{input}\nAssistant:"
self.turn_no_input_format = "User: {instruction}\nAssistant:"
self.system_format = "System: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
from axolotl.prompters import IGNORE_TOKEN_ID
@dataclass
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
assert role == conv.roles[j % 2]
if sentence["value"]:
conv.append_message(role, sentence["value"])
yield conv

View File

@@ -31,52 +31,6 @@ def load_guanaco(tokenizer, cfg):
)
def load_latent_space(tokenizer, cfg):
return LatentSpaceShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class LatentSpaceShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
latent space padded sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
# pylint: disable=duplicate-code
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
# latent space
if add_eos_token and not strip_bos_token:
result["input_ids"].extend([self.tokenizer.pad_token_id] * 100)
result["labels"] = result["input_ids"].copy()
return result
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row

View File

@@ -74,11 +74,8 @@ class PromptTokenizingStrategy(abc.ABC):
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):

View File

@@ -271,11 +271,6 @@ class Conversation:
self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for the ShareGPT
@@ -312,9 +307,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)
raise IndexError
conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
@@ -334,7 +327,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
assert role == conv.roles[j % 2]
conv.append_message(role, sentence["value"])
for part in conv.get_prompt():

View File

@@ -1,43 +0,0 @@
"""Benchmarking and measurement utilities"""
import pynvml
import torch
def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc

View File

@@ -1,6 +1,5 @@
"""Callbacks for Trainer class"""
import logging
import os
from optimum.bettertransformer import BetterTransformer
@@ -12,10 +11,6 @@ from transformers import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl.callbacks")
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter"""
@@ -72,25 +67,3 @@ class SaveBetterTransformerModelCallback(
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control
class GPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to track GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control

View File

@@ -1,19 +1,14 @@
"""Module containing data utilities"""
import functools
import hashlib
import itertools
import logging
from hashlib import md5
from pathlib import Path
from typing import Tuple, Union
from typing import List, Tuple, Union
import torch
from datasets import (
Dataset,
DatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
@@ -41,44 +36,9 @@ from axolotl.prompters import (
ShareGPTPrompter,
SummarizeTLDRPrompter,
)
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
)
from axolotl.utils.distributed import barrier, is_main_process
LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
return train_dataset, eval_dataset, total_num_steps
def load_tokenized_prepared_datasets(
@@ -305,12 +265,20 @@ def load_tokenized_prepared_datasets(
raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
)
LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)
LOG.info("tokenizing, merging, and shuffling master dataset")
if len(datasets) > 1:
LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)
samples: List[int] = []
chunk_size = 1000
for d in datasets:
d_iter = iter(d)
while True:
chunk = list(itertools.islice(d_iter, chunk_size))
if not chunk:
break
samples.extend(chunk)
LOG.info("shuffle")
dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)
@@ -475,7 +443,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
with zero_first(is_main_process()):
if is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
@@ -483,6 +451,16 @@ def load_prepare_datasets(
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
if not is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

View File

@@ -3,7 +3,9 @@ import hashlib
import itertools
import logging
import math
from typing import Any, Callable, List, Union
import queue
import threading
from typing import Any, Callable, List, Optional, Union
import numba
import numpy as np
@@ -78,7 +80,6 @@ def allocate(
s = 0
start_index = 0
result = []
result_totseqs = []
while True:
# binary search [left, right)
@@ -104,10 +105,8 @@ def allocate(
# add local rank
result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
yield batch[rank], tot_seqs, s, len(result) * c * n
def chunk(iterable, n):
@@ -149,15 +148,14 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
total_num_tokens: Optional[int] = None,
):
# Dataset
self.dataset = dataset
self.lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
lengths_series = (
dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
)
self.lengths: np.ndarray = lengths_series.values
assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier
@@ -172,11 +170,17 @@ class MultipackDistributedDataloader:
self.rank = 0
# statistics
self.total_num_tokens = total_num_tokens
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
# for non-blocking batch creation
self.batch_queue: queue.Queue = queue.Queue(
maxsize=10
) # Adjust maxsize as needed
def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
if self.sampler:
@@ -188,65 +192,83 @@ class MultipackDistributedDataloader:
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
alloc_iter = iter(
allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
for batch, tot_seqs, total_used, total_slots in alloc_iter:
self.batch_queue.put([indices[b_idx] for b_idx in batch])
# statistics
if set_stats:
self.eff_total_used = total_used
self.eff_total_slots = total_slots
self.batch_queue.put(None) # Signal the end of batch generation
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches, totseqs
def _generate_batches_thread(self):
try:
self.generate_batches(set_stats=True)
except Exception as e:
LOG.error(f"Error in batch generation thread: {e}")
self.batch_queue.put(
None
) # Signal the end of batch generation in case of error
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True)
# Start the batch generation in a separate thread
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
batch_gen_thread.start()
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batches in chunk(
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
):
while True:
batch = self.batch_queue.get()
if batch is None: # Sentinel value received, stop iteration
break
chunked_data = []
attn_mask_cum_idx = 0
for batch in batches:
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return
break
# Wait for the batch generation thread to finish
batch_gen_thread.join(timeout=5)
LOG.info(f"actual packing efficiency: {self.efficiency()}")
def _len_est(self):
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // self.device_count
if not self.total_num_tokens:
self.total_num_tokens = np.sum(self.lengths)
lengths_sum_per_device = self.total_num_tokens // self.device_count
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"

View File

@@ -10,6 +10,3 @@ class DictDefault(Dict):
def __missing__(self, key):
return None
def __or__(self, other):
return DictDefault(super().__or__(other))

View File

@@ -1,8 +1,6 @@
"""
utility helpers for distributed checks
"""
from contextlib import contextmanager
import torch.distributed as dist
from accelerate import Accelerator
@@ -41,15 +39,3 @@ def is_main_process():
if not is_distributed():
return True
return dist.get_rank() == 0
@contextmanager
def zero_first(is_main):
"""
runs the wrapped context so that rank 0 runs first before other ranks
"""
if not is_main: # other ranks wait first
barrier()
yield
if is_main: # then rank 0 waits after it has run the context
barrier()

View File

@@ -22,7 +22,6 @@ from transformers import ( # noqa: F401
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl")
@@ -32,27 +31,37 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer(cfg):
def load_tokenizer(
tokenizer_config,
tokenizer_type,
cfg,
):
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in [
"LlamaTokenizer",
@@ -60,11 +69,6 @@ def load_tokenizer(cfg):
]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -79,21 +83,19 @@ def load_tokenizer(cfg):
def load_model(
cfg, tokenizer
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
):
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model for a given configuration and tokenizer.
Load a model from a base model and a model type.
"""
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = (
"llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model
or cfg.is_llama_derived_model is True
)
if cfg.is_llama_derived_model and cfg.flash_attention:
@@ -138,10 +140,8 @@ def load_model(
LOG.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope()
if (
cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not cfg.inference
if cfg.is_llama_derived_model and (
cfg.max_packed_sequence_len or cfg.sample_packing
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
@@ -231,17 +231,10 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
@@ -276,7 +269,6 @@ def load_model(
elif model_type and not cfg.trust_remote_code:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
@@ -307,7 +299,6 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
@@ -321,7 +312,6 @@ def load_model(
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
@@ -346,9 +336,6 @@ def load_model(
)
model.config.max_position_embeddings = cfg.sequence_len
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -368,7 +355,7 @@ def load_model(
if hasattr(module, "weight"):
module.to(torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
model, lora_config = load_adapter(model, cfg, adapter)
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
@@ -407,9 +394,6 @@ def load_model(
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
if cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling
return model, lora_config

View File

@@ -11,7 +11,6 @@ from pathlib import Path
from typing import Optional, Union
import bitsandbytes as bnb
import numpy as np
import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled
@@ -22,7 +21,6 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import (
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
@@ -124,6 +122,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
train_data_total_num_tokens: Optional[int] = field(
default=None,
metadata={"help": "the total number of tokens in the train dataset"},
)
class AxolotlTrainer(Trainer):
@@ -184,6 +186,7 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=self.args.train_data_total_num_tokens,
)
)
return super().get_train_dataloader()
@@ -206,6 +209,7 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=None,
)
)
return super().get_eval_dataloader(eval_dataset)
@@ -284,16 +288,13 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
if not cfg.total_num_tokens:
LOG.info("calculating total_num_tokens")
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
cfg.total_num_tokens = total_num_tokens
if cfg.sample_packing_eff_est:
total_num_steps = (
@@ -301,9 +302,9 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
(
math.floor(
0.99
* cfg.total_num_tokens
* total_num_tokens
/ cfg.sample_packing_eff_est
/ cfg.sequence_len
/ 2048
// cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1))
)
@@ -312,7 +313,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
* cfg.num_epochs
)
LOG.info(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
)
else:
sampler = RandomSampler(train_dataset)
@@ -344,7 +345,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
)
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -364,9 +364,6 @@ def setup_fsdp_envs(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
warmup_steps = (
cfg.warmup_steps
if cfg.warmup_steps is not None
@@ -414,13 +411,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
# deepspeed
if cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
# deepspeed
if (
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
and torch.cuda.device_count() > 1
):
if cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
else:
# make a guess here
# TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
@@ -435,9 +440,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
@@ -446,17 +448,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.val_set_size == 0:
training_arguments_kwargs["evaluation_strategy"] = "no"
elif cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
else:
# we have an eval set, but no steps defined, use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
@@ -466,7 +459,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps,
output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
@@ -488,7 +483,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
train_data_total_num_tokens=cfg.total_num_tokens,
**training_arguments_kwargs,
)
@@ -560,7 +556,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = []
callbacks.append(GPUStatsCallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(

View File

@@ -1,70 +1,12 @@
"""Module for working with config dicts"""
"""Module for validating config files"""
import logging
import os
import torch
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl")
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
# in `accelerate launch`, we need to not pass through any device map and let
# accelerate figure out which parts of the model to put on which gpu
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
if accelerate_vars:
cfg.device_map = None
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
log_gpu_memory_usage(LOG, "baseline", cfg.device)
def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError(
@@ -147,7 +89,7 @@ def validate_config(cfg):
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
if int(torch.__version__.split(".")[0]) < 2:
LOG.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
@@ -168,13 +110,6 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
if cfg.gptq and cfg.model_revision:
raise ValueError(
"model_revision is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
)
if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError(

View File

@@ -9,8 +9,6 @@ def setup_wandb_env_vars(cfg):
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:

File diff suppressed because one or more lines are too long

View File

@@ -72,13 +72,6 @@ class DictDefaultTest(unittest.TestCase):
assert cfg.random_key is None, "DictDefault should return None for missing keys"
def test_dict_or(self):
cfg = DictDefault({}) | DictDefault({})
assert (
cfg.random_key is None
), "DictDefault should return None for missing keys after | operation"
def test_dict_nested_missingparentkey(self):
"""
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.

View File

@@ -13,22 +13,17 @@ class TestTokenizers(unittest.TestCase):
"""
def test_default_use_fast(self):
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
}
)
tokenizer = load_tokenizer(cfg)
cfg = DictDefault({})
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" in tokenizer.__class__.__name__
def test_dont_use_fast(self):
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"tokenizer_use_fast": False,
}
)
tokenizer = load_tokenizer(cfg)
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" not in tokenizer.__class__.__name__

View File

@@ -6,8 +6,8 @@ from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase):