Compare commits

..

55 Commits

Author SHA1 Message Date
Wing Lian
b4d84d56d5 support for batched sharegpt tokenization to skip bad data 2023-10-06 15:03:07 -04:00
NanoCode012
669f1d052c Fix: Higher vram usage for mistral and sample_packing (#691)
* Fix: Higher vram usage for mistral and sample_packing

* chore: update comment

* chore: lint
2023-10-06 12:33:43 -04:00
Abhishek Mishra
d4a88e4eca Adding qlora config for Mistral (#675)
* Adding qlora config for Mistral

Contains fix for Mistral FA issue - ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to  call tokenizer.padding_side  = 'left' before tokenizing the input.

Fix for now is to set sample_packing: true and pad_to_sequence_len: true

* Renamed to qlora.yml
2023-10-06 21:05:56 +09:00
Wing Lian
2d60ba3a6e flash_attention + sample packing for stablelm 3b (#671)
* stablelm epoch fa patch

* is causal for fa

* working stablelm fa w packing

* chore: pre-commit linting
2023-10-05 16:03:43 -04:00
NanoCode012
eb480dfd68 Fix: ValueError when FA + Mistral when padding_side=right (#681)
* Fix: ValueError when FA + Mistral when padding_side=right

* fix: remove tokenizer class check
2023-10-06 04:12:54 +09:00
NanoCode012
133e676bcc Feat: Set WORKDIR to /workspace/axolotl (#679) 2023-10-06 04:09:14 +09:00
NanoCode012
69fac9a020 Fix: Future deprecation warning with use_auth_token (#680) 2023-10-06 03:56:18 +09:00
NanoCode012
e0b7eeabfd Fix(tokenizer): Set rstrip,lstrip,norm to False (#678) 2023-10-06 03:50:49 +09:00
NanoCode012
43856c0a39 Fix(version): Update FA to work with Mistral SWA (#673) 2023-10-04 21:32:19 +09:00
NanoCode012
e62d5901b5 chore: Clean up repetitive model kwargs (#670) 2023-10-04 20:41:26 +09:00
NanoCode012
697c50d408 Feat: Allow usage of native Mistral FA when no sample_packing (#669)
* Allow usage of native Mistral FA when no sample_packing

* fix: do not apply custom patch when sample_pack off

* chore: lint

* chore: pin transformer to v4.35.0.dev0

* fix: split sample_packing to separate test
2023-10-04 20:40:47 +09:00
NanoCode012
90e0d673f7 Feat: Add config yaml to section for reprod in bug-report.yaml (#667)
* Update bug-report.yaml

* Update bug-report.yaml

* Update bug-report.yaml
2023-10-03 23:38:42 +09:00
Wing Lian
2642caedf2 refactor to set eval_batch_size earlier if unset, so we can warn if mismatched (#662) 2023-10-02 21:08:07 -04:00
Wing Lian
f34648c8b9 remove patch fix for phi (#664) 2023-10-02 21:07:41 -04:00
Wing Lian
e50a64e85e prepared dataset caching, other misc fixes (#665)
* prepared dataset caching, other misc fixes

* also don't load from disk cache unless explicit
2023-10-02 21:07:24 -04:00
Wing Lian
f4868d733c make sure we also run CI tests when requirements.txt changes (#663) 2023-10-02 08:43:40 -04:00
Napuh
a7e56d83c2 removed duplicate on requirements.txt (#661) 2023-10-02 08:40:05 -04:00
Wing Lian
5b0bc48fbc add mistral e2e tests (#649)
* mistral e2e tests

* make sure to enable flash attention for the e2e tests

* use latest transformers full sha

* uninstall first
2023-09-29 00:22:40 -04:00
Kyle Corbitt
9ec20777ba Make dataset_processes configurable (#651)
I'm using the Axolotl script to train models on https://modal.com serverless GPUs. Unfortunately, their environment seems to have some kind of bug where if I try to run `datasets.filter` with too high a `num_proc`, it throws an error and dies.

This PR adds a new configuration option `dataset_processes`, which lets you explicitly set the number of processes used to map/filter the dataset. If not included, this defaults to the current behavior of setting that to `os.cpu_count()`.
2023-09-29 00:22:22 -04:00
ich
590d6032fd Fix bug when using pretokenized datasets (#652)
* fix pretokenized datasets readme

* check if dataset type is not set to handle pretokenized datasets
2023-09-28 22:54:10 -04:00
Wing Lian
409ca0f21c add support for defined train split (#654) 2023-09-28 20:14:14 -04:00
Wing Lian
8662e8ffe8 don't strip the prompt for check since we don't strip to tokenize anymore (#650) 2023-09-28 12:21:51 -04:00
Wing Lian
b2edaaeff6 fix for flash attn w mistral w/o sammple packing (#648) 2023-09-28 10:57:37 -04:00
Adarsh Shirawalmath
b88f51512a Update mistral/README.md (#647) 2023-09-28 10:24:56 -04:00
NanoCode012
eb41f76f92 Feat: Add example for Mistral (#644)
* Feat: Add example for Mistral

* chore: turn off flash

* chore: add is_mistral_derived_model

* chore: update following PR
2023-09-28 20:15:00 +09:00
NanoCode012
383f88d7a7 Fix(cfg): Add validation for save_strategy and eval_strategy (#633)
* Fix(cfg): Check save_strategy cfg conflict with save_steps

* Fix(cfg): Check evaluation_strategy cfg conflict with eval_steps

* chore: add extra check for steps only
2023-09-28 10:14:41 +09:00
Wing Lian
b6ab8aad62 Mistral flash attn packing (#646)
* add mistral monkeypatch

* add arg for decoder attention masl

* fix lint for duplicate code

* make sure to update transformers too

* tweak install for e2e

* move mistral patch to conditional
2023-09-27 18:41:00 -04:00
Napuh
85b0be2ba7 Warn users to login to HuggingFace (#645)
* added warning if user is not logged in HF

* updated doc to suggest logging in to HF
2023-09-27 17:43:35 -04:00
Ethan Smith
8fe0e633d2 Fix bug in dataset loading (#284)
* Fix bug in dataset loading

This fixes a bug when loading datasets. `d.data_files` is a list, so it cannot be directly passed to `hf_hub_download`

* Check type of data_files, and load accordingly
2023-09-27 13:41:31 -04:00
Felix Yan
d1236f2c41 Correct typos in datasets.py (#639) 2023-09-27 12:12:10 -04:00
Wing Lian
895f0a0723 skip some flash attn patches unless explicitly enabled (#643)
* skip some flash attn patches if explicitly disabled

* make the other patches optional
2023-09-27 12:11:07 -04:00
Wing Lian
e7d3e2dbb6 use fastchat conversations template (#578)
* use fastchat conversations template

* require fastchat (fschat) pip install

* handle roles dynamically from conversation

* tweak fastchat conversation with a monkeypatch to get individual turns

* fix up so it works with multiple conversation styles, and don't strip the turns

* fix sharegpt fixture now that we're using a more correct tokenization

* use a new prompter and support fastchat conversation type

* use sharegpt from prompt strategies now

* update docs, add chatml template

* add a newline after im_end token

* ensure we correctly set system message

* update per PR feedback to handle deprecated sharegpt types

* don't add duplicate wandb req

* make sharegpt fields configurable from yml

* llama2 fixes

* don't fail fatally when turns are improper
2023-09-27 12:10:45 -04:00
Wing Lian
60c7c48c97 update for recent transformers updates (#636)
* update for recent transformers updates

* fix checkpoint forward kwargs

* just pass args into torch checkpoint
2023-09-27 12:10:32 -04:00
Wing Lian
e8cbf50be6 attention_mask not needed for training (#642)
* attention_mask not needed for training

* specifically don't use attention mask for phi

* use a different check for phi

* small fixes since phi removed some values from their config
2023-09-27 11:12:08 -04:00
Wing Lian
d887ad86c3 eval_table isn't quite stable enough to be in default llama configs (#637) 2023-09-26 10:13:20 -04:00
NanoCode012
19a600a8b8 Feat: Add support for upstream FA2 (#626)
* Feat: Add support for upstream FA2

* chore: add is_falcon_derived_model: true to examples

* chore: add config to readme for documentation

* feat: add extra model types

* fix: remove old falcon flash patch

* chore: pin transformers and accelerate
2023-09-26 09:53:28 -04:00
Fernando Tarin Morales
5e5296a77c Added quotes to the pip install -e command to fix an incompatibility with shells that do glob expansion like zsh (#632) 2023-09-25 11:50:14 -04:00
mhenrichsen
f3d939016a Merge pull request #629 from OpenAccess-AI-Collective/chore/-change-default-model
default model changed
2023-09-25 09:32:01 +02:00
NanoCode012
cfbce020e9 Fix: Fail bf16 check when running on cpu during merge (#631) 2023-09-25 13:48:18 +09:00
mhenrichsen
4fecbfe5e1 default model changed 2023-09-24 18:52:53 +02:00
NanoCode012
67b9888630 Feat(doc): Add eval_sample_packing to doc (#625) 2023-09-23 13:11:27 +09:00
Maxime
923eb91304 tweak: improve base builder for smaller layers (#500) 2023-09-22 16:17:50 -04:00
Wing Lian
a363604dcf better handling and logging of empty sharegpt turns (#603) 2023-09-22 16:13:42 -04:00
Wing Lian
501958bb6f create a model card with axolotl badge (#624) 2023-09-22 16:13:26 -04:00
Wing Lian
c25ba7939b update README w deepspeed info (#605) 2023-09-22 00:15:52 -04:00
NanoCode012
d5f8589021 chore(callback): Remove old peft saving code (#510) 2023-09-22 12:31:33 +09:00
Wing Lian
03e59077a0 misc fixes to add gptq tests (#621)
* misc fixes to add gptq tests

* set bf16 needed for fa2
2023-09-21 21:52:12 -04:00
Wing Lian
97d3776ce6 split completion text to sequence_len (#616) 2023-09-21 21:51:25 -04:00
Wing Lian
2844eb22b6 run eval on the first step to get a baseline (#617)
* run eval on the first step to get a baseline

* wandb kleeps getting moved around by pre-commit ...
2023-09-21 21:51:09 -04:00
Wing Lian
e85d2eb06b let MAX_JOBS use the default since we're not resource constrained on our self-hosted runners (#427) 2023-09-21 20:36:30 -04:00
Wing Lian
196ff1181e skip the gpu memory checks if the device is set to 'auto' (#609)
* skip the gpu memory checks if the device is set to 'auto'

* skip gpu mem logging if cpu too

* don't worry about log_gpu_memory_usage since it calls another annotated fn

* rename decorator internal
2023-09-21 15:20:31 -04:00
Wing Lian
92512c390b ignore wandb to resolve isort headaches (#619) 2023-09-21 11:50:09 -04:00
Maxime
2fe95cdcc1 fix distributed devices (#612)
* fix distributed devices

* Update distributed.py

* Update distributed.py
2023-09-21 09:11:34 -04:00
Maxime
c1382e79b6 Create multi-node.md (#613)
* Create multi-node.md

* Update multi-node.md

* Update multi-node.md
2023-09-20 22:02:16 -04:00
Maxime
5d931cc042 Only run tests when a change to python files is made (#614)
* Update tests.yml

* Update .github/workflows/tests.yml

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-09-20 22:02:04 -04:00
80 changed files with 2953 additions and 720 deletions

View File

@@ -53,6 +53,13 @@ body:
validations:
required: true
- type: textarea
id: config
attributes:
label: Config yaml
description: |
Please attach the config yaml!
- type: textarea
id: possible-solution
attributes:

View File

@@ -4,7 +4,13 @@ on:
push:
branches:
- "main"
paths:
- '**.py'
- 'requirements.txt'
pull_request:
paths:
- '**.py'
- 'requirements.txt'
workflow_dispatch:
jobs:
@@ -40,7 +46,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install -e .
pip3 install -U -e .
pip3 install -r requirements-tests.txt
- name: Run tests
@@ -65,8 +71,8 @@ jobs:
- name: Install dependencies
run: |
pip3 install -e .
pip3 install flash-attn
pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn]
pip3 install -r requirements-tests.txt
- name: Run e2e tests

4
.gitignore vendored
View File

@@ -161,3 +161,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# WandB
# wandb creates a folder to store logs for training runs
wandb

View File

@@ -1,2 +1,3 @@
[settings]
profile=black
known_third_party=wandb

View File

@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-nested-blocks,

View File

@@ -31,6 +31,7 @@ Features:
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
- [Training w/ Deepspeed](#training-with-deepspeed)
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-)
@@ -86,7 +87,7 @@ git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
pip3 install -e .[flash-attn]
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install -U git+https://github.com/huggingface/peft.git
# finetune lora
@@ -121,8 +122,13 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
3. Install axolotl along with python dependencies
```bash
pip3 install packaging
pip3 install -e .[flash-attn]
pip3 install -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Huggingface to use gated models/datasets.
```bash
huggingface-cli login
```
Get the token at huggingface.co/settings/tokens
- LambdaLabs
<details>
@@ -157,7 +163,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
cd axolotl
pip3 install packaging
pip3 install -e .[flash-attn]
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy
```
@@ -179,7 +185,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
- `sharegpt`: conversations where `from` is `human`/`gpt`
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -244,6 +250,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"article": "...", "question": "...", "answer": "..."}
```
- `context_qa.load_v2`: in context question answering (alternate)
```json
{"context": "...", "question": "...", "answer": "..."}
```
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
```json
{"article": "...", "unanswerable_question": "..."}
@@ -268,11 +278,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"prompt": "...", "generation": "..."}
```
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -307,7 +317,7 @@ Using file:
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
### Config
@@ -350,6 +360,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
# dataset with splits, but no train split
dataset:
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
```
- loading
@@ -407,6 +423,11 @@ tokenizer_legacy:
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# used to identify which the model is based on
is_falcon_derived_model:
is_llama_derived_model:
is_mistral_derived_model:
# whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
@@ -438,6 +459,7 @@ datasets:
data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
# custom user prompt
- path: repo
@@ -465,6 +487,9 @@ datasets:
dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub
push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
@@ -492,6 +517,8 @@ pad_to_sequence_len:
max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing:
# set to 'false' if getting errors during eval with sample_packing on.
eval_sample_packing:
# you can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
@@ -544,7 +571,7 @@ torch_compile_backend: # Optional[str]
# training hyperparameters
gradient_accumulation_steps: 1
micro_batch_size: 2
eval_batch_size: 2
eval_batch_size:
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
@@ -628,6 +655,8 @@ flash_optimum:
xformers_attention:
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
@@ -715,11 +744,6 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
- llama Deepspeed
```yaml
deepspeed: deepspeed/zero3.json
```
##### Weights & Biases Logging
- wandb options
@@ -732,6 +756,24 @@ wandb_run_id:
wandb_log_model:
```
### Training with Deepspeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
or
```yaml
deepspeed: deepspeed/zero1.json
```
### Inference
Pass the appropriate flag to the train command:

View File

@@ -1,39 +1,41 @@
{
"zero_optimization": {
"stage": 1,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
"zero_optimization": {
"stage": 1,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -1,43 +1,45 @@
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -45,9 +45,11 @@
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
"warmup_num_steps": "auto",
"warmup_type": "linear"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -12,17 +12,18 @@ RUN apt-get update && \
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
else \
pip install -e .[flash-attn]; \
fi
# fix so that git fetch/pull from remote works
RUN cd axolotl && \
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli

View File

@@ -13,16 +13,14 @@ ARG CUDA="118"
ENV PYTHON_VERSION=$PYTHON_VERSION
RUN apt-get update
RUN apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
RUN wget \
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh
RUN conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
@@ -46,6 +44,8 @@ FROM base-builder AS bnb-builder
WORKDIR /workspace
ARG CUDA="118"
ENV CUDA=$CUDA
ARG MAX_JOBS="-1"
ENV MAX_JOBS=$MAX_JOBS
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
cd bitsandbytes && \
@@ -60,8 +60,7 @@ ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
# recompile apex
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN cd apex && python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes

45
docs/multi-node.md Normal file
View File

@@ -0,0 +1,45 @@
# Multi Node
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
~/.cache/huggingface/accelerate/default_config.yaml
```yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
machine_rank: 0 # Set to 0 for the main machine, increment by one for other machines
main_process_ip: 10.0.0.4 # Set to main machine's IP
main_process_port: 5000
main_training_function: main
mixed_precision: bf16
num_machines: 2 # Change to the number of machines
num_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8)
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
Configure your model to use FSDP with for example:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Machine configuration
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
You will also need to have the same configuration file for your model on each machine.
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.

View File

@@ -7,7 +7,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
lora_model_dir:

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./qlora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./qlora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./qlora-out

View File

@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false
@@ -11,7 +12,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
adapter: lora
lora_model_dir:

View File

@@ -6,6 +6,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
@@ -17,7 +18,7 @@ datasets:
data_files:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat"
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
# enable QLoRA
adapter: qlora

View File

@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false
@@ -11,7 +12,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
adapter:
lora_model_dir:

View File

@@ -7,7 +7,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
lora_model_dir:

View File

@@ -6,7 +6,7 @@ load_in_8bit: false
datasets:
- path: openaccess-ai-collective/jeopardy
type: jeopardy
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter:
lora_model_dir:

View File

@@ -15,7 +15,7 @@ hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
adapter: lora
lora_model_dir:

View File

@@ -1,5 +1,5 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out
@@ -56,7 +56,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:

View File

@@ -1,5 +1,5 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./qlora-out
@@ -58,7 +58,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_size:
save_steps:
debug:
deepspeed:

View File

@@ -1,5 +1,5 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./relora-out

View File

@@ -12,7 +12,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out
@@ -56,7 +56,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_size:
save_steps:
debug:
deepspeed:

View File

@@ -0,0 +1,12 @@
**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.
Fine Tune:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml
```
If you run into CUDA OOM, use deepspeed with config zero2.json:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
```

View File

@@ -0,0 +1,62 @@
base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./out
sequence_len: 8192
sample_packing:
pad_to_sequence_len:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,79 @@
base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 8192
sample_packing: True
pad_to_sequence_len: True
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -6,7 +6,7 @@ load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter:
lora_model_dir:

View File

@@ -9,7 +9,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter:
lora_model_dir:

View File

@@ -9,7 +9,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter: lora
lora_model_dir:

View File

@@ -9,7 +9,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
lora_model_dir:

View File

@@ -13,7 +13,7 @@ datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out

View File

@@ -13,7 +13,7 @@ datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out

View File

@@ -10,7 +10,7 @@ device_map: auto
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.05
adapter:
lora_model_dir:

View File

@@ -4,7 +4,7 @@ load_in_8bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
lora_model_dir:

View File

@@ -7,7 +7,7 @@ load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter:
lora_model_dir:

View File

@@ -5,7 +5,7 @@ load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
lora_model_dir:

View File

@@ -16,7 +16,7 @@ datasets:
data_files:
- openassistant_best_replies_train.jsonl
type: "completion"
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.01
# enable QLoRA
adapter: qlora

View File

@@ -4,15 +4,15 @@ torch==2.0.1
auto-gptq
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
deepspeed
addict
evaluate
fire
PyYAML>=6.0
datasets
flash-attn>=2.2.1
flash-attn>=2.3.0
sentencepiece
wandb
einops
@@ -30,3 +30,4 @@ scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.29

View File

@@ -7,6 +7,7 @@ import transformers
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
do_inference,
do_merge_lora,
load_cfg,
@@ -31,6 +32,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
)
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True

View File

@@ -13,7 +13,12 @@ def parse_requirements():
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif "flash-attn" not in line and line and line[0] != "#":
elif (
"flash-attn" not in line
and "deepspeed" not in line
and line
and line[0] != "#"
):
# Handle standard packages
_install_requires.append(line)
return _install_requires, _dependency_links
@@ -35,7 +40,7 @@ setup(
"flash-attn": [
"flash-attn>=2.2.1",
],
"extras": [
"deepspeed": [
"deepspeed",
],
},

View File

@@ -14,6 +14,8 @@ import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
@@ -49,7 +51,7 @@ def print_axolotl_text_art(suffix=None):
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
@@ -247,3 +249,16 @@ def check_accelerate_default_config():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token():
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

View File

@@ -8,6 +8,7 @@ import transformers
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
@@ -21,6 +22,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True

View File

@@ -5,7 +5,7 @@ import os
from typing import List
import torch
from datasets import Dataset, IterableDataset
from datasets import Dataset, IterableDataset, Sequence, Value
from .prompt_tokenizers import PromptTokenizingStrategy
@@ -22,7 +22,7 @@ class TokenizedPromptDataset(Dataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
"""
@@ -38,10 +38,19 @@ class TokenizedPromptDataset(Dataset):
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, os.cpu_count())
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
return (
dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
**map_kwargs,
)
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
)
@@ -50,7 +59,7 @@ class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""

View File

@@ -711,12 +711,8 @@ class ParallelBlock(nn.Module):
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.block_idx = block_idx
self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
mlp_cls = mlp.pop("mlp_cls")
if mlp_cls == "fused_mlp":
self.mlp = FusedMLP(config=config, **mlp)
else:
self.mlp = MLP(config=config, **mlp)
self.mixer = MHA(config, layer_idx=block_idx)
self.mlp = MLP(config)
def forward(
self,

View File

@@ -6,8 +6,8 @@ import importlib
import logging
from typing import Optional, Tuple
import accelerate
import torch
from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM
@@ -18,8 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
with accelerate.init_empty_weights():
AutoModelForCausalLM(model_config)
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)

View File

@@ -1,101 +0,0 @@
"""
Flash Attention monkey patch for Falcon
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
"""
from typing import Optional, Tuple
import torch
import transformers
from flash_attn import flash_attn_func
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor, # pylint: disable=unused-argument
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
use_cache: bool = False,
output_attentions: bool = False, # pylint: disable=unused-argument
):
fused_qkv = self.query_key_value(
hidden_states
) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = (
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(
query_layer,
key_layer,
value_layer,
) = self._split_heads( # pylint: disable=protected-access
fused_qkv
)
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * self.num_heads, query_length, self.head_dim
)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads, query_length, self.head_dim
)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
# unused
# _, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
# unused
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
query_layer_ = (
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
key_layer_ = (
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
value_layer_ = (
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
if alibi is not None:
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
# below output will have shape (batch_size, seqlen, nheads, headdim)
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
attn_output = attn_output.reshape(
batch_size, query_length, self.num_heads * self.head_dim
)
output_tensor = self.dense(attn_output)
return output_tensor, present
def replace_falcon_attn_with_flash_attn():
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward

View File

@@ -0,0 +1,174 @@
"""
monkeypatch to add a get_turns method
"""
import logging
from typing import Generator, Tuple
from fastchat.conversation import SeparatorStyle
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
def get_prompt(self) -> str:
ret = ""
for role, msg in self.get_turns():
ret += role + msg
return ret
def get_turns( # pylint: disable=too-many-return-statements
self,
) -> Generator[Tuple[str, str], None, None]:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message + seps[i % 2]
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ": ", "" # must be end with a space
return
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
yield "", "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role, message + self.sep
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role, message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.RWKV:
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message.replace("\r\n", "\n").replace(
"\n\n", "\n"
) + "\n\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
yield "", system_prompt
else:
yield "", "[INST] "
for i, (role, message) in enumerate(self.messages[1:]):
if message:
yield role + " ", message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if system_prompt:
yield "", system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
yield f"{role}", f"{message}{self.sep}"
else:
yield f"{role}", ""
return
if self.sep_style == SeparatorStyle.CHATML:
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep + "\n"
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
prefix = "<s>" if i % 2 == 0 else ""
if message:
yield prefix + role + ":", message + seps[i % 2] + "\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
suffix = "\n\n" if i % 2 == 1 else ""
yield role + ":\n", message + seps[i % 2] + suffix
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.PHOENIX:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + ": ", "<s>" + message + "</s>"
else:
yield role + ": " + "<s>", ""
return
if self.sep_style == SeparatorStyle.ROBIN:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ":\n", message + self.sep
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.FALCON_CHAT:
if self.system_message:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def add_get_turns_to_conversation():
import fastchat.conversation
fastchat.conversation.Conversation.get_turns = get_turns
fastchat.conversation.Conversation.get_prompt = get_prompt

View File

@@ -38,7 +38,11 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
@@ -49,33 +53,37 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
llama_model_forward
)
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
# skip only if explicitly disabled
if cross_entropy:
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
try:
from flash_attn.ops.rms_norm import RMSNorm
# skip only if explicitly disabled
if rms_norm:
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
@@ -99,6 +107,7 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -476,6 +485,13 @@ def llama_model_forward(
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
@@ -510,7 +526,9 @@ def llama_model_forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return module(
*inputs,
)
return custom_forward
@@ -519,9 +537,10 @@ def llama_model_forward(
hidden_states,
attention_mask,
position_ids,
None,
past_key_value,
output_attentions,
None,
padding_mask,
cu_seqlens,
max_seqlen,
)
@@ -533,6 +552,7 @@ def llama_model_forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
@@ -579,6 +599,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
@@ -611,6 +632,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

View File

@@ -0,0 +1,541 @@
"""Flash attention monkey patch for mistral model"""
# pylint: disable=duplicate-code
import logging
from typing import List, Optional, Tuple, Union
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
def replace_mistral_attn_with_flash_attn(
packed: Optional[bool] = False,
):
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
flashattn_forward
)
if packed:
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
MistralDecoderLayer
)
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
mistral_model_forward
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
sliding_window,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def flashattn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
causal=is_causal,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def mistral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class MistralDecoderLayer(OriginalMistralDecoderLayer):
"""
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

View File

@@ -0,0 +1,415 @@
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
""" PyTorch StableLM Epoch model. """
import importlib
import math
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from accelerate import init_empty_weights
from einops import rearrange
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_varlen_qkvpacked_func,
)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
logger = logging.get_logger(__name__)
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_stablelm_epoch to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
)
modeling_stablelm = importlib.import_module(module_name)
modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
flashattn_attn
)
modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
stablelm_model_forward
)
modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
decoder_layer_forward
)
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
# pylint: disable=invalid-name
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
# pylint: disable=invalid-name
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def flashattn_attn(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
query_rot = query_states[..., : self.rotary_ndims]
query_pass = query_states[..., self.rotary_ndims :]
key_rot = key_states[..., : self.rotary_ndims]
key_pass = key_states[..., self.rotary_ndims :]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_rot, key_rot, cos, sin, position_ids
)
# [batch_size, num_heads, seq_len, head_dim]
query_states = torch.cat((query_states, query_pass), dim=-1)
key_states = torch.cat((key_states, key_pass), dim=-1)
if past_key_value is not None:
# Reuse k, v, self_attention
key_states = torch.cat((past_key_value[0], key_states), dim=2)
value_states = torch.cat((past_key_value[1], value_states), dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# Repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
softmax_scale = None
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
)
attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# Upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
# Merge heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# Final linear projection
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def decoder_layer_forward(
self,
hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
]:
# pylint: disable=duplicate-code
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def stablelm_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View File

@@ -1,10 +1,81 @@
"""
Basic completion text
"""
from typing import Any, Dict, Optional
from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple
from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
from axolotl.prompters import CompletionPrompter
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Completion prompts.
"""
_field: str = "text"
def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
if max_length is not None:
self.max_length = max_length
@property
def supports_batched(self):
return True
@property
def field(self) -> str:
return self._field
@field.setter
def field(self, new_field: str):
self._field = new_field
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt[self.field],
"",
"",
)
def tokenize_prompt(self, prompt):
res = defaultdict(lambda: [])
feature_names = list(prompt.keys())
for row in zip(*prompt.values()):
prompt_row = dict(zip(feature_names, row))
(
instruction,
_,
_,
) = self.parse_instruction_fields(prompt_row)
full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt)
for key, val in tokenized_full_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
return dict(res)
def _build_full_prompt(
self, instruction, input, response
): # pylint: disable=redefined-builtin
return next(iter(self.prompter.build_prompt(instruction, input, response)))
class CompletionPrompter:
"""
Prompter for completion
"""
def build_prompt(
self,
instruction: str,
input=None, # pylint: disable=redefined-builtin, unused-argument
output=None, # pylint: disable=unused-argument
) -> Generator[str, None, None]:
yield instruction
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
@@ -13,6 +84,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
if ds_cfg and "field" in ds_cfg:
strat.field = ds_cfg["field"]

View File

@@ -24,6 +24,15 @@ def load(tokenizer, cfg):
)
def load_v2(tokenizer, cfg):
return ContextQaV2PromptTokenizingStrategy(
ContextV2Prompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class AlpacaContextPrompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
@@ -50,6 +59,38 @@ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
)
class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenization Strategy to combine in-context article with a question and answer
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
"Context: "
+ prompt["context"]
+ "\nQuestion: "
+ prompt["question"]
+ "\n",
"",
"Answer: " + prompt["answer"],
)
class ContextV2Prompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
"""
system_prompt = ""
system_no_input_prompt = ""
def match_prompt_style(self):
# pylint: disable=duplicate-code
self.turn_format = "{instruction}\n{input}"
self.turn_no_input_format = "{instruction}"
self.system_format = "{system}"
class AlpacaMissingInfoContextPromptTokenizingStrategy(
InstructionPromptTokenizingStrategy
):

View File

@@ -0,0 +1,119 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>\n",
)
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strat = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and ds_cfg["skip"]:
strat.skip_invalid = True
return strat
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return NousShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy used by nous/ldj for input/output keyed data
"""
def get_conversation_thread(self):
return "conversation"
def map_conversation_thread(self, conversation):
turns = []
for turn in conversation:
turns.append({"from": "human", "value": turn["input"]})
turns.append({"from": "gpt", "value": turn["output"]})
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
def map_conversation_thread(self, conversation):
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [
{"from": turn["role"], "value": turn["value"]} for turn in conversation
]
return turns
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
"""
def map_conversation_thread(self, conversation):
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"}
turns = [
{"from": role_map[turn["role"]], "value": turn["text"]}
for turn in conversation
]
return turns

View File

@@ -1,11 +1,11 @@
"""Module for Jokes prompts using sharegpt style """
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import PromptStyle, ShareGPTPrompter
from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg):
return SimpleJokesShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,

View File

@@ -1,67 +0,0 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import PromptStyle, ShareGPTPrompter
def load(tokenizer, cfg):
return SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
return turns
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
]
return turns

View File

@@ -4,10 +4,15 @@ import abc
import copy
import functools
import logging
from collections import defaultdict
from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID
LOG = logging.getLogger("axolotl")
@@ -18,6 +23,8 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
add_get_turns_to_conversation()
class InvalidDataException(Exception):
"""
@@ -41,11 +48,16 @@ class PromptTokenizingStrategy(abc.ABC):
self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs
self.sequence_len = sequence_len
self.max_length = sequence_len
@abc.abstractmethod
def tokenize_prompt(self, prompt):
pass
@property
def supports_batched(self):
return False
@functools.lru_cache(maxsize=128)
def _get_user_token(self):
try:
@@ -70,14 +82,14 @@ class PromptTokenizingStrategy(abc.ABC):
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
result: BatchEncoding
if not prompt.strip():
if not prompt:
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
max_length=self.max_length,
padding=False,
return_tensors=None,
)
@@ -86,7 +98,7 @@ class PromptTokenizingStrategy(abc.ABC):
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and len(result["input_ids"]) < self.max_length
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
@@ -247,46 +259,6 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
)
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Completion prompts.
"""
_field: str = "text"
@property
def field(self) -> str:
return self._field
@field.setter
def field(self, new_field: str):
self._field = new_field
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt[self.field],
"",
"",
)
def tokenize_prompt(self, prompt):
(
instruction,
_,
_,
) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt)
return tokenized_full_prompt
def _build_full_prompt(
self, instruction, input, response
): # pylint: disable=redefined-builtin
return next(iter(self.prompter.build_prompt(instruction, input, response)))
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for Reflection prompts.
@@ -380,86 +352,136 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for ShareGPT prompts.
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
_skip_invalid = False
@property
def supports_batched(self):
return True
@property
def skip_invalid(self):
return self._skip_invalid
@skip_invalid.setter
def skip_invalid(self, value):
self._skip_invalid = value
def get_conversation_thread(self):
return "conversations"
def map_conversation_thread(self, conversation):
return conversation
def tokenize_prompt(self, prompt):
result, current_len = tokenize_prompt_default()
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if part[0] == "USER:":
part = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
res = self._tokenize(
part.strip(),
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif part[0] == "ASSISTANT:":
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
part = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistent response, should end with an eos token
res = self._tokenize(
part.strip(),
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "SYSTEM:":
part = part[1] # Ignore the system role from preamble
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {part[0]}")
tokenized_res = defaultdict(lambda: [])
conv_field = self.get_conversation_thread()
for prmpt in prompt[conv_field]:
result, current_len = tokenize_prompt_default()
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
conversation: Conversation = (
self.prompter._conversation # pylint: disable=protected-access
)
try:
for _, part in enumerate(
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
if not part[1].strip():
err_msg = f"user turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
err_msg = f"assistant turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
err_msg = f"unhandled role: {part[0]}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
continue
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
for key, val in sorted(result.items(), key=lambda x: x[0]):
tokenized_res[key].append(val)
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
except ValueError as err:
LOG.warning("skipping prompt: %s", str(err))
return tokenized_res
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if not prompt.strip():
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]

View File

@@ -1,9 +1,10 @@
"""Module containing prompters"""
import dataclasses
import logging
from enum import Enum, auto
from typing import Generator, List, Optional, Tuple, Union
from enum import Enum
from typing import Generator, Optional, Union
from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
@@ -135,20 +136,6 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter:
"""
Prompter for completion
"""
def build_prompt(
self,
instruction: str,
input=None, # pylint: disable=redefined-builtin, unused-argument
output=None, # pylint: disable=unused-argument
) -> Generator[str, None, None]:
yield instruction
class GPTeacherPrompter(AlpacaPrompter):
"""
Prompter for GPTeacher
@@ -228,53 +215,6 @@ class ReflectAlpacaPrompter:
yield res
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
DOLLY = auto()
# TODO clean this 💩 up
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: Optional[str] = None
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2]
preamble = self.system + self.sep
yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages):
if message:
yield (role + ":", " " + message)
else:
LOG.warning(f"role with empty message: {role}")
yield (role + ":", "")
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
)
def append_message(self, role, message):
self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
@@ -285,28 +225,27 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
A prompter that generates prompts for the ShareGPT
"""
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
if prompt_style != PromptStyle.CHAT.value:
raise ValueError(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
)
system: str = (
system_prompt
if system_prompt
else (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
)
self._conversation = Conversation(
system=system,
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2=" ",
)
role_key_human = "human"
role_key_model = "gpt"
def __init__(
self,
prompt_style=None, # pylint: disable=unused-argument
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
):
if conversation:
if isinstance(conversation, Conversation):
self._conversation = conversation
else:
self._conversation = get_conv_template(conversation)
else:
self._conversation = get_conv_template("vicuna_v1.1")
if role_key_human:
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
def build_prompt(self, source) -> Generator[str, None, None]:
if len(source) < 2:
@@ -320,17 +259,14 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.system = source[0]["value"]
conv.set_system_message(source[0]["value"])
source.pop(0)
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
try:
# Apply prompt templates
if (
source[0]["from"] not in roles
or roles[source[0]["from"]] != conv.roles[0]
):
if source[0]["from"] not in roles:
# Skip the first one if it is not from human
source = source[1:]
except IndexError as err:
@@ -340,8 +276,29 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
if role != conv.roles[j % 2]:
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])
for part in conv.get_prompt():
for part in conv.get_turns():
if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}")
yield part
class ShareGPTPrompterV2(ShareGPTPrompter):
"""
A V2 prompter that generates prompts for the ShareGPT
"""
def __init__(
self,
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
)

View File

@@ -9,8 +9,7 @@ from pathlib import Path
from typing import Optional
import torch
# add src to the pythonpath so we don't need to pip install this
import transformers.modelcard
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
@@ -59,7 +58,9 @@ def train(
safe_serialization = cfg.save_safetensors is True
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
if (
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
) or cfg.resume_from_checkpoint is True:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
@@ -72,7 +73,9 @@ def train(
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
resume_from_checkpoint = (
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
)
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
@@ -103,6 +106,9 @@ def train(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
@@ -138,4 +144,7 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
return model, tokenizer

View File

@@ -1,14 +1,44 @@
"""Benchmarking and measurement utilities"""
import functools
import pynvml
import torch
from pynvml.nvml import NVMLError
def check_cuda_device(default_value):
"""
wraps a function and returns the default value instead of running the
wrapped function if cuda isn't available or the device is auto
:param default_value:
:return:
"""
def deco(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
device = kwargs.get("device", args[0] if args else None)
if (
not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):
return default_value
return func(*args, **kwargs)
return wrapper
return deco
@check_cuda_device(0.0)
def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3
@check_cuda_device((0.0, 0.0, 0.0))
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
@@ -16,6 +46,7 @@ def gpu_memory_usage_all(device=0):
return usage, reserved - usage, max(0, smi - reserved)
@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
@@ -31,9 +62,6 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available() or device == "auto":
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:

View File

@@ -43,26 +43,26 @@ LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter"""
class EvalFirstStepCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""
Callback to trigger evals on the first step
"""
def on_save(
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(
peft_model_path, save_safetensors=args.save_safetensors
)
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and args.eval_steps < 1.0
and state.global_step == 1
):
control.should_evaluate = True
return control

View File

@@ -49,6 +49,8 @@ def normalize_config(cfg):
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
if cfg.eval_batch_size is None:
cfg.eval_batch_size = cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
@@ -75,6 +77,8 @@ def normalize_config(cfg):
else:
cfg.torch_dtype = torch.float32
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
@@ -82,10 +86,39 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or "llama" in cfg.base_model.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
# figure out if the model is falcon
cfg.is_falcon_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
)
cfg.is_mistral_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"mistral",
]
)
or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower()
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)
log_gpu_memory_usage(LOG, "baseline", cfg.device)
@@ -94,7 +127,7 @@ def validate_config(cfg):
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
else:
if cfg.bf16 or cfg.bfloat16:
if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
@@ -126,6 +159,11 @@ def validate_config(cfg):
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.eval_batch_size != cfg.micro_batch_size:
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
if cfg.load_4bit:
raise ValueError("cfg.load_4bit parameter has been deprecated")
@@ -262,6 +300,45 @@ def validate_config(cfg):
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
)
if cfg.datasets:
for idx, ds_cfg in enumerate(cfg.datasets):
if not ds_cfg.type:
continue
if ds_cfg.type == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = "sharegpt"
if "sharegpt_simple" in ds_cfg.type:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if (
cfg.evaluation_strategy
and cfg.eval_steps
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -25,7 +25,6 @@ from axolotl.prompt_tokenizers import (
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
@@ -35,7 +34,6 @@ from axolotl.prompters import (
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter,
ShareGPTPrompter,
SummarizeTLDRPrompter,
)
from axolotl.utils.dict import DictDefault
@@ -76,7 +74,7 @@ def prepare_dataset(cfg, tokenizer):
with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
cfg, train_dataset, eval_dataset, tokenizer
)
if cfg.max_steps:
total_num_steps = min(
@@ -116,7 +114,7 @@ def load_tokenized_prepared_datasets(
if cfg.push_dataset_to_hub:
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token,
token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec
@@ -124,7 +122,7 @@ def load_tokenized_prepared_datasets(
if dataset:
...
elif any(prepared_ds_path.glob("*")):
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared dataset loaded from disk...")
@@ -157,24 +155,26 @@ def load_tokenized_prepared_datasets(
d.path,
name=d.name,
streaming=True,
use_auth_token=use_auth_token,
token=use_auth_token,
)
ds_from_hub = True
except FileNotFoundError:
except (FileNotFoundError, ValueError):
pass
# prefer local dataset, even if hub exists
local_path = Path(d.path)
if local_path.exists():
if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset(
d.path,
name=d.name,
data_files=d.data_files,
streaming=False,
split=None,
)
if not d.type:
ds = load_from_disk(d.path)
else:
ds = load_dataset(
d.path,
name=d.name,
data_files=d.data_files,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = "json"
if d.ds_type:
@@ -204,14 +204,29 @@ def load_tokenized_prepared_datasets(
name=d.name,
streaming=False,
data_files=d.data_files,
use_auth_token=use_auth_token,
token=use_auth_token,
)
else:
fp = hf_hub_download(
repo_id=d.path,
repo_type="dataset",
filename=d.data_files,
)
if isinstance(d.data_files, str):
fp = hf_hub_download(
repo_id=d.path,
repo_type="dataset",
filename=d.data_files,
)
elif isinstance(d.data_files, list):
fp = []
for file in d.data_files:
fp.append(
hf_hub_download(
repo_id=d.path,
repo_type="dataset",
filename=file,
)
)
else:
raise ValueError(
"data_files must be either a string or list of strings"
)
ds = load_dataset(
"json", name=d.name, data_files=fp, streaming=False, split=None
)
@@ -234,6 +249,16 @@ def load_tokenized_prepared_datasets(
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
elif (
isinstance(ds, DatasetDict)
and d.train_on_split
and d.train_on_split in ds
):
ds = ds[d.train_on_split]
elif isinstance(ds, DatasetDict):
raise ValueError(
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
)
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
@@ -320,15 +345,6 @@ def load_tokenized_prepared_datasets(
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
@@ -343,7 +359,7 @@ def load_tokenized_prepared_datasets(
if len(datasets) > 1:
LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)
if cfg.local_rank == 0:
if cfg.local_rank == 0 and cfg.dataset_prepared_path:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub:
@@ -403,7 +419,7 @@ def load_prepare_datasets(
)
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token,
token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec
@@ -411,7 +427,7 @@ def load_prepare_datasets(
if dataset:
...
elif any(prepared_ds_path.glob("*")):
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
LOG.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
)

View File

@@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
value_scalar = fn()
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
if not is_main_process():
dist.gather(value_tensor, dst=0)
@@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
"""
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
else:
value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device()
) # Placeholder tensor
# Broadcast the tensor to all processes.
barrier()
@@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
# Placeholder tensor for gathering results
if is_main_process():

View File

@@ -1,5 +1,4 @@
"""Module for models and model loading"""
import importlib
import logging
import math
import os
@@ -10,7 +9,9 @@ import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
@@ -80,11 +81,22 @@ def load_tokenizer(cfg):
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
if cfg.special_tokens:
for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: val})
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in cfg.tokens
]
)
return tokenizer
@@ -113,26 +125,29 @@ def load_model(
replace_btlm_attn_with_flash_attn(cfg.base_model)
if hasattr(model_config, "model_type") and model_config.model_type in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]:
if cfg.flash_attention:
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
replace_falcon_attn_with_flash_attn,
if (
hasattr(model_config, "model_type")
and model_config.model_type == "stablelm_epoch"
):
if cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
replace_stablelm_attn_with_flash_attn,
)
replace_falcon_attn_with_flash_attn()
replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
LOG.info("patching with flash attention")
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
@@ -157,6 +172,14 @@ def load_model(
# Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn,
)
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope,
@@ -175,21 +198,11 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
# special handling b/c remote MixFormers code doesn't have _no_split_modules set
if (
"MixFormerSequentialConfig" in model_config.__class__.__name__
and cfg.model_type == "AutoModelForCausalLM"
):
module_name = model_config.__class__.__module__.replace(
".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
)
modeling_phi = importlib.import_module(module_name)
# pylint:disable=protected-access
modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
"ParallelBlock"
]
model_kwargs = {}
model_kwargs["device_map"] = cfg.device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
@@ -212,6 +225,15 @@ def load_model(
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# sample packing uses custom FA2 patch
if cfg.flash_attention and not cfg.sample_packing:
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
):
model_kwargs["use_flash_attention_2"] = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
@@ -226,10 +248,8 @@ def load_model(
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -263,28 +283,22 @@ def load_model(
model = MixFormerSequentialForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -309,16 +323,22 @@ def load_model(
):
config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
@@ -326,10 +346,8 @@ def load_model(
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -364,7 +382,7 @@ def load_model(
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if "lm_head" in name or "embed_tokens" in name:
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
if hasattr(module, "weight"):
module.to(torch.float32)
@@ -466,10 +484,10 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

View File

@@ -31,7 +31,9 @@ def check_example_labels(example, tokenizer, text_only=False):
)
colored_tokens.append(colored_token)
LOG.info(" ".join(colored_tokens))
delimiter = "" if text_only else " "
LOG.info(delimiter.join(colored_tokens))
LOG.info("\n\n\n")
print(" ".join(colored_tokens))
return " ".join(colored_tokens)

View File

@@ -28,9 +28,9 @@ from transformers.trainer_pt_utils import SequentialDistributedSampler
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
@@ -397,23 +397,36 @@ def disable_datasets_caching():
set_caching_enabled(True)
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
if eval_dataset:
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
eval_dataset = eval_dataset.filter(
drop_long, num_proc=cfg.dataset_processes
)
if cfg.group_by_length:
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
train_dataset = train_dataset.map(
add_length, num_proc=cfg.dataset_processes
)
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
train_dataset = train_dataset.map(
add_position_ids, num_proc=cfg.dataset_processes
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids, num_proc=os.cpu_count()
add_position_ids, num_proc=cfg.dataset_processes
)
# Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
return train_dataset, eval_dataset
@@ -597,26 +610,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.eval_steps and cfg.evaluation_strategy:
# assume if the user set both, they know what they're doing
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
if cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
elif cfg.evaluation_strategy:
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
# if explicitly set for epoch, just set, and eval steps don't matter
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.eval_steps:
# steps isn't used w/ epochs
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_steps:
# save_steps implies save_strategy of steps
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = cfg.save_steps
elif cfg.save_strategy:
@@ -662,9 +668,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None
else cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
@@ -675,6 +679,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.eval_steps
and cfg.save_steps % cfg.eval_steps == 0
)
or False,
@@ -704,16 +709,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks = []
callbacks.append(GPUStatsCallback(cfg))
callbacks.append(EvalFirstStepCallback)
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))
if cfg.local_rank == 0 and cfg.adapter in [
"lora",
"qlora",
]: # only save in rank 0
callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)

View File

@@ -6,6 +6,7 @@ import logging
import os
import tempfile
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -24,6 +25,7 @@ class TestLoraLlama(unittest.TestCase):
def test_lora(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
@@ -51,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
@@ -62,9 +64,11 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
@@ -94,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
@@ -105,3 +109,53 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_gptq(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
"gptq_disable_exllama": True,
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"save_steps": 0.5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()

116
tests/e2e/test_mistral.py Normal file
View File

@@ -0,0 +1,116 @@
"""
E2E tests for lora llama
"""
import logging
import os
import tempfile
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMistral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
def test_lora(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ft(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists()

View File

@@ -0,0 +1,118 @@
"""
E2E tests for lora llama
"""
import logging
import os
import tempfile
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMistral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
def test_lora_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ft_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists()

View File

@@ -31,9 +31,9 @@ class TestPhi(unittest.TestCase):
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sequence_len": 512,
"sample_packing": False,
"load_in_8bit": True,
"load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
@@ -55,8 +55,9 @@ class TestPhi(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"bf16": True,
}
)
normalize_config(cfg)
@@ -74,9 +75,9 @@ class TestPhi(unittest.TestCase):
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sequence_len": 512,
"sample_packing": True,
"load_in_8bit": True,
"load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
@@ -98,8 +99,9 @@ class TestPhi(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"bf16": True,
}
)
normalize_config(cfg)

View File

@@ -0,0 +1,52 @@
{
"conversations": [
{
"from": "human",
"value": "how could i implement a minesweeper algorithm that utilises algebraic topology to solve boards?"
},
{
"from": "gpt",
"value": ""
},
{
"from": "human",
"value": "could you give me the gist of how it could though?"
},
{
"from": "gpt",
"value": "Sure. The basic idea would be to use algebraic topology to analyze the structure of the Minesweeper game board and determine the locations of mines based on certain topological invariants. This could involve representing the game board as a topological space and using topological invariants such as homotopy groups and Betti numbers to infer information about the location of mines on the board. However, as I mentioned before, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper, so it's not something that can be explained in a simple answer."
},
{
"from": "human",
"value": "give me a specific example"
},
{
"from": "gpt",
"value": "As a specific example, suppose we have a Minesweeper game board with the following configuration:\n```\n 1 1 1 1 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 1 1 1 1\n```\nIn this case, the game board can be represented as a topological space, where each square on the board corresponds to a point in the space. The mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space. Using algebraic topology, we can analyze the structure of this space and determine the locations of the mines based on certain topological invariants.\n\nFor instance, one way to do this is to compute the homotopy groups of the space. Homotopy groups are a way of describing the holes in a space, and they can be used to infer information about the locations of the mines on the Minesweeper board. In the case of the board above, we can compute the first homotopy group of the space, which describes the number and \"shape\" of the holes in the space. In this case, the first homotopy group would tell us that there are four holes in the space, arranged in a square shape.\n\nBased on this information, we can infer that the mines in the game are located in the four corners of the board, as shown below:\n```\n 1 1 1 1 1\n 1 M ? ? 1\n 1 ? ? ? 1\n 1 ? ? M 1\n 1 1 1 1 1\n```\nOf course, this is just one example of how algebraic topology could be used to solve Minesweeper. There are many other possible approaches, and the specific details of how to implement them would depend on the specific game board and the topological invariants being used. As I mentioned before, implementing a solution like this would require a deep understanding of both algebraic topology and the game of Minesweeper."
},
{
"from": "human",
"value": "assuming i was someone with a deep understanding of both, could you explain the same concept to me with that assumption?"
},
{
"from": "gpt",
"value": "Sure. Assuming you have a deep understanding of both algebraic topology and the game of Minesweeper, here is how you might use algebraic topology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Use topological invariants to infer information about the locations of the mines on the board. There are many possible topological invariants that could be used for this, such as homotopy groups, Betti numbers, or persistent homology.\n3. Use the information inferred from the topological invariants to determine the locations of the mines on the board. For example, if the first homotopy group of the space indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using algebraic topology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the topological invariants being used. However, this should give you a general idea of how algebraic topology can be used to solve Minesweeper."
},
{
"from": "human",
"value": "how would i do it with persistent homology or any of the other topological invariants?"
},
{
"from": "gpt",
"value": "To use persistent homology or any of the other topological invariants to solve a Minesweeper game board, you would follow the same basic steps as described above, but with some differences in the specific details.\n\nFor instance, here is how you might use persistent homology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Compute the persistent homology of the space. Persistent homology is a way of describing the holes in a space over time, and it can be used to infer information about the locations of the mines on the Minesweeper board.\n3. Use the information from the persistent homology to determine the locations of the mines on the board. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using persistent homology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the specific details of the persistent homology computation. However, this should give you a general idea of how persistent homology can be used to solve Minesweeper."
},
{
"from": "human",
"value": "how could i implement this in python?"
},
{
"from": "gpt",
"value": "To implement a Minesweeper solver that uses algebraic topology in Python, you would need to do the following:\n\n1. Install a Python package for algebraic topology. There are several options available, such as Gudhi, Dionysus, or scikit-tda. These packages provide algorithms and data structures for computing topological invariants, such as homotopy groups, Betti numbers, or persistent homology.\n2. Represent the Minesweeper game board as a topological space. This could be done using a data structure provided by the algebraic topology package, such as a simplicial complex or a Vietoris-Rips complex. Each square on the board would correspond to a point in the space, and the mines on the board would be represented as \"holes\" in the space.\n3. Compute the topological invariants of the space. This could be done using algorithms provided by the algebraic topology package. For example, you could compute the persistent homology of the space using the persistent homology algorithm.\n4. Use the information from the topological invariants to determine the locations of the mines on the board. This could be done using a combination of mathematical reasoning and programming. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then you could use a piece of code to infer that the mines are located in the four corners of the board.\n5. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you could use a piece of code to automatically click on the safe squares and clear the board.\n\nOf course, this is just one possible approach to implementing a Minesweeper solver that uses algebraic topology in Python. There may be other ways to do it, depending on the specific details of the implementation. However, this should give you a general idea of how to get started with such a project."
}
]
}

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,9 @@ import json
import logging
import unittest
from pathlib import Path
from typing import Optional
import pytest
from transformers import AutoTokenizer, LlamaTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
@@ -19,7 +21,7 @@ from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
LOG = logging.getLogger("axolotl")
@@ -29,6 +31,12 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies.
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
@@ -52,7 +60,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
prompter = ShareGPTPrompter("chat")
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
@@ -64,6 +72,24 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])
def test_sharegpt_warnings_integration(self):
with open(
Path(__file__).parent / "fixtures/conversation.missingturns.json",
encoding="utf-8",
) as fin:
data = fin.read()
conversation = json.loads(data)
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
strat.tokenize_prompt(conversation)
assert "assistant turn has empty text" in self._caplog.records[1].message
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts

View File

@@ -351,3 +351,217 @@ class ValidationTest(unittest.TestCase):
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_merge_lora_no_bf16_fail(self):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.
"""
cfg = DictDefault(
{
"bf16": True,
}
)
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
validate_config(cfg)
cfg = DictDefault(
{
"bf16": True,
"merge_lora": True,
}
)
validate_config(cfg)
def test_sharegpt_deprecation(self):
cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt:chat` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt"
cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt_simple` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt:load_role"
def test_no_conflict_save_strategy(self):
cfg = DictDefault(
{
"save_strategy": "epoch",
"save_steps": 10,
}
)
with pytest.raises(
ValueError, match=r".*save_strategy and save_steps mismatch.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"save_strategy": "no",
"save_steps": 10,
}
)
with pytest.raises(
ValueError, match=r".*save_strategy and save_steps mismatch.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"save_strategy": "steps",
}
)
validate_config(cfg)
cfg = DictDefault(
{
"save_strategy": "steps",
"save_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"save_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"save_strategy": "no",
}
)
validate_config(cfg)
def test_no_conflict_eval_strategy(self):
cfg = DictDefault(
{
"evaluation_strategy": "epoch",
"eval_steps": 10,
}
)
with pytest.raises(
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "no",
"eval_steps": 10,
}
)
with pytest.raises(
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "steps",
}
)
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "steps",
"eval_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"eval_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "no",
}
)
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "epoch",
"val_set_size": 0,
}
)
with pytest.raises(
ValueError,
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"eval_steps": 10,
"val_set_size": 0,
}
)
with pytest.raises(
ValueError,
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"val_set_size": 0,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"eval_steps": 10,
"val_set_size": 0.01,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "epoch",
"val_set_size": 0.01,
}
)
validate_config(cfg)