Compare commits

..

18 Commits

Author SHA1 Message Date
Wing Lian
3ce9b0760b fix the lora yaml for l3 2024-04-19 07:28:07 -04:00
Wing Lian
c10563c444 fix broken linting (#1541)
* chore: lint

* include examples in yaml check

* mistral decided to gate their models...

* more mistral models that were gated
2024-04-19 01:03:04 -04:00
Monk (looking for PhD Fall’24)
37c037c69d Adding Llama-3 qlora (#1536)
* Create qlora.yml

* Update qlora.yml
2024-04-18 21:27:32 +02:00
Wing Lian
15f7910d33 llama-3 examples (#1537) 2024-04-18 14:28:03 -04:00
NanoCode012
d28ba2e405 feat(doc): Add example for pad_token (#1535) 2024-04-19 02:20:20 +09:00
Atlas
0eadfc8c86 Create mixtral_22.yml (#1514) [skip ci]
Code sourced from here:

https://twitter.com/mattshumer_/status/1778135774887567712
2024-04-17 01:16:00 -04:00
Atlas
bcaa92325d Update Readme to include support for Mixtral8X22B (#1518) [skip ci] 2024-04-17 01:15:30 -04:00
YTING
7d9bafcb88 Update README.md (#1521) [skip ci] 2024-04-17 01:15:05 -04:00
Wing Lian
e07dcb288c add docs around pre-processing (#1529) 2024-04-16 19:45:46 -04:00
Wing Lian
6319da1f9b Unsloth gradient checkpointing offload (#1528)
* unsloth gradient checkpointing

* fix validation too

* fixes to make it work with mistral

* monkeypatch the checkpoint fn earlier
2024-04-16 14:53:57 -04:00
Wing Lian
132eb740f0 DBRX Model Support (#1462)
* wip for dbrx finetuning

* add fastcore for parallel loading of sharded weights

* fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback

* update to use v2 of the converted model

* more fixes for dbrx loras

* make sure to enable fsdp activation checkpointing

* fix support for 8bit loras too for dbrx

* apply z3 leaf moe fix for DBRX with deepspeed

* don't raise value error since child module searches could fail and be ok

* revert a previous change to fix fsdp

* update mistral/mistral qlora+fsdp yamls

* fix qlora+fsdp quant storage type

* more edge cases for qlora-fsdp

* fixes for fsdp+qlora w optimizer in 8bit

* add bigstral z3 config and make sure to use full_state_dict for fsdp
2024-04-12 09:02:36 -04:00
Thomas Capelle
5ed29393e3 Update SaveAxolotlConfigtoWandBCallback to use artifact instead of save (#1483)
* deprecated wandb.save

* also use wandb.save for axolotl yaml

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-04-09 18:58:38 -04:00
Wing Lian
da9b1a3196 use locale agnostic seperator to make large nums easier to read (#1503) 2024-04-09 17:28:43 -04:00
DavidFarago
057fa44191 WIP: Support table logging for mlflow, too (#1506)
* WIP: Support table logging for mlflow, too

Create a `LogPredictionCallback` for both "wandb" and "mlflow" if
specified.

In `log_prediction_callback_factory`, create a generic table and make it
specific only if the newly added `logger` argument is set to "wandb"
resp. "mlflow".

See https://github.com/OpenAccess-AI-Collective/axolotl/issues/1505

* chore: lint

* add additional clause for mlflow as it's optional

* Fix circular imports

---------

Co-authored-by: Dave Farago <dfarago@innoopract.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-04-09 17:28:27 -04:00
Scott Fleming
8fa0785f74 Correctly handle splits for datasets.arrow_dataset.Dataset objects (#1504)
* Correctly handle splits for datasets.arrow_dataset.Dataset objects

The `load_tokenized_prepared_datasets` function currently has logic for loading a dataset from local path that always checks if a split is in the dataset. The problem is, if the dataset is loaded using `load_from_disk` and it is an Arrow-based dataset, *there is no* split information. Instead what happens is, by calling `split in ds`, it presumably searches through all the rows and columns of the arrow dataset object to find e.g., 'train' assuming `split == 'train'`. This causes the program to hang.

See https://chat.openai.com/share/0d567dbd-d60b-4079-9040-e1de58a4dff3 for context.

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-04-09 16:40:26 -04:00
Wing Lian
4313b1a6a0 Print versions (#1496)
* print out dependency versions for easier debugging

* improve readability
2024-04-09 11:05:15 -04:00
Maziyar Panahi
7f17eff81a Fix the wrong adapter in qwen2-moe-qlora example (#1501) [skip ci]
It should be `qlora` instead of `lora`
2024-04-09 10:57:24 -04:00
Wing Lian
ff01c45127 add field to sft dataset pydantic for completion support (#1497) 2024-04-08 21:37:54 -04:00
45 changed files with 1435 additions and 99 deletions

View File

@@ -7,6 +7,7 @@ on:
- 'requirements.txt'
- '.github/workflows/*.yml'
- "*.md"
- "examples/**/*.y[a]?ml"
workflow_dispatch:
jobs:

View File

@@ -44,6 +44,7 @@ Features:
- Advanced Topics
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
@@ -81,6 +82,7 @@ Features:
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
@@ -425,7 +427,7 @@ deepspeed: deepspeed_configs/zero1.json
```
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
```
##### FSDP

View File

@@ -1,4 +1,6 @@
{
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {

View File

@@ -1,4 +1,6 @@
{
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": {
"stage": 3,
"offload_param": {

View File

@@ -412,6 +412,7 @@ special_tokens:
# bos_token: "<s>"
# eos_token: "</s>"
# unk_token: "<unk>"
# pad_token: "[PAD]"
# Add extra tokens.
tokens:

View File

@@ -0,0 +1,35 @@
---
title: Dataset Preprocessing
description: How datasets are processed
---
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
the (dataset format)[../dataset-formats/] and prompt strategies to:
- parse the dataset based on the *dataset format*
- transform the dataset to how you would interact with the model based on the *prompt strategy*
- tokenize the dataset based on the configured model & tokenizer
- shuffle and merge multiple datasets together if using more than one
The processing of the datasets can happen one of two ways:
1. Before kicking off training by calling `python -m axolotl.cli.preprocess /path/to/your.yaml --debug`
2. When training is started
What are the benefits of pre-processing? When training interactively or for sweeps
(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly
slow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent
training parameters so that it will intelligently pull from its cache when possible.
The path of the cache is controlled by `dataset_prepared_path:` and is often left blank in example
YAMLs as this leads to a more robust solution that prevents unexpectedly reusing cached data.
If `dataset_prepared_path:` is left empty, when training, the processed dataset will be cached in a
default path of `./last_run_prepared/`, but will ignore anything already cached there. By explicitly
setting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed
data is in the cache.
What are the edge cases? Let's say you are writing a custom prompt strategy or using a user-defined
prompt template. Because the trainer cannot readily detect these changes, we cannot change the
calculated hash value for the pre-processed dataset. If you have `dataset_prepared_path: ...` set
and change your prompt templating logic, it may not pick up the changes you made and you will be
training over the old prompt.

View File

@@ -0,0 +1,81 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true

View File

@@ -0,0 +1,81 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true

26
examples/dbrx/README.md Normal file
View File

@@ -0,0 +1,26 @@
# DBRX MoE
Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
results in the trainer hanging.
### FSDP
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
- 16-bit LoRA w/ FSDP
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
- ✅ 8-bit LoRA w/ FSDP
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
### Deepspeed
WIP

View File

@@ -0,0 +1,56 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
unfrozen_parameters:
- transformer.blocks.[0-7].
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
deepspeed: deepspeed_configs/zero3_bf16.json

View File

@@ -65,12 +65,14 @@ deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:

View File

@@ -0,0 +1,13 @@
# Llama-3
https://llama.meta.com/llama3/
[8B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
- [Full Fine Tune](./fft-8b.yaml)
- Single GPU @ 48GB VRAM
- [LoRA](./lora-8b.yml)
- Single GPU @ 11GB VRAM
[70B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
- [QLORA+FSDP](./qlora-fsdp-70b.yaml)
- Dual GPU @ 21GB VRAM

View File

@@ -0,0 +1,58 @@
base_model: meta-llama/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,67 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,80 @@
base_model: casperhansen/llama-3-70b-fp16
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer # PreTrainedTokenizerFast
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out/qlora-llama3-70b
adapter: qlora
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,67 @@
base_model: meta-llama/Meta-Llama-3-8B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: aaditya/alpaca_subset_1
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -0,0 +1,63 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
- model.layers.4[4-9]+.block_sparse_moe.gate
- model.layers.4[4-9]+.block_sparse_moe.experts
- model.layers.5[0-5]+.block_sparse_moe.gate
- model.layers.5[0-5]+.block_sparse_moe.experts
model_config:
output_router_logits: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
save_total_limit: 1
save_steps:
debug:
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"

View File

@@ -0,0 +1,82 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: false
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:

View File

@@ -0,0 +1,81 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
@@ -47,7 +47,7 @@ train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
@@ -69,6 +69,17 @@ debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_sharding_strategy: FULL_SHARD
fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:

View File

@@ -0,0 +1,61 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
- model.layers.4[4-9]+.block_sparse_moe.gate
- model.layers.4[4-9]+.block_sparse_moe.experts
- model.layers.5[0-5]+.block_sparse_moe.gate
- model.layers.5[0-5]+.block_sparse_moe.experts
model_config:
output_router_logits: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
output_dir: ./out
sequence_len: 8000
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
save_total_limit: 1
save_steps:
debug:
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"

View File

@@ -16,7 +16,7 @@ sequence_len: 1024 # supports up to 32k
sample_packing: false
pad_to_sequence_len: false
adapter: lora
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 16

View File

@@ -41,3 +41,4 @@ gcsfs
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
zstandard==0.22.0
fastcore

View File

@@ -24,6 +24,7 @@ from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
@@ -62,6 +63,20 @@ def print_axolotl_text_art(suffix=None):
if is_main_process():
print(ascii_art)
print_dep_versions()
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
print("*" * 40)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file

View File

@@ -36,6 +36,7 @@ from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -71,10 +72,6 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
@@ -921,10 +918,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
):
callbacks.append(SaveBetterTransformerModelCallback())
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
@@ -943,7 +936,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))

View File

@@ -516,24 +516,18 @@ def mistral_model_forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
layer_outputs = (
self._gradient_checkpointing_func( # pylint: disable=protected-access
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
)
else:
layer_outputs = decoder_layer(

View File

@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Union
import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import Dataset
from peft import PeftModel
@@ -81,6 +82,8 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True

View File

@@ -0,0 +1,8 @@
"""
Basic utils for Axolotl
"""
import importlib
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None

View File

@@ -6,7 +6,7 @@ import logging
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List
import evaluate
import numpy as np
@@ -27,7 +27,9 @@ from transformers import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import (
barrier,
broadcast_dict,
@@ -540,7 +542,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
@@ -597,15 +599,13 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
return ranges
def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table( # type: ignore[attr-defined]
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
table_data: Dict[str, List[Any]] = {
"id": [],
"Prompt": [],
"Correct Completion": [],
"Predicted Completion (model.generate)": [],
"Predicted Completion (trainer.prediction_step)": [],
}
row_index = 0
for batch in tqdm(table_dataloader):
@@ -709,16 +709,29 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table.add_data(
row_index,
prompt_text,
completion_text,
prediction_text,
pred_step_text,
table_data["id"].append(row_index)
table_data["Prompt"].append(prompt_text)
table_data["Correct Completion"].append(completion_text)
table_data["Predicted Completion (model.generate)"].append(
prediction_text
)
table_data[
"Predicted Completion (trainer.prediction_step)"
].append(pred_step_text)
row_index += 1
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
elif logger == "mlflow" and is_mlflow_available():
import mlflow
wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
tracking_uri = AxolotlInputConfig(
**self.cfg.to_dict()
).mlflow_tracking_uri
mlflow.log_table(
data=table_data,
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
@@ -748,6 +761,11 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
artifact = wandb.Artifact(
f"config-{wandb.run.id}", type="axolotl-config"
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."

View File

@@ -98,6 +98,7 @@ class SFTDataset(BaseModel):
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
@@ -258,6 +259,7 @@ class ModelInputConfig(BaseModel):
base_model: str
base_model_config: Optional[str] = None
cls_model_config: Optional[str] = None
tokenizer_config: Optional[str] = None
tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None
@@ -477,6 +479,7 @@ class AxolotlInputConfig(
eval_causal_lm_metrics: Optional[List[str]] = None
do_bench_eval: Optional[bool] = None
bench_dataset: Optional[str] = None
bench_split: Optional[str] = None
metric_for_best_model: Optional[str] = None
greater_is_better: Optional[bool] = None
@@ -492,7 +495,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[bool] = Field(default=False)
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
default=False
)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
@@ -970,9 +975,16 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_fsdp_w_8bit_optimizer(cls, data):
if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
def check_fsdp_offload_w_8bit_optimizer(cls, data):
if (
data.get("fsdp")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params")
):
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
)
return data
@model_validator(mode="before")

View File

@@ -379,14 +379,15 @@ def load_tokenized_prepared_datasets(
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if config_dataset.split and config_dataset.split in ds:
ds = ds[config_dataset.split]
elif split in ds:
ds = ds[split]
elif isinstance(ds, DatasetDict):
raise ValueError(
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)
if isinstance(ds, DatasetDict):
if config_dataset.split and config_dataset.split in ds:
ds = ds[config_dataset.split]
elif split in ds:
ds = ds[split]
else:
raise ValueError(
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)
# support for using a subset of the data
if config_dataset.shards:

View File

@@ -4,27 +4,25 @@ utility helpers for distributed checks
import os
import pickle # nosec
from contextlib import contextmanager
from datetime import timedelta
import torch
import torch.distributed as dist
from accelerate import Accelerator
from accelerate import PartialState
accelerate = None # pylint: disable=invalid-name
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
distributed_state = None # pylint: disable=invalid-name
def is_distributed():
"""
Check if distributed training is initialized.
"""
global accelerate # pylint: disable=global-statement
if not accelerate:
accelerate = Accelerator()
return dist.is_available() and dist.is_initialized()
global distributed_state # pylint: disable=global-statement
if not distributed_state:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
return distributed_state.use_distributed and distributed_state.initialized
def barrier():

View File

@@ -0,0 +1,13 @@
"""custom checkpointing utils"""
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer.__self__,
*args,
)

View File

@@ -0,0 +1,52 @@
"""Unsloth checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)

View File

@@ -0,0 +1,259 @@
"""
module to handle loading model on cpu/meta device for FSDP
"""
import os
import time
from typing import List, Optional, Type, Union
import safetensors
import torch
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
def _replace_linear(
model: nn.Module,
linear_replacement: Type[nn.Module],
quant_config: Union[dict, None] = None,
skip_modules=None,
**kwargs,
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
if skip_modules is None:
skip_modules = ["lm_head"]
for name, module in model.named_children():
if len(list(module.children())) > 0:
_replace_linear(
module, linear_replacement, quant_config, skip_modules, **kwargs
)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access
name
] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs,
)
else:
raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}"
)
return model
def load_and_quantize(
module: nn.Module,
name: str,
value: Tensor,
device: torch.device = None,
dtype: torch.dtype = None,
skip_names: Optional[List[str]] = None,
to_cpu: bool = False,
to_meta: bool = False,
verbose: bool = False,
quant_method: str = "bnb",
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
"""
if not skip_names:
skip_names = []
def place_on_device(value):
if to_meta:
device = "meta"
elif to_cpu:
device = "cpu"
return value.to(device=device, dtype=dtype)
if any(skip_name in name for skip_name in skip_names):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition(".")
try:
submodule = module.get_submodule(module_key)
except AttributeError as exc:
print(f"Module {module_key} not found:\n{exc}")
return
try:
if quant_method == "bnb":
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
if to_meta:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif to_cpu:
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
value = type(param)(place_on_device(value).data)
except AttributeError:
# it's a buffer
value = place_on_device(value)
setattr(submodule, value_key, value)
def n_loading_workers(quant_method: str, param_count: float):
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
model_params_b = 70
right = int(
(4 if quant_method == "hqq" else 8)
* (devprops.total_memory / 1e9 / 40)
* (model_params_b / (param_count / 1e9))
)
return min(left, right)
def load_sharded_model(
model_name,
model_config,
cfg,
torch_dtype=torch.bfloat16,
low_memory=True,
):
if (low_memory and cfg.local_rank == 0) or not low_memory:
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_cache=False,
torch_dtype=torch.float32,
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
trust_remote_code=cfg.trust_remote_code,
)
dtype = torch_dtype if not cfg.float32 else None
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
else:
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code,
)
return model
def load_sharded_model_quant(
model_name,
model_config,
cfg,
compute_dtype=torch.bfloat16,
quant_storage=torch.float32,
low_memory=True,
verbose=False,
loading_workers=2,
):
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
trust_remote_code=cfg.trust_remote_code,
)
if hasattr(model, "transformer"):
model.transformer = _replace_linear(
model.transformer,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
)
else:
# this is the more common case with HF transformers
model.model = _replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
)
model.is_loaded_in_4bit = True
# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
except OSError as exc:
# This means the model probably doesn't have a safetensors file
raise exc
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)
quant_method = "bnb"
param_count = sum((p.numel() for n, p in model.named_parameters()))
n_workers = (
n_loading_workers(quant_method, param_count)
if loading_workers == -1
else loading_workers
)
if cfg.local_rank == 0 and verbose:
print(f"Using n_workers: {n_workers} for loading")
start = time.time()
for filename in tqdm(
files,
desc="Loading & Quantizing Model Shards",
disable=cfg.local_rank != 0,
position=0,
):
weights = safetensors.torch.load_file(filename)
parallel(
load_and_quantize_parallel,
iter(weights.items()),
n_workers=n_workers,
threadpool=True,
model=model,
dtype=quant_storage,
device=cfg.local_rank,
skip_names=[],
to_cpu=(low_memory and cfg.local_rank == 0),
to_meta=(low_memory and cfg.local_rank != 0),
verbose=verbose,
quant_method=quant_method,
)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache()
return model

View File

@@ -11,6 +11,7 @@ import addict
import bitsandbytes as bnb
import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from bitsandbytes.nn import Params4bit
from peft import (
@@ -44,11 +45,37 @@ from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger("axolotl")
# copied from accelerator.FullyShardedDataParallelPlugin
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
if len(modules_children) == 0:
return None
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
return None
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
quant_config_exists = (
hasattr(model_config, "quantization_config")
@@ -285,6 +312,9 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
if cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
@@ -459,7 +489,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if not cfg.deepspeed:
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -470,6 +500,13 @@ def load_model(
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif cfg.adapter == "lora" and cfg.load_in_8bit:
bnb_config = {
"load_in_8bit": True,
}
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True
@@ -517,7 +554,31 @@ def load_model(
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
try:
skip_move_to_device = False
if (
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
) and not qlora_fsdp:
model = load_sharded_model(
base_model,
model_config,
cfg,
torch_dtype=cfg.torch_dtype,
)
skip_move_to_device = True
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
):
quant_storage = cfg.torch_dtype
model = load_sharded_model_quant(
base_model,
model_config,
cfg,
quant_storage=quant_storage,
)
skip_move_to_device = True
elif (
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
@@ -597,6 +658,11 @@ def load_model(
**model_kwargs,
)
else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
@@ -670,13 +736,17 @@ def load_model(
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled():
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if cfg.model_config_type == "mixtral":
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
set_z3_leaf_modules(model, [moe_block])
elif cfg.model_config_type == "dbrx":
moe_block = get_module_class_from_name(model, "DbrxFFN")
set_z3_leaf_modules(model, [moe_block])
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -686,7 +756,8 @@ def load_model(
if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True
if qlora_fsdp:
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
# make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]:
@@ -727,7 +798,7 @@ def load_model(
cfg.ddp
and not load_in_8bit
and not (cfg.rl and cfg.load_in_4bit)
and not qlora_fsdp
and not skip_move_to_device
):
# TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}")
@@ -883,7 +954,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
rank = int(os.environ.get("LOCAL_RANK", 0))
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
if (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_meta_for_peft(model)
if cfg.lora_model_dir:
@@ -908,7 +984,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s", exc
)
elif cfg.fsdp and cfg.adapter == "qlora":
elif (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_peft_meta_for_training(model)
return model, lora_config

View File

@@ -198,7 +198,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
if update:
cfg.total_num_tokens = total_num_tokens
@@ -212,7 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens}`",
f"`total_supervised_tokens: {total_supervised_tokens:_}`",
main_process_only=True,
)
if update:
@@ -239,7 +239,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
* cfg.num_epochs
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
main_process_only=True,
)
else:
@@ -306,6 +306,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_activation_checkpointing:
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
if cfg.fsdp_config.fsdp_offload_params:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states:

View File

@@ -7,8 +7,6 @@ import os
import unittest
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -21,7 +19,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("Skipping test due to timeout.")
class TestLlamaShiftedSparseAttention(unittest.TestCase):
"""
Test case for Llama models using S2 Attn

View File

@@ -30,7 +30,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
@@ -74,7 +74,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -22,7 +22,7 @@ class TestModelPatches(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -33,7 +33,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sequence_len": 1024,
"load_in_4bit": True,
@@ -87,7 +87,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 1024,
"load_in_4bit": True,
@@ -141,7 +141,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sequence_len": 1024,
"adapter": "lora",
@@ -198,7 +198,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 1024,
"adapter": "lora",
@@ -255,7 +255,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.1,

View File

@@ -27,7 +27,9 @@ def fixture_alpaca_dataset():
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
# pylint: disable=all
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer = AutoTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(

View File

@@ -43,7 +43,9 @@ def fixture_sharegpt_dataset():
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer = AutoTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_tokens(
[
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),

View File

@@ -96,7 +96,9 @@ def fixture_multi_role_dataset():
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer = AutoTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(

View File

@@ -454,7 +454,9 @@ class OrpoTokenizationTest(unittest.TestCase):
def setUp(self) -> None:
# pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer = LlamaTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(