Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
f6721baf10 tweak to make it work when we have no explicit test split
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-11 22:40:21 -04:00
Wing Lian
33814cc94e make sure we eval for openorca 2023-07-02 17:59:10 -04:00
Wing Lian
50254a7ccc handle orca splits 2023-07-01 07:20:23 -04:00
68 changed files with 549 additions and 3023 deletions

13
.github/FUNDING.yml vendored
View File

@@ -1,13 +0,0 @@
# These are supported funding model platforms
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -18,13 +18,23 @@ jobs:
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.1 pytorch: 2.0.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" axolotl_extras:
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" axolotl_extras:
- cuda: "117"
cuda_version: 11.7.1
python_version: "3.9"
pytorch: 1.13.1
axolotl_extras:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras: gptq
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
@@ -48,9 +58,11 @@ jobs:
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: | build-args: |
CUDA_VERSION=${{ matrix.cuda_version }} CUDA_VERSION=${{ matrix.cuda_version }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }} PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }} PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }} AXOLOTL_EXTRAS=${{ matrix.axolotl_extras }}

View File

@@ -17,18 +17,23 @@ jobs:
- cuda: cu118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.1 pytorch: 2.0.0
axolotl_extras: axolotl_extras:
- cuda: cu118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.0
axolotl_extras: axolotl_extras:
- cuda: cu118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.1 pytorch: 2.0.0
axolotl_extras: gptq axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.1
python_version: "3.9"
pytorch: 1.13.1
axolotl_extras:
runs-on: self-hosted runs-on: self-hosted
steps: steps:
- name: Checkout - name: Checkout
@@ -50,11 +55,13 @@ jobs:
with: with:
context: . context: .
build-args: | build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
file: ./docker/Dockerfile file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-axolotl-runpod: build-axolotl-runpod:
needs: build-axolotl needs: build-axolotl
if: github.repository_owner == 'OpenAccess-AI-Collective' if: github.repository_owner == 'OpenAccess-AI-Collective'
@@ -62,21 +69,26 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.1 pytorch: 2.0.0
axolotl_extras: axolotl_extras:
- cuda: 118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.0
axolotl_extras: axolotl_extras:
- cuda: 118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.1 pytorch: 2.0.0
axolotl_extras: gptq axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.1
python_version: "3.9"
pytorch: 1.13.1
axolotl_extras:
runs-on: self-hosted runs-on: self-hosted
steps: steps:
- name: Checkout - name: Checkout
@@ -98,9 +110,10 @@ jobs:
with: with:
context: . context: .
build-args: | build-args: |
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-runpod file: ./docker/Dockerfile-runpod
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

202
LICENSE
View File

@@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

138
README.md
View File

@@ -24,12 +24,11 @@
| mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ | | mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ | | falcon | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | ✅ | | gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | ✅ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ❓ | ✅
## Quickstart ⚡ ## Quickstart ⚡
**Requirements**: Python >=3.9 and Pytorch >=2.0. **Requirements**: Python 3.9 and Pytorch 2.0.
```bash ```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl git clone https://github.com/OpenAccess-AI-Collective/axolotl
@@ -37,6 +36,8 @@ git clone https://github.com/OpenAccess-AI-Collective/axolotl
pip3 install -e . pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git pip3 install -U git+https://github.com/huggingface/peft.git
accelerate config
# finetune lora # finetune lora
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
@@ -51,10 +52,11 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
- Docker - Docker
```bash ```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1 docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
``` ```
- `winglian/axolotl-runpod:main-py3.10-cu118-2.0.1`: for runpod - `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.1-gptq`: for gptq - `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0-gptq`: for gptq
- `winglian/axolotl:dev`: dev branch (not usually up to date)
Or run on the current files for development: Or run on the current files for development:
@@ -106,7 +108,7 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
3. Install torch 3. Install torch
```bash ```bash
pip3 install -U torch --index-url https://download.pytorch.org/whl/cu118 pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
``` ```
4. Axolotl 4. Axolotl
@@ -136,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"instruction": "...", "input": "...", "output": "..."} {"instruction": "...", "input": "...", "output": "..."}
``` ```
- `sharegpt:chat`: conversations where `from` is `human`/`gpt` - `sharegpt:chat`: conversations
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
@@ -225,10 +227,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"conversations": [{"role": "...", "value": "..."}]} {"conversations": [{"role": "...", "value": "..."}]}
``` ```
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny - `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```json ```json
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
@@ -239,7 +237,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts #### How to add custom prompts
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example. 1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`. 2. Use your custom file name as the dataset type.
Optionally, download some datasets, see [data/README.md](data/README.md) Optionally, download some datasets, see [data/README.md](data/README.md)
@@ -247,7 +245,7 @@ Optionally, download some datasets, see [data/README.md](data/README.md)
### Config ### Config
See [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are: See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
- model - model
```yaml ```yaml
@@ -257,24 +255,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- dataset - dataset
```yaml ```yaml
sequence_len: 2048 # max token length for prompt
# huggingface repo
datasets: datasets:
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4 # local or huggingface repo
type: alpaca # format from earlier
# huggingface repo with specific configuration/subset
datasets:
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
# local
datasets:
- path: json
data_files: data.jsonl # or json
type: alpaca # format from earlier type: alpaca # format from earlier
sequence_len: 2048 # max token length / prompt
``` ```
- loading - loading
@@ -313,8 +297,6 @@ base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files, # if the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model # you can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf base_model_config: ./llama-7b-hf
# you can specify to choose a specific model revision from huggingface hub
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer # Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model # than the one defined in the base model
tokenizer_config: tokenizer_config:
@@ -326,9 +308,6 @@ tokenizer_type: AutoTokenizer
trust_remote_code: trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True # use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast: tokenizer_use_fast:
# resize the model embeddings when new tokens are added to multiples of N
# multiples of 32 are reported to improve training speed on some models
resize_token_embeddings_multiple:
# whether you are training a 4-bit GPTQ quantized model # whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
@@ -349,13 +328,12 @@ tf32: true # require >=ampere
# a list of one or more datasets to finetune the model with # a list of one or more datasets to finetune the model with
datasets: datasets:
# hf dataset repo | "json" for local dataset, make sure to fill data_files # this can be either a hf dataset, or relative path
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format OR format:prompt_style (chat/instruct)
data_files: # path to source data files data_files: # path to source data files
shards: # number of shards to split data into shards: # number of shards to split data into
name: # name of dataset configuration to load
# axolotl attempts to save the dataset as an arrow after packing the data together so # axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
@@ -363,10 +341,7 @@ dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub # push prepared dataset to hub
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# push checkpoints to hub # push checkpoints to hub
hub_model_id: # repo path to push finetuned model push_to_hub_model_id: # repo path
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub` # required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
@@ -382,14 +357,7 @@ dataset_shard_idx:
sequence_len: 2048 sequence_len: 2048
# max sequence length to concatenate training samples together up to # max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024 max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing:
# you can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
@@ -415,12 +383,11 @@ lora_out_dir:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
# wandb configuration if you're using it # wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_mode:
wandb_project: # your wandb project name wandb_project:
wandb_entity: # a wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_run_id: # set the name of your wandb run wandb_run_id:
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # 'checkpoint'
# where to save the finished model to # where to save the finished model to
output_dir: ./completed-model output_dir: ./completed-model
@@ -435,17 +402,10 @@ learning_rate: 0.00003
logging_steps: logging_steps:
save_steps: save_steps:
eval_steps: eval_steps:
save_total_limit: # checkpoints saved at a time
max_steps:
# save model as safetensors (require safetensors package)
save_safetensors:
# whether to mask out or include the human's prompt from the training labels # whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# group similarly sized data to minimize padding # don't use this, leads to wonky training (according to someone on the internet)
# may be slower to start, as it must download and sort the entire dataset
# note that training loss may have an oscillating pattern with this enabled
group_by_length: false group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
@@ -491,10 +451,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only # llama only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# resume from a specific checkpoint dir # resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
@@ -526,9 +482,6 @@ torchdistx_path:
# Set padding for data collator to 'longest' # Set padding for data collator to 'longest'
collator_pad_to_longest: collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:
# Debug mode # Debug mode
debug: debug:
@@ -541,6 +494,17 @@ strict:
</details> </details>
### Accelerate
Configure accelerate
```bash
accelerate config
# Edit manually
# nano ~/.cache/huggingface/accelerate/default_config.yaml
```
### Train ### Train
Run Run
@@ -548,40 +512,6 @@ Run
accelerate launch scripts/finetune.py configs/your_config.yml accelerate launch scripts/finetune.py configs/your_config.yml
``` ```
#### Multi-GPU
You can optionally pre-tokenize dataset with the following before finetuning:
```bash
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
```
##### Config
- llama FSDP
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
##### Weights & Biases Logging
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
```
### Inference ### Inference
Pass the appropriate flag to the train command: Pass the appropriate flag to the train command:
@@ -632,10 +562,6 @@ Try set `fp16: true`
Try to turn off xformers. Try to turn off xformers.
> accelerate config missing
It's safe to ignore it.
## Need help? 🙋♂️ ## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you

View File

@@ -3,15 +3,16 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y vim curl apt-get install -y vim curl
WORKDIR /workspace WORKDIR /workspace
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main" RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"transformers @ git+https://github.com/huggingface/transformers.git@main"
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \ RUN cd axolotl && \
@@ -21,10 +22,5 @@ RUN cd axolotl && \
pip install -e .; \ pip install -e .; \
fi fi
# fix so that git fetch/pull from remote works
RUN cd axolotl && \
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli # helper for huggingface-login cli
RUN git config --global credential.helper store RUN git config --global credential.helper store

View File

@@ -8,7 +8,7 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
ENV PATH="/root/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9" ARG PYTHON_VERSION="3.9"
ARG PYTORCH_VERSION="2.0.1" ARG PYTORCH="2.0.0"
ARG CUDA="118" ARG CUDA="118"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -29,18 +29,17 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu$CUDA
FROM base-builder AS flash-attn-builder FROM base-builder AS flash-attn-builder
WORKDIR /workspace WORKDIR /workspace
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ RUN git clone https://github.com/HazyResearch/flash-attention.git && \
cd flash-attention && \ cd flash-attention && \
git checkout v2.0.4 && \
python3 setup.py bdist_wheel && \ python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \ cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \ python3 setup.py bdist_wheel && \
@@ -53,7 +52,7 @@ RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
FROM base-builder AS deepspeed-builder FROM base-builder AS deepspeed-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
WORKDIR /workspace WORKDIR /workspace
@@ -74,9 +73,6 @@ RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
FROM base-builder FROM base-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
# recompile apex # recompile apex
RUN python3 -m pip uninstall -y apex RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex RUN git clone https://github.com/NVIDIA/apex
@@ -101,4 +97,4 @@ RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo RUN git lfs install --skip-repo
RUN pip3 install awscli && \ RUN pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic

View File

@@ -1,10 +1,6 @@
ARG BASE_TAG=main ARG BASE_TAG=main
FROM winglian/axolotl:$BASE_TAG FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
RUN apt install --yes --no-install-recommends openssh-server tmux && \ RUN apt install --yes --no-install-recommends openssh-server tmux && \

View File

@@ -37,18 +37,18 @@
"lr": "auto", "lr": "auto",
"betas": [ "betas": [
0.9, 0.9,
0.95 0.999
], ],
"eps": 1e-8, "eps": 1e-8,
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": { "scheduler": {
"type": "WarmupLR", "type": "OneCycle",
"params": { "params": {
"warmup_min_lr": "auto", "cycle_min_lr": 0.00001,
"warmup_max_lr": "auto", "cycle_max_lr": 0.00003,
"warmup_num_steps": "auto" "cycle_first_step_size": 120
} }
}, },
"train_batch_size": "auto", "train_batch_size": "auto",

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -36,7 +35,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -38,7 +38,6 @@ lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -24,7 +24,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -33,7 +32,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0001 learning_rate: 0.0001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -22,7 +22,6 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4 wandb_project: llama-7b-lora-int4
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -18,7 +18,6 @@ lora_dropout:
lora_target_modules: lora_target_modules:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -1,20 +0,0 @@
# Overview
This is an example of a llama-2 configuration for 7b and 13b. The yaml file contains configuration for the 7b variant, but you can just aswell use the same settings for 13b.
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
The 13b variant will fit if you change these settings to these values:
gradient_accumulation_steps: 2
micro_batch_size: 1
```shell
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
```
or
```shell
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
```

View File

@@ -1,66 +0,0 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,68 +0,0 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -28,7 +28,6 @@ lora_target_modules:
- o_proj - o_proj
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -22,7 +22,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -35,7 +34,7 @@ torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: true
bf16: true bf16: true
fp16: false fp16: false
tf32: true tf32: true

View File

@@ -23,7 +23,6 @@ lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -17,7 +17,6 @@ lora_target_modules:
lora_target_linear: lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -21,7 +21,6 @@ lora_target_modules:
- v_proj - v_proj
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -20,7 +20,6 @@ lora_target_modules:
- mlp_down - mlp_down
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: lora-replit wandb_project: lora-replit
wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:

View File

@@ -1,91 +0,0 @@
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
# on Tim Dettmer's Guanaco dataset.
base_model: Salesforce/xgen-7b-8k-base
base_model_config: Salesforce/xgen-7b-8k-base
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
gptq: false
strict: false
push_dataset_to_hub:
datasets:
- path: timdettmers/openassistant-guanaco
data_files:
- openassistant_best_replies_train.jsonl
type: "completion"
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
# enable QLoRA
adapter: qlora
lora_model_dir:
sequence_len: 8192
max_packed_sequence_len:
# hyperparameters from QLoRA paper Appendix B.2
# "We find hyperparameters to be largely robust across datasets"
lora_r: 64
lora_alpha: 16
# 0.1 for models up to 13B
# 0.05 for 33B and 65B models
lora_dropout: 0.05
# add LoRA modules on all linear layers of the base model
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
# QLoRA paper Table 9
# - 16 for 7b & 13b
# - 32 for 33b, 64 for 64b
# Max size tested on A6000
# - 7b: 40
# - 40b: 4
# decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1
gradient_accumulation_steps: 1
num_epochs: 3
# Optimizer for QLoRA
optimizer: paged_adamw_32bit
torchdistx_path:
lr_scheduler: cosine
# QLoRA paper Table 9
# - 2e-4 for 7b & 13b
# - 1e-4 for 33b & 64b
learning_rate: 0.00002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 50
save_steps: 50
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
eos_token: "<|endoftext|>"
bos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"
pad_token: "<|endoftext|>"

View File

@@ -1,7 +1,7 @@
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1 bitsandbytes>=0.39.0
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b accelerate
addict addict
fire fire
PyYAML==6.0 PyYAML==6.0
@@ -12,13 +12,9 @@ wandb
einops einops
xformers xformers
optimum optimum
hf_transfer
numba
numpy==1.24.4
# qlora things # qlora things
bert-score==0.3.13 bert-score==0.3.13
evaluate==0.4.0 evaluate==0.4.0
rouge-score==0.1.2 rouge-score==0.1.2
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml

View File

@@ -15,9 +15,6 @@ from axolotl.convert import (
JsonToJsonlConverter, JsonToJsonlConverter,
StdoutWriter, StdoutWriter,
) )
from axolotl.logging_config import configure_logging
configure_logging()
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

View File

@@ -17,30 +17,42 @@ import yaml
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import ( from axolotl.utils.trainer import setup_trainer
calculate_total_num_steps, from axolotl.utils.validation import validate_config
process_datasets_for_packing,
setup_trainer,
)
from axolotl.utils.wandb import setup_wandb_env_vars from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src") src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]: def get_multi_line_input() -> Optional[str]:
@@ -172,13 +184,36 @@ def train(
validate_config(cfg) validate_config(cfg)
normalize_config(cfg) # setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first # load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer = load_tokenizer(cfg) logging.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if ( if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
@@ -192,33 +227,14 @@ def train(
cfg.pretraining_dataset, cfg.pretraining_dataset,
tokenizer, tokenizer,
max_tokens=cfg.sequence_len, max_tokens=cfg.sequence_len,
seed=cfg.seed or 42, seed=cfg.seed,
) )
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
if is_main_process():
# process on rank 0 first so it gets cached so other ranks load from cache
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if not is_main_process():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...") logging.info("check_dataset_labels...")
check_dataset_labels( check_dataset_labels(
train_dataset.select( train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
@@ -227,31 +243,32 @@ def train(
) )
if prepare_ds_only: if prepare_ds_only:
LOG.info("Finished preparing dataset. Exiting...") logging.info("Finished preparing dataset. Exiting...")
return return
# Load the model and tokenizer # Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...") logging.info("loading model and peft_config...")
model, peft_config = load_model(cfg, tokenizer) model, peft_config = load_model(
cfg.base_model,
safe_serialization = cfg.save_safetensors is True cfg.base_model_config,
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
)
if "merge_lora" in kwargs and cfg.adapter is not None: if "merge_lora" in kwargs and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model") logging.info("running merge of LoRA with base model")
model = model.merge_and_unload() model = model.merge_and_unload()
model.to(dtype=torch.float16) model.to(dtype=torch.float16)
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info("saving merged model") logging.info("saving merged model")
model.save_pretrained( model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return return
if cfg.inference: if cfg.inference:
LOG.info("calling do_inference function") logging.info("calling do_inference function")
prompter: Optional[str] = "AlpacaPrompter" prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs: if "prompter" in kwargs:
if kwargs["prompter"] == "None": if kwargs["prompter"] == "None":
@@ -262,22 +279,20 @@ def train(
return return
if "shard" in kwargs: if "shard" in kwargs:
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir)
return return
trainer = setup_trainer( trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32": if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model") logging.info("Compiling torch model")
model = torch.compile(model) model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect # go ahead and presave, so we have the adapter config available to inspect
if peft_config: if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir) peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
@@ -286,16 +301,16 @@ def train(
def terminate_handler(_, __, model): def terminate_handler(_, __, model):
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir)
sys.exit(0) sys.exit(0)
signal.signal( signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
) )
LOG.info("Starting trainer...") logging.info("Starting trainer...")
if cfg.group_by_length: if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length") logging.info("hang tight... sorting dataset for group_by_length")
resume_from_checkpoint = cfg.resume_from_checkpoint resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [ possible_checkpoints = [
@@ -307,13 +322,12 @@ def train(
key=lambda path: int(path.split("-")[-1]), key=lambda path: int(path.split("-")[-1]),
) )
resume_from_checkpoint = sorted_paths[-1] resume_from_checkpoint = sorted_paths[-1]
LOG.info( logging.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}" f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
) )
if not Path(cfg.output_dir).is_dir(): if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True) os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True enable_flash=True, enable_math=True, enable_mem_efficient=True
@@ -322,16 +336,16 @@ def train(
else: else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp: if cfg.local_rank == 0:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
if __name__ == "__main__": if __name__ == "__main__":

19
scripts/runpod-entrypoint.sh Executable file → Normal file
View File

@@ -1,21 +1,10 @@
#!/bin/bash #!/bin/bash
# Export specific ENV variables to /etc/rp_environment echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
echo "Exporting environment variables..." chmod 700 -R ~/.ssh
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
echo 'source /etc/rp_environment' >> ~/.bashrc
if [[ $PUBLIC_KEY ]] # Start the SSH service in the background
then service ssh start
mkdir -p ~/.ssh
chmod 700 ~/.ssh
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
# Start the SSH service in the background
service ssh start
else
echo "No PUBLIC_KEY ENV variable provided, not starting openSSH daemon"
fi
# Execute the passed arguments (CMD) # Execute the passed arguments (CMD)
exec "$@" exec "$@"

View File

@@ -1,13 +1,12 @@
"""Module containing Dataset functionality""" """Module containing Dataset functionality"""
import logging import logging
import os
from typing import List from typing import List
import torch import torch
from datasets import Dataset, IterableDataset from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded # We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example # lets use the concept of middlewares to wrap each dataset, for example
@@ -15,12 +14,10 @@ from .prompt_tokenizers import PromptTokenizingStrategy
# let's check to ensure we don't truncate an item in the middle, we'll use # let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets # the collators later on to pad the datasets
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(IterableDataset):
class TokenizedPromptDataset(Dataset):
""" """
Dataset that returns tokenized prompts from a stream of text files. Iterable dataset that returns tokenized prompts from a stream of text files.
Args: Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data. prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files. dataset (dataset.Dataset): Dataset with text files.
@@ -30,19 +27,22 @@ class TokenizedPromptDataset(Dataset):
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset, dataset: IterableDataset,
**kwargs,
): ):
self.prompt_tokenizer = prompt_tokenizer self.prompt_tokenizer = prompt_tokenizer
super().__init__(self.process(dataset).data, **kwargs) self.dataset = dataset
def process(self, dataset): def __iter__(self):
features = dataset.features.keys() iterator = iter(self.dataset)
num_proc = min(64, os.cpu_count()) count = 0
return dataset.map( # Loop through the entire dataset
self.prompt_tokenizer.tokenize_prompt, for example in iterator:
num_proc=num_proc, try:
remove_columns=features, yield self.prompt_tokenizer.tokenize_prompt(example)
) count += 1
except InvalidDataException:
pass
if count == 0:
raise RuntimeError("Expected at least one datapoint in dataset.")
# TODO this isn't the best since it can't interleave datasets # TODO this isn't the best since it can't interleave datasets
@@ -76,21 +76,14 @@ class ConstantLengthDataset(IterableDataset):
self.tokens_dtype = torch.int64 self.tokens_dtype = torch.int64
def __iter__(self): def __iter__(self):
buffer = { buffer = {"input_ids": [], "attention_mask": [], "labels": []}
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0 buffer_len = 0
for dataset in self.datasets: for dataset in self.datasets:
idx = 0
iterator = iter(dataset) iterator = iter(dataset)
more_examples = True more_examples = True
while more_examples: while more_examples:
try: try:
example = next(iterator) example = next(iterator)
idx += 1
except StopIteration: except StopIteration:
more_examples = False more_examples = False
example = None example = None
@@ -112,9 +105,6 @@ class ConstantLengthDataset(IterableDataset):
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length : self.seq_length
] ]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and ( if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size() attention_mask.size() == input_ids.size()
@@ -123,20 +113,17 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": input_ids, "input_ids": input_ids,
"labels": labels, "labels": labels,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"position_ids": position_ids,
} }
else: else:
LOG.warning( logging.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
) )
buffer = { buffer = {
"input_ids": [], "input_ids": [],
"attention_mask": [], "attention_mask": [],
"labels": [], "labels": [],
"position_ids": [],
} }
buffer_len = 0 buffer_len = 0
idx = 1
if example: if example:
# FIXME # FIXME
@@ -145,6 +132,11 @@ class ConstantLengthDataset(IterableDataset):
input_ids = example["input_ids"] input_ids = example["input_ids"]
attention_mask = example["attention_mask"] attention_mask = example["attention_mask"]
labels = example["labels"] labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token: if add_concat_token:
input_ids.append(self.concat_token_id) input_ids.append(self.concat_token_id)
@@ -155,17 +147,13 @@ class ConstantLengthDataset(IterableDataset):
input_ids, dtype=self.tokens_dtype input_ids, dtype=self.tokens_dtype
) )
attention_mask_with_concat = torch.tensor( attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16 attention_mask, dtype=self.tokens_dtype
) )
labels_with_concat = torch.tensor( labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype labels, dtype=self.tokens_dtype
) )
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat) buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat) buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat) buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids) buffer_len += len(input_ids)

View File

@@ -8,18 +8,9 @@ import torch
import transformers import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
def forward( def forward(
self, self,
@@ -88,17 +79,7 @@ def forward(
dtype=torch.int32, dtype=torch.int32,
device=qkv.device, device=qkv.device,
) )
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif attention_mask.shape[0] == 1:
# special handling using sample packing
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze()
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
@@ -114,7 +95,7 @@ def forward(
three=3, three=3,
h=nheads, h=nheads,
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, x_unpad,
cu_q_lens, cu_q_lens,
max_s, max_s,
@@ -132,7 +113,6 @@ def forward(
"b s (h d) -> b s h d", "b s (h d) -> b s h d",
h=nheads, h=nheads,
) )
return ( return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")), self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None, None,

View File

@@ -1,33 +0,0 @@
"""Logging configuration settings"""
import os
import sys
from logging.config import dictConfig
from typing import Any, Dict
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
},
},
"filters": {},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "simple",
"filters": [],
"stream": sys.stdout,
},
},
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
"loggers": {
"axolotl": {"handlers": ["console"], "level": "DEBUG", "propagate": False},
},
}
def configure_logging():
"""Configure with default logging"""
dictConfig(DEFAULT_LOGGING_CONFIG)

View File

@@ -7,7 +7,6 @@ import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
from torch import nn from torch import nn
@@ -39,48 +38,21 @@ def xformers_forward(
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"): query_states = (
self.pretraining_tp = 1 self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
if self.pretraining_tp > 1: .transpose(1, 2)
key_value_slicing = ( )
self.num_key_value_heads * self.head_dim key_states = (
) // self.pretraining_tp self.k_proj(hidden_states)
query_slices = self.q_proj.weight.split( .view(bsz, q_len, self.num_heads, self.head_dim)
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 .transpose(1, 2)
) )
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_states = (
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
query_states = [ .transpose(1, 2)
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) )
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
@@ -101,14 +73,6 @@ def xformers_forward(
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = transformers.models.llama.modeling_llama.repeat_kv(
key_states, self.num_key_value_groups
)
value_states = transformers.models.llama.modeling_llama.repeat_kv(
value_states, self.num_key_value_groups
)
# We only apply xformers optimizations if we don't need to output the whole attention matrix # We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions: if not output_attentions:
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
@@ -128,7 +92,6 @@ def xformers_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(), attn_bias=xformers.ops.LowerTriangularMask(),
) )
attn_weights = None attn_weights = None
@@ -165,23 +128,10 @@ def xformers_forward(
f" {attn_output.size()}" f" {attn_output.size()}"
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2)
# end x-formers vs. not x-formers if-else block
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
@@ -234,15 +184,14 @@ def sdp_attention_forward(
# We only apply sdp attention if we don't need to output the whole attention matrix # We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions: if not output_attentions:
with torch.backends.cuda.sdp_kernel(): attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output = torch.nn.functional.scaled_dot_product_attention( query_states,
query_states, key_states,
key_states, value_states,
value_states, attn_mask=attention_mask,
attn_mask=attention_mask, is_causal=False,
is_causal=False, )
) attn_weights = None
attn_weights = None
else: else:
attn_weights = torch.matmul( attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3) query_states, key_states.transpose(2, 3)

View File

@@ -1,52 +0,0 @@
"""
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
"""
from typing import Optional
import torch
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def hijack_expand_mask():
import transformers
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
_expand_mask
)

View File

@@ -53,7 +53,7 @@ from transformers.utils import (
replace_return_docstrings, replace_return_docstrings,
) )
LOG = logging.getLogger("axolotl") logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig" _CONFIG_FOR_DOC = "LlamaConfig"
@@ -862,7 +862,7 @@ class LlamaModel(LlamaPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
LOG.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False

View File

@@ -1,103 +0,0 @@
"""
Shared utils for the monkeypatches
"""
import torch
def get_cu_seqlens(attn_mask):
"""generate a cumulative sequence length mask for flash attention using attn mask"""
if len(attn_mask.shape) == 1:
attn_mask = attn_mask.unsqueeze(0)
device = attn_mask.device
results = []
max_seq_lens = []
for row in attn_mask:
# Exclude zeros to avoid adding their positions to the mask
t_non_zeros = row[row != 0]
# Find where the sequence number changes (including the first position)
seq_change = torch.cat(
[
torch.tensor([1], dtype=torch.int32, device=device),
t_non_zeros[1:] != t_non_zeros[:-1],
]
)
# Get the indices where the sequence changes
change_indices = torch.cat(
[
(seq_change == 1).nonzero(as_tuple=True)[0],
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = change_indices[1:] - change_indices[:-1]
# Calculate the length of the final sequence or padding
final_seq_length = len(row) - change_indices[-1]
# Append the length of the final sequence or padding to seq_lengths
if final_seq_length.item():
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[final_seq_length.item()], dtype=torch.int32, device=device
),
]
)
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
device = position_ids.device
results = []
max_seq_lens = []
for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
(seq_starts).nonzero(as_tuple=True)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)
# Append the padding length to the cumulative sequence lengths
if padding_length:
cu_seqlens = torch.cat(
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
)
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)

View File

@@ -66,46 +66,15 @@ class SystemDataPrompter(AlpacaPrompter):
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
formatted_sys_prompt = (
self.system_format.format(system=system)
if system and self.system_format
else ""
)
if input: if input:
res = formatted_sys_prompt + self.turn_format.format( res = system + self.turn_format.format(instruction=instruction, input=input)
instruction=instruction, input=input
)
else: else:
res = formatted_sys_prompt + self.turn_no_input_format.format( res = system + self.turn_no_input_format.format(instruction=instruction)
instruction=instruction
)
if output: if output:
res = f"{res}{output}" res = f"{res}{output}"
yield res yield res
class OpenOrcaSystemDataPrompter(SystemDataPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts
"""
def match_prompt_style(self):
# pylint: disable=duplicate-code
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy): class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
""" """
Tokenizing strategy for OpenOrca datasets Tokenizing strategy for OpenOrca datasets
@@ -144,16 +113,7 @@ def load_chat(tokenizer, cfg):
def load_open_orca(tokenizer, cfg): def load_open_orca(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy( return OpenOrcaPromptTokenizingStrategy(
OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value), SystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_open_orca_chatml(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -1,205 +0,0 @@
"""
Prompt Strategy for finetuning Llama2 chat models
see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.
This implementation is based on the Vicuna PR and the fastchat repo, see also:
https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847
Use dataset type: "llama2_chat" in conig.yml to use this prompt style.
E.g. in the config.yml:
```
datasets:
- path: llama_finetune_train.jsonl
type: llama2_chat
```
The dataset itself should look like this:
```
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]}
```
in a jsonl file. The first message should be from the human, the second from gpt.
For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns).
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
"""
import logging
from dataclasses import dataclass, field
from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
@dataclass
class Llama2ChatConversation:
"""A class that manages prompt templates and keeps all conversation history.
copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py"""
name: str = "llama2"
# The system prompt
system: str = (
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
)
roles: Sequence[str] = ("[INST]", "[/INST]")
messages: List[List[str]] = field(default_factory=list)
offset: int = 0
sep = " "
sep2 = " </s><s>"
stop_token_ids = [2]
def get_prompt(self) -> str:
"""Get the prompt for generation."""
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if (i == len(self.messages) - 1) and (role == self.roles[0]):
# last message is from user (due to length),
# return prompt without it for training
return ret
if i == 0:
ret += self.system + message.strip()
else:
ret += role + " " + message.strip() + seps[i % 2]
return ret
def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for ShareGPT prompts.
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sequence_len = 4096
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
def tokenize_prompt(self, prompt):
conv = next(self.prompter.build_prompt(prompt))
conversation_str = conv.get_prompt()
# Tokenize conversations
input_ids = self.tokenizer(
conversation_str,
return_tensors="pt",
padding="max_length",
max_length=self.sequence_len,
truncation=True,
).input_ids[0]
target = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.roles[1]
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
turns = conversation_str.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for turn in turns:
if turn == "":
break
turn_len = len(self.tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-1" is hardcoded for the LLaMA tokenizer to make the offset correct.
instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1
# Ignore the user instructions
target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len + 2 # due to length of role token
target[cur_len:] = IGNORE_TOKEN_ID
if cur_len < self.sequence_len:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
logging.warning(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist()
input_ids = input_ids.tolist()
target = target.tolist()
# this is a fix for the tokenizer which tokenizes [ differently with eos tokens and
# follows the original llama implementation
for i in range(2, total_len - 2):
if input_ids[i] == 29961:
input_ids[i] = 518
if target[i] == 29961:
target[i] = 518
return {
"input_ids": input_ids,
"labels": target,
"attention_mask": attention_mask,
}
class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for Llama2 models.
"""
system_prompt = (
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
)
def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]:
# see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78
source = source["conversations"] # fix data structure for datasets
# if system prompt provided, use it
if source[0]["from"] == "system":
system = f"[INST] <<SYS>>\n{source[0]['value']}\n<</SYS>>\n\n"
source = source[1:]
else:
system = self.system_prompt
conv = Llama2ChatConversation(system=system)
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
if sentence["value"]:
conv.append_message(role, sentence["value"])
yield conv
def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy:
return LLama2ChatTokenizingStrategy(
Llama2ChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -1,46 +0,0 @@
"""
Prompt Strategy for finetuning Orca Mini (v2) models
see also https://huggingface.co/psmathur/orca_mini_v2_7b for more information
Use dataset type: orcamini in conig.yml to use this prompt style.
Compared to the alpaca_w_system.open_orca dataset type,
this one specifies the system prompt with "### System:".
Not suited/tested for multiple-turn conversations without further adjustments.
"""
from typing import Generator, Union
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
class OrcaMiniPrompter(AlpacaPrompter):
"""Adjusted Prompter for Orca Mini (v2) datasets"""
def match_prompt_style(self):
self.turn_no_input_format = (
"### System:\n{system}\n\n### User:\n{instruction}\n\n### Response:\n"
)
def build_prompt_w_system(
self,
system: str,
instruction: str,
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
res = self.turn_no_input_format.format(system=system, instruction=instruction)
if output:
res = f"{res}{output}"
yield res
def load(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
OrcaMiniPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -11,8 +11,6 @@ from axolotl.prompt_tokenizers import (
tokenize_prompt_default, tokenize_prompt_default,
) )
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
@@ -66,7 +64,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
*copy.deepcopy(res["input_ids"]) *copy.deepcopy(res["input_ids"])
][len(self.bot_prefix_token_ids) :] ][len(self.bot_prefix_token_ids) :]
else: else:
LOG.warning(f"unknown role in conversation: {role}") logging.warning(f"unknown role in conversation: {role}")
res = defaultdict(lambda: []) res = defaultdict(lambda: [])
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

View File

@@ -10,8 +10,6 @@ from transformers import PreTrainedTokenizer
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
LOG = logging.getLogger("axolotl")
IGNORE_INDEX = -100 IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
@@ -48,22 +46,16 @@ class PromptTokenizingStrategy(abc.ABC):
@functools.lru_cache(maxsize=128) @functools.lru_cache(maxsize=128)
def _get_user_token(self): def _get_user_token(self):
try: id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") if isinstance(id_or_ids, (int,)):
if isinstance(id_or_ids, (int,)): return id_or_ids
return id_or_ids
except KeyError:
pass
return False return False
@functools.lru_cache(maxsize=128) @functools.lru_cache(maxsize=128)
def _get_assistant_token(self): def _get_assistant_token(self):
try: id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") if isinstance(id_or_ids, (int,)):
if isinstance(id_or_ids, (int,)): return id_or_ids
return id_or_ids
except KeyError:
pass
return False return False
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
@@ -392,7 +384,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else: else:
LOG.warning(f"unhandled role: {part[0]}") logging.warning(f"unhandled role: {part[0]}")
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result( result, current_len = parse_tokenized_to_result(

View File

@@ -5,7 +5,6 @@ import logging
from enum import Enum, auto from enum import Enum, auto
from typing import Generator, List, Optional, Tuple, Union from typing import Generator, List, Optional, Tuple, Union
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
@@ -16,7 +15,6 @@ class PromptStyle(Enum):
INSTRUCT = "instruct" INSTRUCT = "instruct"
CHAT = "chat" CHAT = "chat"
CHATML = "chatml"
class AlpacaPrompter: class AlpacaPrompter:
@@ -26,7 +24,6 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_format: str
turn_format: str turn_format: str
turn_no_input_format: str turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None prompt_style: Optional[PromptStyle] = None
@@ -36,23 +33,14 @@ class AlpacaPrompter:
self.match_prompt_style() self.match_prompt_style()
def match_prompt_style(self): def match_prompt_style(self):
# pylint: disable=duplicate-code
if self.prompt_style == PromptStyle.INSTRUCT.value: if self.prompt_style == PromptStyle.INSTRUCT.value:
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_no_input_format = ( self.turn_no_input_format = (
"### Instruction:\n{instruction}\n\n### Response:\n" "### Instruction:\n{instruction}\n\n### Response:\n"
) )
self.system_format = "### System:\n{system}\n\n"
if self.prompt_style == PromptStyle.CHAT.value: if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
def build_prompt( def build_prompt(
self, self,
@@ -253,7 +241,7 @@ class Conversation:
if message: if message:
yield (role + ":", " " + message) yield (role + ":", " " + message)
else: else:
LOG.warning(f"role with empty message: {role}") logging.warning(f"role with empty message: {role}")
yield (role + ":", "") yield (role + ":", "")
def copy(self): def copy(self):
@@ -271,11 +259,6 @@ class Conversation:
self.messages.append([role, message]) self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
class ShareGPTPrompter: # pylint: disable=too-few-public-methods class ShareGPTPrompter: # pylint: disable=too-few-public-methods
""" """
A prompter that generates prompts for the ShareGPT A prompter that generates prompts for the ShareGPT
@@ -312,9 +295,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations
raise IndexError( raise IndexError
f"A conversation entry has less than 2 messages :\n{source}"
)
conv = self._conversation.copy() conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
@@ -334,7 +315,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] conv.messages = []
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE assert role == conv.roles[j % 2]
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
for part in conv.get_prompt(): for part in conv.get_prompt():

View File

@@ -1,43 +0,0 @@
"""Benchmarking and measurement utilities"""
import pynvml
import torch
def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc

View File

@@ -1,6 +1,5 @@
"""Callbacks for Trainer class""" """Callbacks for Trainer class"""
import logging
import os import os
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
@@ -12,10 +11,6 @@ from transformers import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl.callbacks")
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter""" """Callback to save the PEFT adapter"""
@@ -72,25 +67,3 @@ class SaveBetterTransformerModelCallback(
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False control.should_save = False
return control return control
class GPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to track GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control

View File

@@ -1,121 +0,0 @@
"""
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
@dataclass
class DataCollatorForSeq2Seq:
"""
Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
Args:
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
The tokenizer used for encoding the data.
model ([`PreTrainedModel`]):
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
prepare the *decoder_input_ids*
This is useful when using *label_smoothing* to avoid calculating loss twice.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
sequence is provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (`int`, *optional*, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
"""
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
def __call__(self, features, return_tensors=None):
labels = None
if return_tensors is None:
return_tensors = self.return_tensors
for feature_name, pad_token_id in [
("labels", self.label_pad_token_id),
("position_ids", self.position_pad_token_id),
]:
feat = (
[feature[feature_name] for feature in features]
if feature_name in features[0].keys()
else None
)
labels = feat if feat and feature_name == "labels" else labels
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
if feat is not None:
max_feature_length = max(len(l) for l in feat) # noqa: E741
if self.pad_to_multiple_of is not None:
max_feature_length = (
(max_feature_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
)
padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [pad_token_id] * (
max_feature_length - len(feature[feature_name])
)
if isinstance(feature[feature_name], list):
feature[feature_name] = (
feature[feature_name] + remainder
if padding_side == "right"
else remainder + feature[feature_name]
)
elif padding_side == "right":
feature[feature_name] = np.concatenate(
[feature[feature_name], remainder]
).astype(np.int64)
else:
feature[feature_name] = np.concatenate(
[remainder, feature[feature_name]]
).astype(np.int64)
features = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
# prepare decoder_input_ids
if (
labels is not None
and self.model is not None
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
labels=features["labels"]
)
features["decoder_input_ids"] = decoder_input_ids
return features

View File

@@ -1,199 +0,0 @@
"""Module for working with config dicts"""
import logging
import os
import torch
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl")
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
# in `accelerate launch`, we need to not pass through any device map and let
# accelerate figure out which parts of the model to put on which gpu
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
if accelerate_vars:
cfg.device_map = None
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
log_gpu_memory_usage(LOG, "baseline", cfg.device)
def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError(
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
)
if cfg.max_packed_sequence_len:
LOG.warning(
str(
PendingDeprecationWarning(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
)
)
)
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
if cfg.batch_size:
LOG.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.load_4bit:
raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
)
if cfg.adapter == "qlora":
if cfg.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if cfg.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if cfg.gptq:
raise ValueError("Can't merge qlora if gptq")
if cfg.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if cfg.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if cfg.gptq:
raise ValueError("Can't load qlora if gptq")
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
)
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
if (
cfg.base_model and "mpt" in cfg.base_model.lower()
) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
if cfg.flash_optimum is True:
if cfg.adapter:
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bloat16 is not True:
LOG.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
LOG.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
if cfg.pretraining_dataset and cfg.group_by_length:
LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
if cfg.push_to_hub_model_id:
raise ValueError(
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
if cfg.gptq and cfg.model_revision:
raise ValueError(
"model_revision is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
)
if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError(
"sample_packing not compatible with sdp_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.xformers_attention:
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3

View File

@@ -1,19 +1,12 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools import functools
import hashlib
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import List, Tuple, Union
import torch import torch
from datasets import ( from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
Dataset,
DatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -41,13 +34,10 @@ from axolotl.prompters import (
ShareGPTPrompter, ShareGPTPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
) )
from axolotl.utils.distributed import barrier, is_main_process
LOG = logging.getLogger("axolotl")
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path split, tokenizer, cfg, default_dataset_prepared_path
) -> DatasetDict: ) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
@@ -59,6 +49,8 @@ def load_tokenized_prepared_datasets(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
) )
+ "|" + "|"
+ split
+ "|"
+ tokenizer_name + tokenizer_name
).encode("utf-8") ).encode("utf-8")
).hexdigest() ).hexdigest()
@@ -76,24 +68,24 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
dataset = dataset["train"] dataset = dataset[split]
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
pass pass
if dataset: if dataset:
... ...
elif any(prepared_ds_path.glob("*")): elif any(prepared_ds_path.glob("*")):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path)) dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared dataset loaded from disk...") logging.info("Prepared dataset loaded from disk...")
else: else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
LOG.info("Loading raw datasets...") logging.info("Loading raw datasets...")
if cfg.seed: if cfg.seed:
seed = cfg.seed seed = cfg.seed
else: else:
LOG.info("No seed provided, using default seed of 42") logging.info("No seed provided, using default seed of 42")
seed = 42 seed = 42
datasets = [] datasets = []
@@ -104,7 +96,6 @@ def load_tokenized_prepared_datasets(
try: try:
load_dataset( load_dataset(
d.path, d.path,
name=d.name,
streaming=True, streaming=True,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
@@ -113,52 +104,40 @@ def load_tokenized_prepared_datasets(
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(d.path) if Path(d.path).exists():
if local_path.exists(): ds = load_dataset(
if local_path.is_dir(): "json",
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk` data_files=d.path,
streaming=False,
split=None,
)
elif ds_from_hub:
if d.data_files:
ds = load_dataset( ds = load_dataset(
d.path, d.path,
name=d.name, streaming=False,
data_files=d.data_files, data_files=d.data_files,
streaming=False, use_auth_token=use_auth_token,
split=None,
)
elif local_path.is_file():
ds = load_dataset(
"json",
name=d.name,
data_files=d.path,
streaming=False,
split=None,
) )
else: else:
raise ValueError( ds = load_dataset(
"unhandled dataset load: local path exists, but is neither a directory or a file" d.path,
streaming=False,
use_auth_token=use_auth_token,
) )
elif ds_from_hub:
ds = load_dataset(
d.path,
name=d.name,
streaming=False,
data_files=d.data_files,
use_auth_token=use_auth_token,
)
else: else:
fp = hf_hub_download( fp = hf_hub_download(
repo_id=d.path, repo_id=d.path,
repo_type="dataset", repo_type="dataset",
filename=d.data_files, filename=d.data_files,
) )
ds = load_dataset( ds = load_dataset("json", data_files=fp, streaming=False, split=None)
"json", name=d.name, data_files=fp, streaming=False, split=None
)
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if d.shards: if d.shards:
if "train" in ds: if split in ds:
ds = ds.shuffle(seed=seed)["train"].shard( ds = ds.shuffle(seed=seed)[split].shard(
num_shards=d.shards, index=0 num_shards=d.shards, index=0
) )
else: else:
@@ -167,8 +146,8 @@ def load_tokenized_prepared_datasets(
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds: if split in ds:
ds = ds["train"] ds = ds[split]
if ds_strategy := load(d.type, tokenizer, cfg): if ds_strategy := load(d.type, tokenizer, cfg):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
@@ -266,21 +245,25 @@ def load_tokenized_prepared_datasets(
suffix = "" suffix = ""
if ":load_" in d.type: if ":load_" in d.type:
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?" suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}") logging.error(
f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
)
raise ValueError( raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}" f"unhandled prompt tokenization strategy: {d.type} {suffix}"
) )
LOG.info("merging datasets") logging.info("tokenizing, merging, and shuffling master dataset")
dataset = concatenate_datasets(datasets)
if len(datasets) > 1: samples: List[int] = []
LOG.info("shuffle merged datasets") for d in datasets:
dataset = dataset.shuffle(seed=seed) samples = samples + list(d)
dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") logging.info(
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
)
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info( logging.info(
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
) )
dataset.push_to_hub( dataset.push_to_hub(
@@ -331,68 +314,90 @@ def load_prepare_datasets(
use_auth_token = cfg.hf_use_auth_token use_auth_token = cfg.hf_use_auth_token
try: try:
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info( logging.info(
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
) )
dataset = load_dataset( dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
pass pass
if dataset: if dataset:
... ...
elif any(prepared_ds_path.glob("*")): elif any(prepared_ds_path.glob("*")):
LOG.info( logging.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..." f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
) )
dataset = load_from_disk(str(prepared_ds_path)) dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared packed dataset loaded from disk...") logging.info("Prepared packed dataset loaded from disk...")
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info( logging.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
) )
dataset.push_to_hub( dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
else: else:
dataset = load_tokenized_prepared_datasets( dataset_train = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path "train", tokenizer, cfg, default_dataset_prepared_path
) )
dataset_test = load_tokenized_prepared_datasets(
"test", tokenizer, cfg, default_dataset_prepared_path
)
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.seed: if cfg.seed:
dataset = dataset.shuffle(seed=cfg.seed) dataset = dataset.shuffle(seed=cfg.seed)
constant_len_dataset = ConstantLengthDataset( constant_len_dataset_train = ConstantLengthDataset(
tokenizer, tokenizer,
[dataset], [dataset["train"]],
seq_length=max_packed_sequence_len, seq_length=max_packed_sequence_len,
) )
LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}") constant_len_dataset_test = ConstantLengthDataset(
dataset = Dataset.from_list(list(constant_len_dataset)) tokenizer,
[dataset["test"]],
seq_length=max_packed_sequence_len,
)
logging.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
)
dataset_train = Dataset.from_list(list(constant_len_dataset_train))
dataset_test = Dataset.from_list(list(constant_len_dataset_test))
# filter out bad data # filter out bad data
# TODO convert to dataset.filter(...) dataset_train = Dataset.from_list(
dataset = Dataset.from_list(
[ [
d d
for d in dataset for d in dataset_train
if len(d["input_ids"]) <= cfg.sequence_len if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0 and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"]) and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"]) and len(d["input_ids"]) == len(d["labels"])
] ]
) )
# filter out bad data
dataset_test = Dataset.from_list(
[
d
for d in dataset_test
if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info( logging.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}" f"Saving packed prepared dataset to disk... {prepared_ds_path}"
) )
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info( logging.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
) )
dataset.push_to_hub( dataset.push_to_hub(
@@ -400,12 +405,17 @@ def load_prepare_datasets(
private=True, private=True,
) )
else: else:
# dataset_train = load_tokenized_prepared_datasets(
dataset = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path "train", tokenizer, cfg, default_dataset_prepared_path
) )
# dataset_test = load_tokenized_prepared_datasets(
# "test", tokenizer, cfg, default_dataset_prepared_path
# )
# dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
LOG.info( logging.info(
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
) )
dataset = dataset.shard( dataset = dataset.shard(
@@ -414,51 +424,10 @@ def load_prepare_datasets(
) )
if cfg.val_set_size: if cfg.val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
to_hash_train = ( train_dataset = dataset["train"]
dataset._fingerprint # pylint: disable=protected-access eval_dataset = dataset["test"]
+ "|" elif "train" in dataset:
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ str(cfg.seed or 42)
)
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ str(cfg.seed or 42)
)
train_fingerprint = hashlib.md5(
to_hash_train.encode(), usedforsecurity=False
).hexdigest()
test_fingerprint = hashlib.md5(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
if is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
if not is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]
else: else:
@@ -570,7 +539,7 @@ def encode_pretraining(tokenizer, max_tokens, examples):
"attention_mask": [seq.tolist() for seq in new_attention_mask], "attention_mask": [seq.tolist() for seq in new_attention_mask],
} }
LOG.debug(len(ret["input_ids"])) logging.debug(len(ret["input_ids"]))
return ret return ret

View File

@@ -1,288 +0,0 @@
# pylint: skip-file
import hashlib
import itertools
import logging
import math
from typing import Any, Callable, List, Union
import numba
import numpy as np
from torch.utils.data import DistributedSampler, Sampler
LOG = logging.getLogger("axolotl.utils.dataloader")
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result, len(a)
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
"""
:param lengths: array of lengths of each sample
:param lengths_cumsum: cumulative sum of consecutive lengths
:param rank: rank for this process
:param c: length of tokens per batch
:param n: number of ranks
:return:
"""
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
result_totseqs = []
while True:
# binary search [left, right)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length left
batch, tot_seqs = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n):
"""
Chunk data into tuples of length n
"""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(itertools.islice(it, n)):
yield batch
def hash_indices(lst: List[int]) -> str:
# Convert the list of integers to a string representation
concatenated = ",".join(map(str, lst))
# Generate the hash
sha256 = hashlib.sha256()
sha256.update(concatenated.encode())
return sha256.hexdigest()
class MultipackDistributedDataloader:
"""Unpadded data loading using Multipack.
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
"""
def __init__(
self,
dataset: Any,
collate_fn: Callable,
seq_max_length: int = 2048,
batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
):
# Dataset
self.dataset = dataset
self.lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
)
assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier
self.sampler = sampler
self.batch_size = batch_size
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn
self.num_replicas = 1
self.rank = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
if self.sampler:
indices = [idx for idx in self.sampler]
else:
indices = range(0, len(self.dataset))
LOG.info(hash_indices(indices))
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches, totseqs
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batches in chunk(
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
):
chunked_data = []
attn_mask_cum_idx = 0
for batch in batches:
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return
def _len_est(self):
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // self.device_count
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return (
math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.seq_max_length
// self.batch_size
)
- 1
)
def __len__(self):
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
# the same share of total tokens
# if not self.eff_total_used:
# batches, _ = self.generate_batches(set_stats=True)
# LOG.info(
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
# f"actual packing efficiency: {self.efficiency()}"
# )
return max(1, self._len_est())
def len_w_stats(self):
if not self.eff_total_used:
batches, _ = self.generate_batches(set_stats=True)
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"actual packing efficiency: {self.efficiency()}"
)
return max(1, self._len_est())
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -10,6 +10,3 @@ class DictDefault(Dict):
def __missing__(self, key): def __missing__(self, key):
return None return None
def __or__(self, other):
return DictDefault(super().__or__(other))

View File

@@ -1,41 +0,0 @@
"""
utility helpers for distributed checks
"""
import torch.distributed as dist
from accelerate import Accelerator
accelerate = None # pylint: disable=invalid-name
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed():
"""
Check if distributed training is initialized.
"""
global accelerate # pylint: disable=global-statement
if not accelerate:
accelerate = Accelerator()
return dist.is_available() and dist.is_initialized()
def barrier():
"""
Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further.
"""
if is_distributed():
dist.barrier()
def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
"""
if not is_distributed():
return True
return dist.get_rank() == 0

View File

@@ -22,9 +22,6 @@ from transformers import ( # noqa: F401
) )
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl")
if TYPE_CHECKING: if TYPE_CHECKING:
from peft import PeftConfig # noqa: F401 from peft import PeftConfig # noqa: F401
@@ -32,66 +29,31 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401 from axolotl.utils.dict import DictDefault # noqa: F401
def smart_tokenizer_and_embedding_resize( def load_tokenizer(
tokenizer: transformers.PreTrainedTokenizer, tokenizer_config,
model: transformers.PreTrainedModel, tokenizer_type,
resize_token_embeddings_multiple: Optional[int] = None, cfg,
): ):
"""Resize tokenizer and embedding.
Note: This function resizes the tokenizer to accommodate additional special tokens and the
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
have been added, the function computes the average embedding values of the existing embeddings
and sets those values for the new special token embeddings. This is done separately for the input
embeddings and output embeddings of the model.
"""
old_tokens = model.get_input_embeddings().weight.data.shape[0]
num_new_tokens = len(tokenizer) - old_tokens
embeddings_len = (
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
* resize_token_embeddings_multiple
if resize_token_embeddings_multiple
else len(tokenizer)
)
model.resize_token_embeddings(embeddings_len)
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def load_tokenizer(cfg):
tokenizer_kwargs = {}
use_fast = True # this is the default use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None: if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None: if tokenizer_type:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224 tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
)
tokenizer_cls = AutoTokenizer logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
if cfg.tokenizer_type: logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
tokenizer_cls = getattr(transformers, cfg.tokenizer_type) logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if tokenizer.__class__.__name__ in [ if tokenizer.__class__.__name__ in [
"LlamaTokenizer", "LlamaTokenizer",
@@ -99,11 +61,6 @@ def load_tokenizer(cfg):
]: ]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -118,44 +75,38 @@ def load_tokenizer(cfg):
def load_model( def load_model(
cfg, tokenizer base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]] ):
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model for a given configuration and tokenizer. Load a model from a base model and a model type.
""" """
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = "llama" in base_model or (
"llama" in base_model cfg.model_type and "llama" in cfg.model_type.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model
) )
if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not cfg.inference: if cfg.device not in ["mps", "cpu"] and not cfg.inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.flash_attn import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn,
)
LOG.info("patching with flash attention") logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn() replace_llama_attn_with_flash_attn()
elif cfg.is_llama_derived_model and cfg.xformers_attention: elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention, hijack_llama_attention,
) )
LOG.info("patching with xformers attention") logging.info("patching with xformers attention")
hijack_llama_attention() hijack_llama_attention()
elif cfg.is_llama_derived_model and cfg.sdp_attention: elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention, hijack_llama_sdp_attention,
) )
LOG.info("patching with sdp attention") logging.info("patching with sdp attention")
hijack_llama_sdp_attention() hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention: elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import ( from axolotl.monkeypatch.llama_landmark_attn import (
@@ -163,7 +114,7 @@ def load_model(
patch_llama_with_landmark_attn, patch_llama_with_landmark_attn,
) )
LOG.info("patching with landmark attention") logging.info("patching with landmark attention")
patch_llama_with_landmark_attn() patch_llama_with_landmark_attn()
# Note: This might overwrite previous additional_special_tokens # Note: This might overwrite previous additional_special_tokens
@@ -174,17 +125,9 @@ def load_model(
replace_llama_rope_with_xpos_rope, replace_llama_rope_with_xpos_rope,
) )
LOG.info("patching with xpos rope") logging.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope() replace_llama_rope_with_xpos_rope()
if cfg.is_llama_derived_model and (
cfg.max_packed_sequence_len or cfg.sample_packing
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
if cfg.bf16 or cfg.bfloat16: if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
@@ -199,24 +142,18 @@ def load_model(
replace_peft_model_with_int4_lora_model() replace_peft_model_with_int4_lora_model()
except Exception as err: except Exception as err:
LOG.exception(err) logging.exception(err)
raise err raise err
if not cfg.gptq and ( try:
(cfg.adapter == "lora" and load_in_8bit) from peft import prepare_model_for_kbit_training
or (cfg.adapter == "qlora" and cfg.load_in_4bit) except ImportError:
): # For backward compatibility
try: from peft import (
from peft import prepare_model_for_kbit_training prepare_model_for_int8_training as prepare_model_for_kbit_training,
except ImportError: )
# For backward compatibility
from peft import (
prepare_model_for_int8_training as prepare_model_for_kbit_training,
)
model_kwargs = {} model_kwargs = {}
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.adapter == "qlora" and cfg.load_in_4bit: if cfg.adapter == "qlora" and cfg.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
@@ -248,7 +185,7 @@ def load_model(
if len(files) > 0: if len(files) > 0:
model_path = str(files[0]) model_path = str(files[0])
else: else:
LOG.warning( logging.warning(
"unable to find a cached model file, this will likely fail..." "unable to find a cached model file, this will likely fail..."
) )
model_path = str(cache_model_path) model_path = str(cache_model_path)
@@ -265,23 +202,17 @@ def load_model(
else True, else True,
) )
load_in_8bit = False load_in_8bit = False
elif cfg.is_llama_derived_model and not cfg.trust_remote_code: elif cfg.is_llama_derived_model:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {} config = LlamaConfig.from_pretrained(base_model_config)
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
**model_kwargs, **model_kwargs,
) )
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -310,13 +241,13 @@ def load_model(
# device=cfg.device, # device=cfg.device,
# ) # )
# model.train() # sets to train instead of eval mode # model.train() # sets to train instead of eval mode
elif model_type and not cfg.trust_remote_code: elif model_type:
model = getattr(transformers, model_type).from_pretrained( model = getattr(transformers, model_type).from_pretrained(
base_model, base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
@@ -333,85 +264,69 @@ def load_model(
and cfg.sequence_len > config.max_seq_len and cfg.sequence_len > config.max_seq_len
): ):
config.max_seq_len = cfg.sequence_len config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}") logging.warning(f"increasing context length to {cfg.sequence_len}")
elif ( elif (
hasattr(config, "max_sequence_length") hasattr(config, "max_sequence_length")
and config.max_sequence_length and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length and cfg.sequence_len > config.max_sequence_length
): ):
config.max_sequence_length = cfg.sequence_len config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}") logging.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
except Exception as err: # pylint: disable=broad-exception-caught except Exception as err: # pylint: disable=broad-exception-caught
LOG.error( logging.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM" "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
) )
LOG.exception(err) logging.exception(err)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
smart_tokenizer_and_embedding_resize( embeddings_len = math.ceil(len(tokenizer) / 32) * 32
tokenizer, model.resize_token_embeddings(embeddings_len)
model,
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
)
if ( if (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings and model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings and cfg.sequence_len >= model.config.max_position_embeddings
): ):
LOG.warning( logging.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
) )
model.config.max_position_embeddings = cfg.sequence_len model.config.max_position_embeddings = cfg.sequence_len
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
if not cfg.gptq and ( if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training( model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing model, use_gradient_checkpointing=cfg.gradient_checkpointing
) )
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to model, lora_config = load_adapter(model, cfg, adapter)
# convert them back to fp16/bf16 for flash-attn compatibility.
if cfg.flash_attention and cfg.is_llama_derived_model:
for name, module in model.named_modules():
if "norm" in name:
module.to(torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
if cfg.gptq: if cfg.gptq:
# Scales to half # Scales to half
LOG.info("Fitting 4bit scales and zeros to half") logging.info("Fitting 4bit scales and zeros to half")
for _, module in model.named_modules(): for _, module in model.named_modules():
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str( if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
type(module) type(module)
@@ -437,15 +352,12 @@ def load_model(
if param.requires_grad: if param.requires_grad:
requires_grad.append(f"{name}: {param.requires_grad}") requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0: if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates") logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False model.config.use_cache = False
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.transform(model) model = BetterTransformer.transform(model)
if cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, lora_config
@@ -455,8 +367,6 @@ def load_adapter(model, cfg, adapter):
if adapter is None: if adapter is None:
return model, None return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg) return load_lora(model, cfg)
if adapter == "llama-adapter": if adapter == "llama-adapter":
@@ -476,7 +386,7 @@ def load_llama_adapter(model, cfg):
) )
if cfg.lora_model_dir: if cfg.lora_model_dir:
LOG.info("Loading pretained LORA") logging.info("Loading pretained LORA")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.lora_model_dir,
@@ -523,7 +433,7 @@ def load_lora(model, cfg):
bits = 8 bits = 8
linear_names = find_all_linear_names(bits, model) linear_names = find_all_linear_names(bits, model)
LOG.info(f"found linear modules: {repr(linear_names)}") logging.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names)) lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config = LoraConfig( lora_config = LoraConfig(

View File

@@ -1,9 +1,6 @@
"""Module for custom LRScheduler class""" """Module for custom LRScheduler class"""
import math
from functools import partial
from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class InterpolatingLogScheduler(LRScheduler): class InterpolatingLogScheduler(LRScheduler):
@@ -45,58 +42,3 @@ class InterpolatingLogScheduler(LRScheduler):
lrs = [self.max_lr for base_lr in self.base_lrs] lrs = [self.max_lr for base_lr in self.base_lrs]
return lrs return lrs
def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float
):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
)
def get_cosine_schedule_with_quadratic_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_quadratic_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)

View File

@@ -5,8 +5,6 @@ import logging
from termcolor import colored from termcolor import colored
LOG = logging.getLogger("axolotl")
def check_dataset_labels(dataset, tokenizer): def check_dataset_labels(dataset, tokenizer):
# the dataset is already shuffled, so let's just check the first 5 elements # the dataset is already shuffled, so let's just check the first 5 elements
@@ -34,7 +32,7 @@ def check_example_labels(example, tokenizer):
) )
colored_tokens.append(colored_token) colored_tokens.append(colored_token)
LOG.info(" ".join(colored_tokens)) logging.info(" ".join(colored_tokens))
LOG.info("\n\n\n") logging.info("\n\n\n")
return " ".join(colored_tokens) return " ".join(colored_tokens)

View File

@@ -1,226 +1,29 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import importlib import importlib
import logging import logging
import math import math
import os import os
import sys import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional
import bitsandbytes as bnb import bitsandbytes as bnb
import numpy as np
import torch.cuda import torch.cuda
import transformers import transformers
from datasets import Dataset, set_caching_enabled
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from transformers import EarlyStoppingCallback, Trainer
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
GPUStatsCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SavePeftModelCallback, SavePeftModelCallback,
) )
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
)
LOG = logging.getLogger("axolotl")
@torch.jit.script class OneCycleLRSchedulerTrainer(Trainer):
def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
):
# Flatten the logits, labels, and weights tensors
logits = logits.view(
-1, logits.size(-1)
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
# Compute the unweighted cross entropy loss
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
# Apply the weights to the losses and compute their sum
return (weights * losses).sum()
@torch.jit.script
def create_weighted_mask(labels: torch.Tensor):
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
if len(labels.shape) == 1:
labels = labels.unsqueeze(0)
weights = torch.zeros_like(labels).float()
for i in range(labels.shape[0]):
mask = labels[i] != -100
# Create a tensor to track group ids
group_ids = torch.zeros_like(labels[i]).int()
curr_group_id = 0
for j in range(1, len(labels[i])):
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
curr_group_id += 1 # start new group
group_ids[j] = (
curr_group_id if mask[j] else 0
) # assign group id if unmasked label
# Count only unmasked labels in each group
group_counts = torch.bincount(group_ids[mask])
mask_weights = torch.zeros_like(labels[i]).float()
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
weights[i] = mask_weights
return weights.squeeze() # squeeze the output to match the input dimension
def trainer_weighted_loss(model_output, labels, shift_labels=True):
logits = (
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
)
if shift_labels:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
weights = create_weighted_mask(labels)
return weighted_cross_entropy(logits, labels, weights)
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_train_dataloader()
def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_eval_dataloader(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
""" """
Trainer subclass that uses the OneCycleLR scheduler Trainer subclass that uses the OneCycleLR scheduler
""" """
@@ -249,121 +52,10 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
return self.lr_scheduler return self.lr_scheduler
def add_position_ids(sample): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
sample["position_ids"] = torch.arange(len(sample["input_ids"])) total_num_steps = int(
return sample math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len
@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.sample_packing:
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
add_position_ids, num_proc=os.cpu_count()
)
if eval_dataset:
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
add_position_ids, num_proc=os.cpu_count()
)
return train_dataset, eval_dataset
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
if not cfg.total_num_tokens:
LOG.info("calculating total_num_tokens")
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
cfg.total_num_tokens = total_num_tokens
if cfg.sample_packing_eff_est:
total_num_steps = (
# match count to len est in dataloader
(
math.floor(
0.99
* cfg.total_num_tokens
/ cfg.sample_packing_eff_est
/ cfg.sequence_len
// cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1))
)
- 1
)
* cfg.num_epochs
)
LOG.info(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
)
else:
sampler = RandomSampler(train_dataset)
data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size,
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
collate_fn=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding="longest",
),
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency()
LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int(
math.floor(
data_loader_len
* cfg.micro_batch_size
* cfg.num_epochs
// cfg.batch_size
)
)
LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
)
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
LOG.info(f"total_num_steps: {total_num_steps}")
return total_num_steps
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states:
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
if cfg.fsdp_config.fsdp_state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.fsdp:
setup_fsdp_envs(cfg)
warmup_steps = ( warmup_steps = (
cfg.warmup_steps cfg.warmup_steps
if cfg.warmup_steps is not None if cfg.warmup_steps is not None
@@ -411,9 +103,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.fsdp_config: if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
# deepspeed # deepspeed
if ( if (
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
@@ -435,34 +124,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.max_grad_norm: if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
if cfg.hub_model_id: if cfg.push_to_hub_model_id:
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id training_arguments_kwargs["push_to_hub_model_id"] = cfg.push_to_hub_model_id
training_arguments_kwargs["push_to_hub"] = True training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy: training_args = transformers.TrainingArguments(
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
if cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.val_set_size == 0:
evaluation_strategy = "no"
elif cfg.eval_steps < 1:
# eval every epoch
evaluation_strategy = "epoch"
else:
# eval every eval_steps steps
evaluation_strategy = "steps"
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size, per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None if cfg.eval_batch_size is not None
@@ -471,12 +137,12 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps, eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs, num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate, learning_rate=cfg.learning_rate,
evaluation_strategy=evaluation_strategy, evaluation_strategy="steps",
save_strategy="steps" if cfg.save_steps else "epoch", save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None, eval_steps=cfg.eval_steps,
save_steps=cfg.save_steps, save_steps=cfg.save_steps,
output_dir=cfg.output_dir, output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4, save_total_limit=3,
load_best_model_at_end=( load_best_model_at_end=(
cfg.load_best_model_at_end is not False cfg.load_best_model_at_end is not False
and cfg.val_set_size > 0 and cfg.val_set_size > 0
@@ -494,8 +160,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine", else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
**training_arguments_kwargs, **training_arguments_kwargs,
) )
@@ -567,7 +231,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = [] callbacks = []
callbacks.append(GPUStatsCallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background # TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience: if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback( early_stop_cb = EarlyStoppingCallback(
@@ -590,11 +253,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.collator_pad_to_longest: if cfg.collator_pad_to_longest:
data_collator_kwargs["padding"] = "longest" data_collator_kwargs["padding"] = "longest"
else: else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check data_collator_kwargs["pad_to_multiple_of"] = 8
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if cfg.is_llama_derived_model and cfg.landmark_attention: if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import ( from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens, add_mem_tokens,
get_mem_id, get_mem_id,
@@ -603,7 +266,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
set_model_mem_id(model, tokenizer) set_model_mem_id(model, tokenizer)
LOG.info("Adding landmark attention tokens to dataset") logging.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]: for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map( dataset = dataset.map(
@@ -615,14 +278,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
trainer_cls = ( trainer_cls = (
OneCycleLRSchedulerTrainer OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
else AxolotlTrainer else transformers.Trainer
) )
trainer = trainer_cls( trainer = trainer_cls(
model=model, model=model,
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
args=training_args, args=training_args,
data_collator=DataCollatorForSeq2Seq( data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, tokenizer,
return_tensors="pt", return_tensors="pt",
**data_collator_kwargs, **data_collator_kwargs,

View File

@@ -0,0 +1,105 @@
"""Module for validating config files"""
import logging
import torch
def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
if cfg.batch_size:
logging.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.load_4bit:
raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
)
if cfg.adapter == "qlora":
if cfg.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if cfg.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if cfg.gptq:
raise ValueError("Can't merge qlora if gptq")
if cfg.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if cfg.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if cfg.gptq:
raise ValueError("Can't load qlora if gptq")
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora":
logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.trust_remote_code:
logging.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
)
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
if (
cfg.base_model and "mpt" in cfg.base_model.lower()
) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
if cfg.pretraining_dataset and cfg.group_by_length:
logging.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3

View File

@@ -9,8 +9,6 @@ def setup_wandb_env_vars(cfg):
elif cfg.wandb_project and len(cfg.wandb_project) > 0: elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0: if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:

File diff suppressed because one or more lines are too long

View File

@@ -1,30 +0,0 @@
"""
Unit tests for the monkeypatch utils
"""
import unittest
import torch
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
class TestMonkeyPatchUtils(unittest.TestCase):
"""
Unit test class for monkeypatch utils
"""
def test_get_cu_seqlens_1d(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
def test_get_cu_seqlens_from_pos_ids_1d(self):
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
self.assertTrue(
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
if __name__ == "__main__":
unittest.main()

View File

@@ -72,13 +72,6 @@ class DictDefaultTest(unittest.TestCase):
assert cfg.random_key is None, "DictDefault should return None for missing keys" assert cfg.random_key is None, "DictDefault should return None for missing keys"
def test_dict_or(self):
cfg = DictDefault({}) | DictDefault({})
assert (
cfg.random_key is None
), "DictDefault should return None for missing keys after | operation"
def test_dict_nested_missingparentkey(self): def test_dict_nested_missingparentkey(self):
""" """
Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist. Due to subclassing Dict, DictDefault will error if we try to access a nested key whose parent key does not exist.

View File

@@ -1,44 +0,0 @@
"""
Unit tests for the monkey patch for expand mask to handle packed sequences
"""
import unittest
import torch
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
class TestExpandMask(unittest.TestCase):
"""
Test class for attention mask expansion for packed sequences
"""
def test_output(self):
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
dtype = torch.float32
expected_output = torch.tensor(
[
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
]
],
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
]
],
]
)
# Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
if __name__ == "__main__":
unittest.main()

View File

@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
} }
) )
def test_increments_attention(self): def test_resets_attention(self):
prompter = AlpacaPrompter("chat") prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy( strat = AlpacaPromptTokenizingStrategy(
prompter, prompter,
@@ -55,14 +55,10 @@ class TestPacking(unittest.TestCase):
# first example doesn't have mask reset # first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1 assert example["attention_mask"][0] == 1
assert example["position_ids"][0] == 0
assert example["position_ids"][1] == 1
# but subsequent one does # but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 2 assert example["attention_mask"][next_bos_index] == 0
assert example["position_ids"][next_bos_index] == 0
assert example["position_ids"][next_bos_index + 1] == 1
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -4,24 +4,20 @@ import logging
import unittest import unittest
from pathlib import Path from pathlib import Path
from transformers import AutoTokenizer, LlamaTokenizer from transformers import AutoTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import ( from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy, InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter, SystemDataPrompter,
) )
from axolotl.prompt_strategies.llama2_chat import (
Llama2ChatPrompter,
LLama2ChatTokenizingStrategy,
)
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy,
) )
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
LOG = logging.getLogger("axolotl") logging.basicConfig(level="INFO")
class TestPromptTokenizationStrategies(unittest.TestCase): class TestPromptTokenizationStrategies(unittest.TestCase):
@@ -134,95 +130,8 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
"output": "Hi! How can I help?", "output": "Hi! How can I help?",
} }
example = strat.tokenize_prompt(sample) example = strat.tokenize_prompt(sample)
assert example["input_ids"][0:5] == [ assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
1, assert example["input_ids"][3] == 11889 # USER
28962,
1254,
12665,
29901,
] # "<s>SYSTEM:"
assert example["input_ids"][5:7] == [671, 20118] # " use cot"
assert example["input_ids"][8] == 11889 # USER
class Llama2ChatTokenizationTest(unittest.TestCase):
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
# woraround because official Meta repos are not open
def test_llama2_chat_integration(self):
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
data = fin.read()
conversation = json.loads(data)
with open(
Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json",
encoding="utf-8",
) as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
prompter = Llama2ChatPrompter()
strat = LLama2ChatTokenizingStrategy(
prompter,
self.tokenizer,
False,
4096,
)
example = strat.tokenize_prompt(conversation)
for fields in ["input_ids", "attention_mask", "labels"]:
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])
def compare_with_transformers_integration(self):
# this needs transformers >= v4.31.0
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
from transformers.pipelines.conversational import Conversation
# from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT
# broken as of 23/7/20
# see https://github.com/huggingface/transformers/pull/24935
# pylint: disable=C0103
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
data = fin.read()
conversation = json.loads(data)
with open(
Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json",
encoding="utf-8",
) as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
user_input = []
answers = []
for msg in conversation["conversations"]:
if msg["from"] == "human":
user_input.append(msg["value"])
else:
answers.append(msg["value"])
hf_conf = Conversation(
text=user_input[-1],
past_user_inputs=[B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + user_input[0]]
+ user_input[1:-1],
generated_responses=answers,
)
# pylint: disable=W0212
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
self.assertEqual(
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase):
) )
) )
assert "use cot" in res assert "use cot" in res
assert res.startswith("SYSTEM:") assert res.startswith("use cot")
assert "### Instruction:" not in res assert "### Instruction:" not in res
assert "### Input:" not in res assert "### Input:" not in res
assert "alpacas" in res assert "alpacas" in res

View File

@@ -13,22 +13,17 @@ class TestTokenizers(unittest.TestCase):
""" """
def test_default_use_fast(self): def test_default_use_fast(self):
cfg = DictDefault( cfg = DictDefault({})
{ tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
"tokenizer_config": "huggyllama/llama-7b",
}
)
tokenizer = load_tokenizer(cfg)
assert "Fast" in tokenizer.__class__.__name__ assert "Fast" in tokenizer.__class__.__name__
def test_dont_use_fast(self): def test_dont_use_fast(self):
cfg = DictDefault( cfg = DictDefault(
{ {
"tokenizer_config": "huggyllama/llama-7b",
"tokenizer_use_fast": False, "tokenizer_use_fast": False,
} }
) )
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" not in tokenizer.__class__.__name__ assert "Fast" not in tokenizer.__class__.__name__

View File

@@ -6,8 +6,8 @@ from typing import Optional
import pytest import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase): class ValidationTest(unittest.TestCase):
@@ -268,7 +268,7 @@ class ValidationTest(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"optimizer": None, "optimizer": None,
"adam_epsilon": 0.0001, "adamw_epsilon": 0.0001,
} }
) )
@@ -283,7 +283,7 @@ class ValidationTest(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"optimizer": "adafactor", "optimizer": "adafactor",
"adam_beta1": 0.0001, "adamw_beta1": 0.0001,
} }
) )
@@ -298,9 +298,9 @@ class ValidationTest(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"optimizer": "adamw_bnb_8bit", "optimizer": "adamw_bnb_8bit",
"adam_beta1": 0.9, "adamw_beta1": 0.0001,
"adam_beta2": 0.99, "adamw_beta2": 0.0001,
"adam_epsilon": 0.0001, "adamw_epsilon": 0.0001,
} }
) )
@@ -313,27 +313,3 @@ class ValidationTest(unittest.TestCase):
) )
validate_config(cfg) validate_config(cfg)
def test_packing(self):
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
"sample_packing": True,
}
)
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)