Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
3202f19f52 add save_only_model arg 2024-04-10 16:09:08 -04:00
39 changed files with 65 additions and 1351 deletions

View File

@@ -7,7 +7,6 @@ on:
- 'requirements.txt' - 'requirements.txt'
- '.github/workflows/*.yml' - '.github/workflows/*.yml'
- "*.md" - "*.md"
- "examples/**/*.y[a]?ml"
workflow_dispatch: workflow_dispatch:
jobs: jobs:

View File

@@ -44,7 +44,6 @@ Features:
- Advanced Topics - Advanced Topics
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg> - [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg> - [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-) - [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training) - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl) - [Debugging Axolotl](#debugging-axolotl)
@@ -82,7 +81,6 @@ Features:
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ | | Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ | | Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ | | cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ | | btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
@@ -427,7 +425,7 @@ deepspeed: deepspeed_configs/zero1.json
``` ```
```shell ```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
``` ```
##### FSDP ##### FSDP

View File

@@ -1,6 +1,4 @@
{ {
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
"offload_optimizer": { "offload_optimizer": {

View File

@@ -1,6 +1,4 @@
{ {
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
"offload_param": { "offload_param": {

View File

@@ -412,7 +412,6 @@ special_tokens:
# bos_token: "<s>" # bos_token: "<s>"
# eos_token: "</s>" # eos_token: "</s>"
# unk_token: "<unk>" # unk_token: "<unk>"
# pad_token: "[PAD]"
# Add extra tokens. # Add extra tokens.
tokens: tokens:

View File

@@ -1,35 +0,0 @@
---
title: Dataset Preprocessing
description: How datasets are processed
---
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
the (dataset format)[../dataset-formats/] and prompt strategies to:
- parse the dataset based on the *dataset format*
- transform the dataset to how you would interact with the model based on the *prompt strategy*
- tokenize the dataset based on the configured model & tokenizer
- shuffle and merge multiple datasets together if using more than one
The processing of the datasets can happen one of two ways:
1. Before kicking off training by calling `python -m axolotl.cli.preprocess /path/to/your.yaml --debug`
2. When training is started
What are the benefits of pre-processing? When training interactively or for sweeps
(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly
slow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent
training parameters so that it will intelligently pull from its cache when possible.
The path of the cache is controlled by `dataset_prepared_path:` and is often left blank in example
YAMLs as this leads to a more robust solution that prevents unexpectedly reusing cached data.
If `dataset_prepared_path:` is left empty, when training, the processed dataset will be cached in a
default path of `./last_run_prepared/`, but will ignore anything already cached there. By explicitly
setting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed
data is in the cache.
What are the edge cases? Let's say you are writing a custom prompt strategy or using a user-defined
prompt template. Because the trainer cannot readily detect these changes, we cannot change the
calculated hash value for the pre-processed dataset. If you have `dataset_prepared_path: ...` set
and change your prompt templating logic, it may not pick up the changes you made and you will be
training over the old prompt.

View File

@@ -1,81 +0,0 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true

View File

@@ -1,81 +0,0 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true

View File

@@ -1,26 +0,0 @@
# DBRX MoE
Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
results in the trainer hanging.
### FSDP
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
- 16-bit LoRA w/ FSDP
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
- ✅ 8-bit LoRA w/ FSDP
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
### Deepspeed
WIP

View File

@@ -1,56 +0,0 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
unfrozen_parameters:
- transformer.blocks.[0-7].
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
deepspeed: deepspeed_configs/zero3_bf16.json

View File

@@ -65,14 +65,12 @@ deepspeed:
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
- full_shard - full_shard
- auto_wrap
fsdp_config: fsdp_config:
fsdp_limit_all_gathers: true fsdp_limit_all_gathers: true
fsdp_sync_module_states: true fsdp_sync_module_states: true
fsdp_offload_params: true fsdp_offload_params: true
fsdp_use_orig_params: false fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: SHARDED_STATE_DICT
special_tokens: special_tokens:

View File

@@ -1,13 +0,0 @@
# Llama-3
https://llama.meta.com/llama3/
[8B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
- [Full Fine Tune](./fft-8b.yaml)
- Single GPU @ 48GB VRAM
- [LoRA](./lora-8b.yml)
- Single GPU @ 11GB VRAM
[70B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
- [QLORA+FSDP](./qlora-fsdp-70b.yaml)
- Dual GPU @ 21GB VRAM

View File

@@ -1,58 +0,0 @@
base_model: meta-llama/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,67 +0,0 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,80 +0,0 @@
base_model: casperhansen/llama-3-70b-fp16
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer # PreTrainedTokenizerFast
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out/qlora-llama3-70b
adapter: qlora
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,67 +0,0 @@
base_model: meta-llama/Meta-Llama-3-8B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: aaditya/alpaca_subset_1
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,63 +0,0 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
- model.layers.4[4-9]+.block_sparse_moe.gate
- model.layers.4[4-9]+.block_sparse_moe.experts
- model.layers.5[0-5]+.block_sparse_moe.gate
- model.layers.5[0-5]+.block_sparse_moe.experts
model_config:
output_router_logits: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
save_total_limit: 1
save_steps:
debug:
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"

View File

@@ -1,82 +0,0 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: false
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:

View File

@@ -1,81 +0,0 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 1 num_epochs: 1
optimizer: adamw_torch optimizer: paged_adamw_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -47,7 +47,7 @@ train_on_inputs: false
group_by_length: false group_by_length: false
bf16: auto bf16: auto
fp16: fp16:
tf32: true tf32: false
gradient_checkpointing: true gradient_checkpointing: true
early_stopping_patience: early_stopping_patience:
@@ -69,17 +69,6 @@ debug:
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
- full_shard - full_shard
- auto_wrap
fsdp_config: fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_sharding_strategy: FULL_SHARD
fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens: special_tokens:

View File

@@ -1,61 +0,0 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
- model.layers.4[4-9]+.block_sparse_moe.gate
- model.layers.4[4-9]+.block_sparse_moe.experts
- model.layers.5[0-5]+.block_sparse_moe.gate
- model.layers.5[0-5]+.block_sparse_moe.experts
model_config:
output_router_logits: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
output_dir: ./out
sequence_len: 8000
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
save_total_limit: 1
save_steps:
debug:
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"

View File

@@ -41,4 +41,3 @@ gcsfs
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
zstandard==0.22.0 zstandard==0.22.0
fastcore

View File

@@ -918,6 +918,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
): ):
callbacks.append(SaveBetterTransformerModelCallback()) callbacks.append(SaveBetterTransformerModelCallback())
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available(): if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import ( from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoMlflowCallback,
@@ -1054,6 +1058,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.save_safetensors is not None: if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.save_only_model is not None:
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.sample_packing_eff_est: if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[ training_arguments_kwargs[
"sample_packing_efficiency" "sample_packing_efficiency"

View File

@@ -516,18 +516,24 @@ def mistral_model_forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = (
self._gradient_checkpointing_func( # pylint: disable=protected-access def create_custom_forward(module):
decoder_layer.__call__, def custom_forward(*inputs):
hidden_states, # None for past_key_value
attention_mask, return module(*inputs)
position_ids,
past_key_value, return custom_forward
output_attentions,
None, layer_outputs = torch.utils.checkpoint.checkpoint(
cu_seqlens, create_custom_forward(decoder_layer),
max_seqlen, hidden_states,
) attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(

View File

@@ -9,7 +9,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from peft import PeftModel from peft import PeftModel
@@ -82,8 +81,6 @@ def train(
if cfg.adapter: if cfg.adapter:
msg += " and peft_config..." msg += " and peft_config..."
LOG.debug(msg) LOG.debug(msg)
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True model.generation_config.do_sample = True

View File

@@ -259,7 +259,6 @@ class ModelInputConfig(BaseModel):
base_model: str base_model: str
base_model_config: Optional[str] = None base_model_config: Optional[str] = None
cls_model_config: Optional[str] = None
tokenizer_config: Optional[str] = None tokenizer_config: Optional[str] = None
tokenizer_use_fast: Optional[bool] = None tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None tokenizer_legacy: Optional[bool] = None
@@ -356,6 +355,7 @@ class ModelOutputConfig(BaseModel):
hub_model_id: Optional[str] = None hub_model_id: Optional[str] = None
hub_strategy: Optional[str] = None hub_strategy: Optional[str] = None
save_safetensors: Optional[bool] = None save_safetensors: Optional[bool] = None
save_only_model: Optional[bool] = None
class MLFlowConfig(BaseModel): class MLFlowConfig(BaseModel):
@@ -479,7 +479,6 @@ class AxolotlInputConfig(
eval_causal_lm_metrics: Optional[List[str]] = None eval_causal_lm_metrics: Optional[List[str]] = None
do_bench_eval: Optional[bool] = None do_bench_eval: Optional[bool] = None
bench_dataset: Optional[str] = None bench_dataset: Optional[str] = None
bench_split: Optional[str] = None
metric_for_best_model: Optional[str] = None metric_for_best_model: Optional[str] = None
greater_is_better: Optional[bool] = None greater_is_better: Optional[bool] = None
@@ -495,9 +494,7 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype] # torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field( gradient_checkpointing: Optional[bool] = Field(default=False)
default=False
)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None unfrozen_parameters: Optional[List[str]] = None
@@ -975,16 +972,9 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_fsdp_offload_w_8bit_optimizer(cls, data): def check_fsdp_w_8bit_optimizer(cls, data):
if ( if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
data.get("fsdp") raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params")
):
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -4,25 +4,27 @@ utility helpers for distributed checks
import os import os
import pickle # nosec import pickle # nosec
from contextlib import contextmanager from contextlib import contextmanager
from datetime import timedelta
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate import PartialState from accelerate import Accelerator
distributed_state = None # pylint: disable=invalid-name accelerate = None # pylint: disable=invalid-name
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed(): def is_distributed():
""" """
Check if distributed training is initialized. Check if distributed training is initialized.
""" """
global distributed_state # pylint: disable=global-statement global accelerate # pylint: disable=global-statement
if not distributed_state: if not accelerate:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800)) accelerate = Accelerator()
distributed_state = PartialState(timeout=timedelta(seconds=timeout)) return dist.is_available() and dist.is_initialized()
return distributed_state.use_distributed and distributed_state.initialized
def barrier(): def barrier():

View File

@@ -1,13 +0,0 @@
"""custom checkpointing utils"""
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer.__self__,
*args,
)

View File

@@ -1,52 +0,0 @@
"""Unsloth checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)

View File

@@ -1,259 +0,0 @@
"""
module to handle loading model on cpu/meta device for FSDP
"""
import os
import time
from typing import List, Optional, Type, Union
import safetensors
import torch
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
def _replace_linear(
model: nn.Module,
linear_replacement: Type[nn.Module],
quant_config: Union[dict, None] = None,
skip_modules=None,
**kwargs,
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
if skip_modules is None:
skip_modules = ["lm_head"]
for name, module in model.named_children():
if len(list(module.children())) > 0:
_replace_linear(
module, linear_replacement, quant_config, skip_modules, **kwargs
)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access
name
] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs,
)
else:
raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}"
)
return model
def load_and_quantize(
module: nn.Module,
name: str,
value: Tensor,
device: torch.device = None,
dtype: torch.dtype = None,
skip_names: Optional[List[str]] = None,
to_cpu: bool = False,
to_meta: bool = False,
verbose: bool = False,
quant_method: str = "bnb",
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
"""
if not skip_names:
skip_names = []
def place_on_device(value):
if to_meta:
device = "meta"
elif to_cpu:
device = "cpu"
return value.to(device=device, dtype=dtype)
if any(skip_name in name for skip_name in skip_names):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition(".")
try:
submodule = module.get_submodule(module_key)
except AttributeError as exc:
print(f"Module {module_key} not found:\n{exc}")
return
try:
if quant_method == "bnb":
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
if to_meta:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif to_cpu:
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
value = type(param)(place_on_device(value).data)
except AttributeError:
# it's a buffer
value = place_on_device(value)
setattr(submodule, value_key, value)
def n_loading_workers(quant_method: str, param_count: float):
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
model_params_b = 70
right = int(
(4 if quant_method == "hqq" else 8)
* (devprops.total_memory / 1e9 / 40)
* (model_params_b / (param_count / 1e9))
)
return min(left, right)
def load_sharded_model(
model_name,
model_config,
cfg,
torch_dtype=torch.bfloat16,
low_memory=True,
):
if (low_memory and cfg.local_rank == 0) or not low_memory:
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_cache=False,
torch_dtype=torch.float32,
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
trust_remote_code=cfg.trust_remote_code,
)
dtype = torch_dtype if not cfg.float32 else None
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
else:
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code,
)
return model
def load_sharded_model_quant(
model_name,
model_config,
cfg,
compute_dtype=torch.bfloat16,
quant_storage=torch.float32,
low_memory=True,
verbose=False,
loading_workers=2,
):
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
trust_remote_code=cfg.trust_remote_code,
)
if hasattr(model, "transformer"):
model.transformer = _replace_linear(
model.transformer,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
)
else:
# this is the more common case with HF transformers
model.model = _replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
)
model.is_loaded_in_4bit = True
# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
except OSError as exc:
# This means the model probably doesn't have a safetensors file
raise exc
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)
quant_method = "bnb"
param_count = sum((p.numel() for n, p in model.named_parameters()))
n_workers = (
n_loading_workers(quant_method, param_count)
if loading_workers == -1
else loading_workers
)
if cfg.local_rank == 0 and verbose:
print(f"Using n_workers: {n_workers} for loading")
start = time.time()
for filename in tqdm(
files,
desc="Loading & Quantizing Model Shards",
disable=cfg.local_rank != 0,
position=0,
):
weights = safetensors.torch.load_file(filename)
parallel(
load_and_quantize_parallel,
iter(weights.items()),
n_workers=n_workers,
threadpool=True,
model=model,
dtype=quant_storage,
device=cfg.local_rank,
skip_names=[],
to_cpu=(low_memory and cfg.local_rank == 0),
to_meta=(low_memory and cfg.local_rank != 0),
verbose=verbose,
quant_method=quant_method,
)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache()
return model

View File

@@ -11,7 +11,6 @@ import addict
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights from accelerate import init_empty_weights
from bitsandbytes.nn import Params4bit from bitsandbytes.nn import Params4bit
from peft import ( from peft import (
@@ -45,37 +44,11 @@ from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only from axolotl.utils.distributed import zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
# copied from accelerator.FullyShardedDataParallelPlugin
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
if len(modules_children) == 0:
return None
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
return None
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
quant_config_exists = ( quant_config_exists = (
hasattr(model_config, "quantization_config") hasattr(model_config, "quantization_config")
@@ -312,9 +285,6 @@ def load_model(
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
if cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if hasattr(model_config, "model_type") and model_config.model_type == "btlm": if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention: if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import ( from axolotl.monkeypatch.btlm_attn_hijack_flash import (
@@ -489,7 +459,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16, "bnb_4bit_quant_storage": torch.bfloat16,
} }
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed: if not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude # for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16 # but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32 bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -500,13 +470,6 @@ def load_model(
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
) )
elif cfg.adapter == "lora" and cfg.load_in_8bit:
bnb_config = {
"load_in_8bit": True,
}
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
if cfg.load_in_8bit and cfg.adapter is not None: if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True model_kwargs["load_in_8bit"] = True
@@ -554,31 +517,7 @@ def load_model(
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
try: try:
skip_move_to_device = False
if ( if (
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
) and not qlora_fsdp:
model = load_sharded_model(
base_model,
model_config,
cfg,
torch_dtype=cfg.torch_dtype,
)
skip_move_to_device = True
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
):
quant_storage = cfg.torch_dtype
model = load_sharded_model_quant(
base_model,
model_config,
cfg,
quant_storage=quant_storage,
)
skip_move_to_device = True
elif (
model_config.model_type == "llama" model_config.model_type == "llama"
and not cfg.trust_remote_code and not cfg.trust_remote_code
and not cfg.gptq and not cfg.gptq
@@ -658,11 +597,6 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
else: else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=model_config, config=model_config,
@@ -736,17 +670,13 @@ def load_model(
needs_fa2_dtype = cfg.adapter or cfg.fsdp needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False skip_prepare_model_for_kbit_training = False
if is_deepspeed_zero3_enabled(): if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import ( # pylint: disable=no-name-in-module from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules, set_z3_leaf_modules,
) )
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
if cfg.model_config_type == "mixtral": set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
set_z3_leaf_modules(model, [moe_block])
elif cfg.model_config_type == "dbrx":
moe_block = get_module_class_from_name(model, "DbrxFFN")
set_z3_leaf_modules(model, [moe_block])
if cfg.model_config_type == "qwen" and cfg.adapter == "lora": if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled # Qwen doesn't play nicely with LoRA if this is enabled
@@ -756,8 +686,7 @@ def load_model(
if cfg.adapter == "lora" and loftq_bits: if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading): if qlora_fsdp:
# make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]: if cfg.adapter in ["lora", "qlora"]:
@@ -798,7 +727,7 @@ def load_model(
cfg.ddp cfg.ddp
and not load_in_8bit and not load_in_8bit
and not (cfg.rl and cfg.load_in_4bit) and not (cfg.rl and cfg.load_in_4bit)
and not skip_move_to_device and not qlora_fsdp
): ):
# TODO revaldate this conditional # TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
@@ -954,12 +883,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
rank = int(os.environ.get("LOCAL_RANK", 0)) rank = int(os.environ.get("LOCAL_RANK", 0))
if ( if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_meta_for_peft(model) setup_quantized_meta_for_peft(model)
if cfg.lora_model_dir: if cfg.lora_model_dir:
@@ -984,12 +908,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
LOG.warning( LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s", exc "Exception caught during model.print_trainable_parameters(): %s", exc
) )
elif ( elif cfg.fsdp and cfg.adapter == "qlora":
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_peft_meta_for_training(model) setup_quantized_peft_meta_for_training(model)
return model, lora_config return model, lora_config

View File

@@ -306,8 +306,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
def setup_fsdp_envs(cfg): def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true" os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_activation_checkpointing:
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
if cfg.fsdp_config.fsdp_offload_params: if cfg.fsdp_config.fsdp_offload_params:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true" os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states: if cfg.fsdp_config.fsdp_sync_module_states:

View File

@@ -30,7 +30,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "hf-internal-testing/Mixtral-tiny", "base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,
@@ -74,7 +74,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "hf-internal-testing/Mixtral-tiny", "base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,

View File

@@ -22,7 +22,7 @@ class TestModelPatches(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "hf-internal-testing/Mixtral-tiny", "base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,

View File

@@ -33,7 +33,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "hf-internal-testing/Mixtral-tiny", "base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"load_in_4bit": True, "load_in_4bit": True,
@@ -87,7 +87,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "hf-internal-testing/Mixtral-tiny", "base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": False, "flash_attention": False,
"sequence_len": 1024, "sequence_len": 1024,
"load_in_4bit": True, "load_in_4bit": True,
@@ -141,7 +141,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "hf-internal-testing/Mixtral-tiny", "base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"adapter": "lora", "adapter": "lora",
@@ -198,7 +198,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "hf-internal-testing/Mixtral-tiny", "base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": False, "flash_attention": False,
"sequence_len": 1024, "sequence_len": 1024,
"adapter": "lora", "adapter": "lora",
@@ -255,7 +255,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "hf-internal-testing/Mixtral-tiny", "base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.1, "val_set_size": 0.1,

View File

@@ -27,9 +27,7 @@ def fixture_alpaca_dataset():
@pytest.fixture(name="tokenizer") @pytest.fixture(name="tokenizer")
def fixture_tokenizer(): def fixture_tokenizer():
# pylint: disable=all # pylint: disable=all
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
"eos_token": AddedToken( "eos_token": AddedToken(

View File

@@ -43,9 +43,7 @@ def fixture_sharegpt_dataset():
@pytest.fixture(name="tokenizer") @pytest.fixture(name="tokenizer")
def fixture_tokenizer(): def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_tokens( tokenizer.add_tokens(
[ [
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False), AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),

View File

@@ -96,9 +96,7 @@ def fixture_multi_role_dataset():
@pytest.fixture(name="tokenizer") @pytest.fixture(name="tokenizer")
def fixture_tokenizer(): def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
"eos_token": AddedToken( "eos_token": AddedToken(

View File

@@ -454,9 +454,7 @@ class OrpoTokenizationTest(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
"eos_token": AddedToken( "eos_token": AddedToken(