Compare commits

..

36 Commits

Author SHA1 Message Date
Wing Lian
d0b534292f Add e2e test for ia3 ft 2023-10-19 09:27:55 -04:00
Wing Lian
0bd89b38c6 migrate lora_ to peft_ 2023-10-18 22:22:54 -04:00
Wing Lian
481ef187a5 update README for IA3 peft 2023-10-18 22:18:39 -04:00
Wing Lian
d645b19fcf include task type for ia3 config 2023-10-18 22:18:39 -04:00
Wing Lian
203369411e consolidate as peft_model_dir 2023-10-18 22:18:37 -04:00
Wing Lian
ba85308720 Update src/axolotl/utils/models.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-10-18 22:17:38 -04:00
Wing Lian
998763bade ia3 keeps casting to float32, handle it here for now 2023-10-18 22:17:38 -04:00
Wing Lian
c8e42a0f4f fix load_in_8bit check 2023-10-18 22:17:38 -04:00
Wing Lian
1da328eb9a prepare ia3 for 8bit 2023-10-18 22:17:38 -04:00
Wing Lian
2d7cccfc8e add ia3 peft support 2023-10-18 22:17:38 -04:00
NanoCode012
440c3ab527 Fix(model): Linear detected and added to target module with rope linear (#738)
* Fix(model): Linear detected and added to target module with rope linear

* fix: exclude layer instead
2023-10-18 22:13:20 -04:00
Napuh
992d57f20a catch ConnectionError when checking dataset from HuggingFace (#743) 2023-10-18 22:11:54 -04:00
mhenrichsen
91a016f410 badge (#739)
* badge

* fixed text
2023-10-18 10:21:34 -04:00
Casper
a045db0214 Mistral: Sliding Window Attention with Flash Attention and Sample Packing (#732)
* Implement Mistral FA + SWA + Sample Packing

* Handle unbroadcastable tensor

* chore: lint

* Simplify _prepare_decoder_attention_mask

* Uncomment window size

* Upgrade flash-attn to minimum of 2.3.0 to support SWA

* Add original condition to avoid error during inference

* chore: lint

* use torchscript to prevent oom

* chore: pylint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-16 15:13:46 -04:00
Casper
e1b214c62b Clarify custom format example (#729)
* Clarify custom prompt format

* Simplify format
2023-10-14 09:28:12 -04:00
Wing Lian
3553172e3c fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728) 2023-10-14 09:27:07 -04:00
Wing Lian
7f2027d93f tweak for xformers install w pytorch 2.1.0 (#727) 2023-10-13 15:21:17 -04:00
Wing Lian
8d288a2ad4 workaround for installing xformers w torch 2.1.0 (#725) 2023-10-13 11:19:30 -04:00
Wing Lian
f30afe4544 misc sharegpt fixes (#723)
* support for sharegpt with assistant talking first, better masking of assistant token, allow remap of roles from dataset

* invalid role is actually not possible

* update tokenized fixture for corrected labels
2023-10-13 11:04:39 -04:00
Wing Lian
bfbdba8614 pin xformers >= 0.0.22 (#724) 2023-10-13 10:27:56 -04:00
Maxime
3bd9528390 add noisy embedding (#721)
* add noisy embedding

* fix format

* Update README.md

* Update README.md

* linter issues

* caseus fixes

---------

Co-authored-by: Maxime <maxime@nope.no>
2023-10-13 10:00:42 -04:00
Wing Lian
2aa1f71464 fix pytorch 2.1.0 build, add multipack docs (#722) 2023-10-13 08:57:28 -04:00
Wing Lian
1c412c7e9d improve handling of the prepared ds path and other cfg defaults (#701) 2023-10-13 07:46:07 -04:00
Jan Philipp Harries
490923fb78 Save Axolotl config as WandB artifact (#716) 2023-10-11 07:28:12 -04:00
NanoCode012
5855dded3d fix(doc): update default doc according to arg (#714) 2023-10-10 21:51:56 +09:00
atgctg
ace70b33c6 Fix: lowercase True values in config (#713)
* Fix: lowercase `True` values in config

* Fix: lowercase `True` values in config
2023-10-10 21:32:20 +09:00
NanoCode012
11c48c5e03 fix(doc): Add note on inference w sample packing (#712) 2023-10-10 21:08:17 +09:00
lukemarsden
295b2662e1 Get qlora mistral-7b fine tuning working on a single 4090 (#708) 2023-10-10 15:14:23 +09:00
seungduk.kim.2304
77c84e02fd Update README with some explanations (#700)
* Update README with some explanations

* revert commit-hook change

* add more explanation about batch size and gradient accum

* not use latex foromat

* decorate

* git hook again

* Attach a link that explains about LoRA hyperparameters

* update table of content

* Explanation about lora_modules_to_save
2023-10-08 13:37:54 -04:00
mhenrichsen
f91db198f3 fix unneeded space (#699) 2023-10-07 14:19:25 -04:00
Wing Lian
7f2618b5f4 add docker images for pytorch 2.10 (#697) 2023-10-07 12:23:31 -04:00
Wing Lian
aca0398315 apex not needed as amp is part of pytorch (#696) 2023-10-07 12:20:45 -04:00
mhenrichsen
29b8f46aed Merge pull request #693 from OpenAccess-AI-Collective/update-mistral-example
update mistral lr, sample pack
2023-10-07 11:04:58 +02:00
mhenrichsen
83a950bb87 lint 2023-10-07 11:04:35 +02:00
Wing Lian
de87ea68f6 fix multiline for docker (#694) 2023-10-06 22:38:15 -04:00
mhenrichsen
4c8ddf2c6f new lr, sample pack 2023-10-06 22:58:13 +02:00
65 changed files with 1104 additions and 393 deletions

View File

@@ -25,6 +25,11 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@@ -23,6 +23,11 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
@@ -46,6 +51,7 @@ jobs:
build-args: | build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
@@ -68,6 +74,11 @@ jobs:
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -12,4 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error, disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-nested-blocks, too-many-boolean-expressions,

316
README.md
View File

@@ -23,9 +23,10 @@ Features:
- [Supported Features](#axolotl-supports) - [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-) - [Quickstart](#quickstart-)
- [Installation](#installation) - [Installation](#installation)
- [Docker Installation](#environment) - [Docker](#docker)
- [Conda/Pip venv Installation](#condapip-venv) - [Conda/Pip venv](#condapip-venv)
- [LambdaLabs Installation](#lambdalabs) - [LambdaLabs](#lambdalabs)
- [Windows](#windows)
- [Dataset](#dataset) - [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts) - [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
@@ -50,7 +51,7 @@ Features:
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b> <b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
</p> </p>
<p> <p>
Go ahead and axolotl questions!! Go ahead and Axolotl questions!!
</p> </p>
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit"> <img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main"> <img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
@@ -95,14 +96,14 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --peft_model_dir="./lora-out"
``` ```
## Installation ## Installation
### Environment ### Environment
- Docker #### Docker
```bash ```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1 docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
@@ -114,12 +115,12 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
docker compose up -d docker compose up -d
``` ```
- Conda/Pip venv #### Conda/Pip venv
1. Install python >=**3.9** 1. Install python >=**3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/ 2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install axolotl along with python dependencies 3. Install Axolotl along with python dependencies
```bash ```bash
pip3 install packaging pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]' pip3 install -e '.[flash-attn,deepspeed]'
@@ -130,7 +131,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
``` ```
Get the token at huggingface.co/settings/tokens Get the token at huggingface.co/settings/tokens
- LambdaLabs #### LambdaLabs
<details> <details>
<summary>Click to Expand</summary> <summary>Click to Expand</summary>
@@ -174,7 +175,8 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
``` ```
</details> </details>
- Windows: Please use WSL or Docker! #### Windows
Please use WSL or Docker!
### Dataset ### Dataset
@@ -295,25 +297,24 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts #### How to add custom prompts
Using yaml. Example: For a dataset that is preprocessed for instruction purposes:
```json
{"instruction": "...", "output": "..."}
```
You can use this example in your YAML config:
```yaml ```yaml
datasets: datasets:
- path: repo - path: repo
type: type:
system_prompt: "" system_prompt: ""
no_input_format: |- field_system: system
User: {instruction}<|end_of_turn|> format: "[INST] {instruction} [/INST]"
Assistant: no_input_format: "[INST] {instruction} [/INST]"
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
``` ```
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
#### How to use your custom pretokenized dataset #### How to use your custom pretokenized dataset
- Do not pass a `type:` - Do not pass a `type:`
@@ -383,10 +384,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora - lora
```yaml ```yaml
adapter: lora # qlora or leave blank for full finetune adapter: lora # qlora or leave blank for full finetune
lora_r: 8 peft_r: 8
lora_alpha: 16 peft_alpha: 16
lora_dropout: 0.05 peft_dropout: 0.05
lora_target_modules: peft_target_modules:
- q_proj - q_proj
- v_proj - v_proj
``` ```
@@ -396,15 +397,15 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
<summary>All yaml options</summary> <summary>All yaml options</summary>
```yaml ```yaml
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# this can also be a relative path to a model on disk # This can also be a relative path to a model on disk
base_model: ./llama-7b-hf base_model: ./llama-7b-hf
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc) # You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
base_model_ignore_patterns: base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files, # If the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model # You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf base_model_config: ./llama-7b-hf
# you can specify to choose a specific model revision from huggingface hub # You can specify to choose a specific model revision from huggingface hub
model_revision: model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer # Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model # than the one defined in the base model
@@ -419,23 +420,24 @@ trust_remote_code:
tokenizer_use_fast: tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True # Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy: tokenizer_legacy:
# resize the model embeddings when new tokens are added to multiples of 32 # Resize the model embeddings when new tokens are added to multiples of 32
# this is reported to improve training speed on some models # This is reported to improve training speed on some models
resize_token_embeddings_to_32x: resize_token_embeddings_to_32x:
# used to identify which the model is based on # Used to identify which the model is based on
is_falcon_derived_model: is_falcon_derived_model:
is_llama_derived_model: is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model: is_mistral_derived_model:
# whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
gptq_groupsize: 128 # group size gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2 gptq_model_v1: false # v1 or v2
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true load_in_8bit: true
# use bitsandbytes 4 bit # Use bitsandbytes 4 bit
load_in_4bit: load_in_4bit:
# Use CUDA bf16 # Use CUDA bf16
@@ -449,9 +451,9 @@ tf32: true # require >=ampere
bfloat16: true # require >=ampere bfloat16: true # require >=ampere
float16: true float16: true
# a list of one or more datasets to finetune the model with # A list of one or more datasets to finetune the model with
datasets: datasets:
# hf dataset repo | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
@@ -461,17 +463,18 @@ datasets:
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
# custom user prompt # Custom user prompt
- path: repo - path: repo
type: type:
# the below are defaults. only set what's needed. # The below are defaults. only set what's needed.
system_prompt: "" system_prompt: ""
system_format: "{system}"
field_system: system field_system: system
field_instruction: instruction field_instruction: instruction
field_output: input field_input: input
field_output: output
# customizable to be single line or multi-line # Customizable to be single line or multi-line
system_format: "{system}"
# 'format' can include {input} # 'format' can include {input}
format: |- format: |-
User: {instruction} {input} User: {instruction} {input}
@@ -479,13 +482,13 @@ datasets:
# 'no_input_format' cannot include {input} # 'no_input_format' cannot include {input}
no_input_format: "{instruction} " no_input_format: "{instruction} "
# for completions datsets, uses the provided field if not `text` # For `completion` datsets only, uses the provided field instead of `text` column
field: field:
# axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub # Push prepared dataset to hub
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set. # if not set.
@@ -495,8 +498,8 @@ hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub # how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy: hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub` # Required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval. # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
val_set_size: 0.04 val_set_size: 0.04
@@ -505,34 +508,38 @@ dataset_shard_num:
# Index of shard to use for whole dataset # Index of shard to use for whole dataset
dataset_shard_idx: dataset_shard_idx:
# the maximum length of an input to train with, this should typically be less than 2048 # The maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048 # as most models have a token/context limit of 2048
sequence_len: 2048 sequence_len: 2048
# pad inputs so each step uses constant sized buffers # Pad inputs so each step uses constant sized buffers
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len: pad_to_sequence_len:
# max sequence length to concatenate training samples together up to # Max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED # FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024 max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing: sample_packing:
# set to 'false' if getting errors during eval with sample_packing on. # Set to 'false' if getting errors during eval with sample_packing on.
eval_sample_packing: eval_sample_packing:
# you can set these packing optimizations AFTER starting a training at least once. # You can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values. # The trainer will provide recommended values for these values.
sample_packing_eff_est: sample_packing_eff_est:
total_num_tokens: total_num_tokens:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
# if you already have a lora model trained that you want to load, put that here # If you already have a lora model trained that you want to load, put that here.
# lora hyperparameters # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
lora_model_dir: peft_model_dir:
lora_r: 8
lora_alpha: 16 # LoRA hyperparameters
lora_dropout: 0.05 # For more details about the following options, see:
lora_target_modules: # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
peft_r: 8
peft_alpha: 16
peft_dropout: 0.05
peft_target_modules:
- q_proj - q_proj
- v_proj - v_proj
# - k_proj # - k_proj
@@ -540,36 +547,49 @@ lora_target_modules:
# - gate_proj # - gate_proj
# - down_proj # - down_proj
# - up_proj # - up_proj
lora_target_linear: # if true, will target all linear layers peft_target_linear: # if true, will target all linear layers
lora_modules_to_save:
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
peft_modules_to_save:
# - embed_tokens # - embed_tokens
# - lm_head # - lm_head
# Once you complete training, the model will be saved to the following directory.
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir: lora_out_dir:
lora_fan_in_fan_out: false peft_fan_in_fan_out: false
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
# ReLoRA configuration # ReLoRA configuration
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # number of steps per ReLoRA restart relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # number of per-restart warmup steps relora_warmup_steps: # Number of per-restart warmup steps
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it # wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # your wandb project name wandb_project: # Your wandb project name
wandb_entity: # a wandb Team name if using a Team wandb_entity: # A wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_run_id: # set the name of your wandb run wandb_run_id: # Set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# where to save the finished model to # Where to save the full-finetuned model to
output_dir: ./completed-model output_dir: ./completed-model
# whether to use torch.compile and which backend to use # Whether to use torch.compile and which backend to use
torch_compile: # bool torch_compile: # bool
torch_compile_backend: # Optional[str] torch_compile_backend: # Optional[str]
# training hyperparameters # Training hyperparameters
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
micro_batch_size: 2 micro_batch_size: 2
eval_batch_size: eval_batch_size:
num_epochs: 3 num_epochs: 3
@@ -577,44 +597,47 @@ warmup_steps: 100
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
save_strategy: # set to `no` to skip checkpoint saves save_strategy: # Set to `no` to skip checkpoint saves
save_steps: # leave empty to save at each epoch save_steps: # Leave empty to save at each epoch
eval_steps: # leave empty to eval at each epoch eval_steps: # Leave empty to eval at each epoch
save_total_limit: # checkpoints saved at a time save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
max_steps: max_steps:
eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128 eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# save model as safetensors (require safetensors package) # Save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
# whether to mask out or include the human's prompt from the training labels # Whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# group similarly sized data to minimize padding # Group similarly sized data to minimize padding.
# may be slower to start, as it must download and sort the entire dataset # May be slower to start, as it must download and sort the entire dataset.
# note that training loss may have an oscillating pattern with this enabled # Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row # Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3 early_stopping_patience: 3
# specify a scheduler and kwargs to use with the optimizer # Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
# for one_cycle optim # For one_cycle optim
lr_div_factor: # learning rate div factor lr_div_factor: # Learning rate div factor
# for log_sweep optim # For log_sweep optim
log_sweep_min_lr: log_sweep_min_lr:
log_sweep_max_lr: log_sweep_max_lr:
# specify optimizer # Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see: # Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
# #
@@ -640,7 +663,7 @@ log_sweep_max_lr:
# - paged_lion_32bit # - paged_lion_32bit
# - paged_lion_8bit # - paged_lion_8bit
optimizer: optimizer:
# specify weight decay # Specify weight decay
weight_decay: weight_decay:
# adamw hyperparams # adamw hyperparams
adam_beta1: adam_beta1:
@@ -649,49 +672,56 @@ adam_epsilon:
# Gradient clipping max norm # Gradient clipping max norm
max_grad_norm: max_grad_norm:
# whether to bettertransformers # Augmentation techniques
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral
noisy_embedding_alpha:
# Whether to bettertransformers
flash_optimum: flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers: # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention: xformers_attention:
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention: # Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention: flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# whether to use scaled-dot-product attention # Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention: sdp_attention:
# Landmark attention (only llama) # Landmark attention (only llama)
landmark_attention: landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only # LLaMA only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653 # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling: rope_scaling:
type: # linear | dynamic type: # linear | dynamic
factor: # float factor: # float
# resume from a specific checkpoint dir # Resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
# if resume_from_checkpoint isn't set and you simply want it to start where it left off # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# be careful with this being turned on between different models # Be careful with this being turned on between different models.
auto_resume_from_checkpoints: false auto_resume_from_checkpoints: false
# don't mess with this, it's here for accelerate and torchrun # Don't mess with this, it's here for accelerate and torchrun
local_rank: local_rank:
# add or change special tokens # Add or change special tokens.
# If you add tokens here, you don't need to add them to the `tokens` list.
special_tokens: special_tokens:
# bos_token: "<s>" # bos_token: "<s>"
# eos_token: "</s>" # eos_token: "</s>"
# unk_token: "<unk>" # unk_token: "<unk>"
# add extra tokens
# Add extra tokens.
tokens: tokens:
# FSDP # FSDP
fsdp: fsdp:
fsdp_config: fsdp_config:
# Deepspeed config path # Deepspeed config path. e.g., deepspeed/zero3.json
deepspeed: deepspeed:
# Advanced DDP Arguments # Advanced DDP Arguments
@@ -717,6 +747,66 @@ strict:
</details> </details>
<details>
<summary> Understanding of batch size and gradient accumulation steps </summary>
<br/>
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
**Example 1:**
Micro batch size: 3
Gradient accumulation steps: 2
Number of GPUs: 3
Total batch size = 3 * 2 * 3 = 18
```
| GPU 1 | GPU 2 | GPU 3 |
|----------------|----------------|----------------|
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|----------------|----------------|----------------|
| → (accumulate) | → (accumulate) | → (accumulate) |
|----------------|----------------|----------------|
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|----------------|----------------|----------------|
| → (apply) | → (apply) | → (apply) |
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
Weight update for w1:
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
```
**Example 2:**
Micro batch size: 2
Gradient accumulation steps: 1
Number of GPUs: 3
Total batch size = 2 * 1 * 3 = 6
```
| GPU 1 | GPU 2 | GPU 3 |
|-----------|-----------|-----------|
| S1, S2 | S3, S4 | S5, S6 |
| e1, e2 | e3, e4 | e5, e6 |
|-----------|-----------|-----------|
| → (apply) | → (apply) | → (apply) |
Accumulated gradient for the weight w1 (considering all GPUs):
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
Weight update for w1:
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
```
</details>
### Train ### Train
Run Run
@@ -780,7 +870,7 @@ Pass the appropriate flag to the train command:
- Pretrained LORA: - Pretrained LORA:
```bash ```bash
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir" python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
``` ```
- Full weights finetune: - Full weights finetune:
```bash ```bash
@@ -792,12 +882,16 @@ Pass the appropriate flag to the train command:
--base_model="./completed-model" --prompter=None --load_in_8bit=True --base_model="./completed-model" --prompter=None --load_in_8bit=True
``` ```
Please use `--sample_packing False` if you have it on and receive the error similar to below:
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
### Merge LORA to base ### Merge LORA to base
Add below flag to train command above Add below flag to train command above
```bash ```bash
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
``` ```
If you run out of CUDA memory, you can try to merge in system RAM with If you run out of CUDA memory, you can try to merge in system RAM with

View File

@@ -5,6 +5,9 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
ARG CUDA="118" ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y vim curl apt-get install -y vim curl
@@ -16,6 +19,7 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
else \ else \

View File

@@ -14,7 +14,7 @@ ARG CUDA="118"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \
@@ -57,11 +57,6 @@ FROM base-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
# recompile apex
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
RUN cd apex && python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN mkdir -p /workspace/builds RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes

51
docs/multipack.md Normal file
View File

@@ -0,0 +1,51 @@
# Multipack
4k context, bsz =4,
each character represents 256 tokens
X represents a padding token
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B ]
C C C C C C C ]
D D D D ]]
[[ E E E E E E E E ]
[ F F F F ]
[ G G G ]
[ H H H H ]]
[[ I I I ]
[ J J J ]
[ K K K K K]
[ L L L ]]
```
after padding to longest input in each step
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B X X X X X X ]
C C C C C C C X X X X ]
D D D D X X X X X X X ]]
[[ E E E E E E E E ]
[ F F F F X X X X ]
[ G G G X X X X X ]
[ H H H H X X X X ]]
[[ I I I X X ]
[ J J J X X ]
[ K K K K K ]
[ L L L X X ]]
```
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A B B B B B
B C C C C C C C D D D D E E E E
E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]]
```

View File

@@ -18,7 +18,7 @@ dataset_prepared_path: last_prepared_run
val_set_size: 0.01 val_set_size: 0.01
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 16 lora_r: 16

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 16 lora_r: 16

View File

@@ -22,7 +22,7 @@ dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 64 lora_r: 64

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 512 sequence_len: 512
max_packed_sequence_len: max_packed_sequence_len:
lora_r: lora_r:

View File

@@ -18,7 +18,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: sample_packing:
lora_r: 8 lora_r: 8

72
examples/llama-2/ia3.yml Normal file
View File

@@ -0,0 +1,72 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./ia3-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: ia3
peft_model_dir:
peft_target_modules:
- k_proj
- v_proj
- down_proj
peft_feedforward_modules:
- down_proj
peft_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 5
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./relora-out output_dir: ./relora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sequence_len: 4096
sample_packing: true sample_packing: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,8 +16,8 @@ val_set_size: 0.01
output_dir: ./out output_dir: ./out
sequence_len: 8192 sequence_len: 8192
sample_packing: sample_packing: true
pad_to_sequence_len: pad_to_sequence_len: true
wandb_project: wandb_project:
wandb_entity: wandb_entity:
@@ -30,7 +30,7 @@ micro_batch_size: 2
num_epochs: 3 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.000005
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false

View File

@@ -19,8 +19,8 @@ adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 8192 sequence_len: 8192
sample_packing: True sample_packing: true
pad_to_sequence_len: True pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
@@ -43,7 +43,7 @@ wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 4 micro_batch_size: 2
num_epochs: 1 num_epochs: 1
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: lora_r:

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: 8 lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: 8 lora_r: 8

View File

@@ -22,7 +22,7 @@ sample_packing: true
pad_to_sequence_len: pad_to_sequence_len:
adapter: adapter:
lora_model_dir: peft_model_dir:
lora_r: lora_r:
lora_alpha: lora_alpha:
lora_dropout: lora_dropout:

View File

@@ -22,7 +22,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
pad_to_sequence_len: pad_to_sequence_len:
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
lora_r: 64 lora_r: 64
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -13,7 +13,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 64 lora_r: 64

View File

@@ -7,7 +7,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 512 sequence_len: 512
lora_r: 16 lora_r: 16
lora_alpha: 32 lora_alpha: 32
@@ -28,8 +28,8 @@ num_epochs: 3
learning_rate: 0.00001 learning_rate: 0.00001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
bf16: True bf16: true
tf32: True tf32: true
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -8,7 +8,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -20,7 +20,7 @@ dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 8192 sequence_len: 8192
max_packed_sequence_len: max_packed_sequence_len:

BIN
image/sticker_fixed.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 370 KiB

View File

@@ -16,7 +16,7 @@ flash-attn>=2.3.0
sentencepiece sentencepiece
wandb wandb
einops einops
xformers xformers>=0.0.22
optimum optimum
hf_transfer hf_transfer
colorama colorama

View File

@@ -21,6 +21,14 @@ def parse_requirements():
): ):
# Handle standard packages # Handle standard packages
_install_requires.append(line) _install_requires.append(line)
# TODO(wing) remove once xformers release supports torch 2.1.0
if "torch==2.1.0" in _install_requires:
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
_install_requires.append(
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
)
return _install_requires, _dependency_links return _install_requires, _dependency_links
@@ -38,7 +46,7 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn>=2.2.1", "flash-attn>=2.3.0",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed", "deepspeed",

View File

@@ -194,6 +194,7 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
# load the config from the yaml file # load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml, # if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value # then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()

View File

@@ -14,6 +14,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True

View File

@@ -1,10 +1,12 @@
""" """
CLI to run training on a model CLI to run training on a model
""" """
import logging
from pathlib import Path from pathlib import Path
import fire import fire
import transformers import transformers
from colorama import Fore
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -14,8 +16,11 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -27,6 +32,14 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "--prepare_ds_only called without dataset_prepared_path set."
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only: if parsed_cli_args.prepare_ds_only:

View File

@@ -0,0 +1,5 @@
"""
Various shared constants
"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

@@ -5,7 +5,7 @@ import os
from typing import List from typing import List
import torch import torch
from datasets import Dataset, IterableDataset, Sequence, Value from datasets import Dataset, IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
@@ -42,15 +42,11 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100 map_kwargs["batch_size"] = 100
return ( return dataset.map(
dataset.map( self.prompt_tokenizer.tokenize_prompt,
self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc,
num_proc=num_proc, remove_columns=features,
remove_columns=features, **map_kwargs,
**map_kwargs,
)
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
) )

View File

@@ -116,6 +116,8 @@ def flashattn_forward(
attention_mask: [bsz, q_len] attention_mask: [bsz, q_len]
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
original_dtype = hidden_states.dtype
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"): if not hasattr(self, "pretraining_tp"):
@@ -151,6 +153,13 @@ def flashattn_forward(
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
if query_states.dtype == torch.float32:
query_states = query_states.to(dtype=original_dtype)
if key_states.dtype == torch.float32:
key_states = key_states.to(dtype=original_dtype)
if value_states.dtype == torch.float32:
value_states = value_states.to(dtype=original_dtype)
query_states = query_states.view( query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2) ).transpose(1, 2)
@@ -309,6 +318,10 @@ def flashattn_forward(
else: else:
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
# handle conversion back for IA3
if attn_output.dtype == torch.float32:
attn_output = attn_output.to(dtype=original_dtype)
return attn_output, None, past_key_value return attn_output, None, past_key_value
@@ -502,6 +515,7 @@ def llama_model_forward(
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
original_dtype = hidden_states.dtype
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
@@ -559,6 +573,10 @@ def llama_model_forward(
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
# handle conversion back for IA3
if hidden_states.dtype == torch.float32:
hidden_states = hidden_states.to(dtype=original_dtype)
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

View File

@@ -0,0 +1,40 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
import transformers.models.llama.modeling_llama
from transformers.utils import logging
logger = logging.get_logger(__name__)
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)
return new_func
def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)
return new_func
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
transformers.models.llama.modeling_llama.LlamaModel.post_init
)

View File

@@ -14,6 +14,9 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
) )
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
@@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
) )
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
@@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
sliding_window, sliding_window,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# [bsz, seq_len] # [bsz, seq_len]
if attention_mask is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
return attention_mask return attention_mask
def flashattn_forward( def flashattn_forward(
self, self: OriginalMistralAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
@@ -91,10 +150,41 @@ def flashattn_forward(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)
if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # Activate slicing cache only if the config has a value `sliding_windows` attribute
key_states = torch.cat([past_key_value[0], key_states], dim=2) if (
value_states = torch.cat([past_key_value[1], value_states], dim=2) hasattr(self.config, "sliding_window")
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
past_key_value = (past_key, past_value) if use_cache else None
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
@@ -120,7 +210,13 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True qkv,
cu_seqlens,
max_seqlen,
0.0,
softmax_scale=None,
causal=True,
window_size=window_size,
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -146,6 +242,7 @@ def flashattn_forward(
0.0, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)
else: else:
@@ -157,6 +254,7 @@ def flashattn_forward(
query_states, query_states,
torch.stack([key_states, value_states], 2), torch.stack([key_states, value_states], 2),
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
else: else:
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
@@ -191,6 +289,7 @@ def flashattn_forward(
0.0, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)

View File

@@ -0,0 +1,40 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
import transformers.models.mistral.modeling_mistral
from transformers.utils import logging
logger = logging.get_logger(__name__)
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)
return new_func
def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)
return new_func
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
transformers.models.mistral.modeling_mistral.MistralModel.post_init
)

View File

@@ -1,6 +1,6 @@
"""Module containing the AlpacaQAPromptTokenizingStrategy class""" """Module for Alpaca prompt strategy classes"""
from typing import Tuple from typing import Any, Dict, Optional, Tuple
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompt_style = PromptStyle.CHAT.value
if ds_cfg and "conversation" in ds_cfg:
prompt_style = ds_cfg["conversation"]
return AlpacaPromptTokenizingStrategy( return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value), AlpacaPrompter(prompt_style),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
) )
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strat = ShareGPTPromptTokenizingStrategy( return SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2( ShareGPTPrompterV2(
conversation=conversation, conversation=conversation,
role_key_model=field_model, role_key_model=field_model,
@@ -34,9 +34,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
if ds_cfg and ds_cfg["skip"]:
strat.skip_invalid = True
return strat
def load_role(tokenizer, cfg): def load_role(tokenizer, cfg):
@@ -57,38 +54,13 @@ def load_guanaco(tokenizer, cfg):
) )
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return NousShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
""" """
basic sharegpt strategy used by nous/ldj for input/output keyed data basic sharegpt strategy to grab conversations from the sample row
""" """
def get_conversation_thread(self): def get_conversation_thread(self, prompt):
return "conversation" return prompt["conversations"]
def map_conversation_thread(self, conversation):
turns = []
for turn in conversation:
turns.append({"from": "human", "value": turn["input"]})
turns.append({"from": "gpt", "value": turn["output"]})
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -96,11 +68,10 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
def map_conversation_thread(self, conversation): def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ... # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [ turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
{"from": turn["role"], "value": turn["value"]} for turn in conversation
]
return turns return turns
@@ -109,11 +80,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
sharegpt strategy that remaps oasst data to sharegpt format sharegpt strategy that remaps oasst data to sharegpt format
""" """
def map_conversation_thread(self, conversation): def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ... # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"} role_map = {"prompter": "human", "assistant": "gpt"}
turns = [ turns = [
{"from": role_map[turn["role"]], "value": turn["text"]} {"from": role_map[t["role"]], "value": t["text"]} for t in conversations
for turn in conversation
] ]
return turns return turns

View File

@@ -2,9 +2,7 @@
import abc import abc
import copy import copy
import functools
import logging import logging
from collections import defaultdict
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
@@ -58,26 +56,6 @@ class PromptTokenizingStrategy(abc.ABC):
def supports_batched(self): def supports_batched(self):
return False return False
@functools.lru_cache(maxsize=128)
def _get_user_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
@functools.lru_cache(maxsize=128)
def _get_assistant_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
def _tokenize( def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding: ) -> BatchEncoding:
@@ -352,109 +330,99 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for ShareGPT prompts. Tokenizing strategy for ShareGPT prompts.
""" """
_skip_invalid = False def get_conversation_thread(self, prompt):
return prompt["conversations"]
@property
def supports_batched(self):
return True
@property
def skip_invalid(self):
return self._skip_invalid
@skip_invalid.setter
def skip_invalid(self, value):
self._skip_invalid = value
def get_conversation_thread(self):
return "conversations"
def map_conversation_thread(self, conversation):
return conversation
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
tokenized_res = defaultdict(lambda: []) result, current_len = tokenize_prompt_default()
conv_field = self.get_conversation_thread() conversation: Conversation = (
for prmpt in prompt[conv_field]: self.prompter._conversation.copy() # pylint: disable=protected-access
result, current_len = tokenize_prompt_default() )
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
conversation: Conversation = (
self.prompter._conversation # pylint: disable=protected-access
)
try:
for _, part in enumerate(
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
if not part[1].strip():
err_msg = f"user turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
err_msg = f"assistant turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
err_msg = f"unhandled role: {part[0]}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
continue
# pylint: disable=duplicate-code # support for custom roles from the dataset, only useful for vicuna style prompts/roles
result, current_len = parse_tokenized_to_result( role_remap = []
result, if (
current_len, conversation.name == "vicuna_v1.1"
res, and "roles" in prompt
labels, and len(prompt["roles"]) >= 2
pad_token_id=self.tokenizer.pad_token_id, ):
) role_remap = [
for key, val in sorted(result.items(), key=lambda x: x[0]): {"from": conversation.roles[0], "to": prompt["roles"][0]},
tokenized_res[key].append(val) {"from": conversation.roles[1], "to": prompt["roles"][1]},
except (KeyError, AssertionError, IndexError) as err: ]
raise InvalidDataException(str(err)) from err
except ValueError as err: try:
LOG.warning("skipping prompt: %s", str(err)) for _, part in enumerate(
return tokenized_res self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
role = (
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this is still the user query, we should
if not part[1].strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
role = (
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {part[0]}")
continue
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
if not prompt.strip(): if not prompt.strip():

View File

@@ -274,9 +274,11 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
raise err raise err
conv.messages = [] conv.messages = []
for j, sentence in enumerate(source): for _, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
if role != conv.roles[j % 2]: if len(conv.messages) > 0 and (
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])

View File

@@ -58,9 +58,7 @@ def train(
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
if ( if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
) or cfg.resume_from_checkpoint is True:
possible_checkpoints = [ possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
] ]
@@ -73,9 +71,7 @@ def train(
LOG.info( LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
) )
resume_from_checkpoint = ( resume_from_checkpoint = cfg.resume_from_checkpoint
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
)
trainer = setup_trainer( trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps

View File

@@ -514,3 +514,27 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
return control return control
return LogPredictionCallback return LogPredictionCallback
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
"""Callback to save axolotl config to wandb"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
artifact = wandb.Artifact(name="axolotl-config", type="config")
artifact.add_file(local_path=self.axolotl_config_path)
wandb.run.log_artifact(artifact)
LOG.info("Axolotl config has been saved to WandB as an artifact.")
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control

View File

@@ -121,6 +121,18 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.adapter is not None:
for key in list(cfg.keys()):
if key.startswith("lora_"):
new_key = key.replace("lora_", "peft_")
LOG.warning(
PendingDeprecationWarning(
f"{key} soon to be deprecated. please use {new_key}"
)
)
cfg[new_key] = cfg[key]
del cfg[key]
def validate_config(cfg): def validate_config(cfg):
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():
@@ -190,7 +202,10 @@ def validate_config(cfg):
raise ValueError("Require cfg.load_in_4bit to be True for qlora") raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")
if not cfg.load_in_8bit and cfg.adapter == "ia3":
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"): if cfg.adapter not in ("lora", "qlora"):

View File

@@ -16,6 +16,7 @@ from datasets import (
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
@@ -44,7 +45,6 @@ from axolotl.utils.trainer import (
) )
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def md5(to_hash: str, encoding: str = "utf-8") -> str: def md5(to_hash: str, encoding: str = "utf-8") -> str:
@@ -158,23 +158,21 @@ def load_tokenized_prepared_datasets(
token=use_auth_token, token=use_auth_token,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ValueError): except (FileNotFoundError, ConnectionError):
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(d.path) local_path = Path(d.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
if not d.type: # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_from_disk(d.path) ds = load_dataset(
else: d.path,
ds = load_dataset( name=d.name,
d.path, data_files=d.data_files,
name=d.name, streaming=False,
data_files=d.data_files, split=None,
streaming=False, )
split=None,
)
elif local_path.is_file(): elif local_path.is_file():
ds_type = "json" ds_type = "json"
if d.ds_type: if d.ds_type:
@@ -359,7 +357,7 @@ def load_tokenized_prepared_datasets(
if len(datasets) > 1: if len(datasets) > 1:
LOG.info("shuffle merged datasets") LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed) dataset = dataset.shuffle(seed=seed)
if cfg.local_rank == 0 and cfg.dataset_prepared_path: if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:

View File

@@ -180,6 +180,26 @@ def load_model(
LOG.info("patching with flash attention") LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_mistral_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
if cfg.is_llama_derived_model and cfg.xpos_rope: if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope, replace_llama_rope_with_xpos_rope,
@@ -382,25 +402,25 @@ def load_model(
if model_config.model_type == "btlm": if model_config.model_type == "btlm":
# don't upcast lm_head for btlm # don't upcast lm_head for btlm
continue continue
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch.float32) module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp require_peft: bool = False
if (cfg.adapter == "lora" and load_in_8bit) or ( if cfg.adapter in ["lora", "qlora", "ia3"]:
cfg.adapter == "qlora" and cfg.load_in_4bit require_peft = True
):
if require_peft:
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training( model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing model, use_gradient_checkpointing=cfg.gradient_checkpointing
) )
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility. # convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model): if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules(): for name, module in model.named_modules():
if "norm" in name: if "norm" in name:
@@ -409,7 +429,7 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(cfg.torch_dtype) module.to(cfg.torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, peft_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
@@ -440,7 +460,7 @@ def load_model(
log_gpu_memory_usage(LOG, "after adapters", model.device) log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, peft_config
def load_adapter(model, cfg, adapter, inference=False): def load_adapter(model, cfg, adapter, inference=False):
@@ -450,6 +470,8 @@ def load_adapter(model, cfg, adapter, inference=False):
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()
if adapter == "ia3":
return load_ia3(model, cfg, inference=inference)
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
@@ -468,11 +490,11 @@ def load_llama_adapter(model, cfg):
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
if cfg.lora_model_dir: if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter") LOG.debug("Loading pretained PEFT - llama_adapter")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.peft_model_dir,
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
else: else:
@@ -485,16 +507,20 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set() peft_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, cls) or "Linear" in module.__class__.__name__: if (
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
names = name.split(".") names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) peft_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit if "lm_head" in peft_module_names: # needed for 16-bit
lora_module_names.remove("lm_head") peft_module_names.remove("lm_head")
return list(lora_module_names) return list(peft_module_names)
def load_lora(model, cfg, inference=False): def load_lora(model, cfg, inference=False):
@@ -502,34 +528,68 @@ def load_lora(model, cfg, inference=False):
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or []) peft_target_modules = list(cfg.peft_target_modules or [])
if cfg.lora_target_linear: if cfg.peft_target_linear:
linear_names = find_all_linear_names(model) linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}") LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names)) peft_target_modules = list(set(peft_target_modules + linear_names))
lora_config = LoraConfig( peft_config = LoraConfig(
r=cfg.lora_r, r=cfg.peft_r,
lora_alpha=cfg.lora_alpha, lora_alpha=cfg.peft_alpha,
target_modules=lora_target_modules, target_modules=peft_target_modules,
lora_dropout=cfg.lora_dropout, lora_dropout=cfg.peft_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out, fan_in_fan_out=cfg.peft_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
bias="none", bias="none",
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
if cfg.lora_model_dir: if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - LoRA") LOG.debug("Loading pretained PEFT - LoRA")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.peft_model_dir,
is_trainable=(not inference), is_trainable=(not inference),
) )
else: else:
model = get_peft_model(model, lora_config) model = get_peft_model(model, peft_config)
model.print_trainable_parameters() model.print_trainable_parameters()
return model, lora_config return model, peft_config
def load_ia3(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import IA3Config, PeftModel, get_peft_model
peft_config_kwargs = {}
if cfg.peft_init_ia3_weights is not None:
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
if cfg.peft_fan_in_fan_out is not None:
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
peft_config = IA3Config(
target_modules=cfg.peft_target_modules,
feedforward_modules=cfg.peft_feedforward_modules,
modules_to_save=cfg.peft_modules_to_save,
task_type="CAUSAL_LM",
**peft_config_kwargs,
)
if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - IA3")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
is_trainable=(not inference),
)
else:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config

View File

@@ -30,6 +30,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
@@ -422,7 +423,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
) )
# Phi doesn't want the attention_mask feature when training # Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__: if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
cfg.is_mistral_derived_model and cfg.flash_attention
):
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")
@@ -775,6 +778,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg)) trainer.add_callback(LogPredictionCallback(cfg))
if cfg.use_wandb:
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
if cfg.do_bench_eval: if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer)) trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))

View File

@@ -24,6 +24,10 @@ class TestLoraLlama(unittest.TestCase):
""" """
def test_lora(self): def test_lora(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
@@ -66,6 +70,101 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_peft(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ia3_peft(self):
"""
support for IA3 peft
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "ia3",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
"peft_feedforward_modules": ["down_proj"],
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_packing(self): def test_lora_packing(self):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,48 @@
"""Module for testing the validation module"""
import logging
import unittest
from typing import Optional
import pytest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizationTest(unittest.TestCase):
"""
Test the cfg normalization module
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_lora_to_peft(self):
base_cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"base_model": "NousResearch/Llama-2-7b-hf",
"base_model_config": "NousResearch/Llama-2-7b-hf",
}
)
cfg = base_cfg | DictDefault(
{
"adapter": "lora",
"lora_r": 128,
"lora_alpha": 64,
}
)
with self._caplog.at_level(logging.WARNING):
normalize_config(cfg)
assert any(
"soon to be deprecated. please use peft_" in record.message
for record in self._caplog.records
)
assert cfg.peft_r == 128
assert cfg.peft_alpha == 64

View File

@@ -90,6 +90,73 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
strat.tokenize_prompt(conversation) strat.tokenize_prompt(conversation)
assert "assistant turn has empty text" in self._caplog.records[1].message assert "assistant turn has empty text" in self._caplog.records[1].message
def test_sharegpt_warnings_turns(self):
conversation = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
]
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
strat.tokenize_prompt(conversation)
assert (
"Role did not alternate between turns (gpt and human)"
in self._caplog.records[0].message
)
def test_sharegpt_changes_roles(self):
conversation = {
"roles": ["USER", "CHARACTER"],
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
],
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
def test_sharegpt_assistant_label_ignore(self):
conversation = {
"roles": ["user", "assistant"],
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
],
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
idx = res["input_ids"].index(20255) # assistant token
assert res["labels"][idx] == -100
def test_no_sys_prompt(self): def test_no_sys_prompt(self):
""" """
tests the interface between the user and assistant parts tests the interface between the user and assistant parts