DBRX Model Support (#1462)
* wip for dbrx finetuning * add fastcore for parallel loading of sharded weights * fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback * update to use v2 of the converted model * more fixes for dbrx loras * make sure to enable fsdp activation checkpointing * fix support for 8bit loras too for dbrx * apply z3 leaf moe fix for DBRX with deepspeed * don't raise value error since child module searches could fail and be ok * revert a previous change to fix fsdp * update mistral/mistral qlora+fsdp yamls * fix qlora+fsdp quant storage type * more edge cases for qlora-fsdp * fixes for fsdp+qlora w optimizer in 8bit * add bigstral z3 config and make sure to use full_state_dict for fsdp
This commit is contained in:
81
examples/dbrx/16bit-lora.yaml
Normal file
81
examples/dbrx/16bit-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
||||
base_model: LnL-AI/dbrx-base-converted-v2
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
# w1, w2, & v1 will hang the trainer
|
||||
lora_target_modules:
|
||||
- q_proj # attn
|
||||
- k_proj # attn
|
||||
- v_proj # attn
|
||||
- out_proj # attn
|
||||
- layer # router
|
||||
# - w1
|
||||
# - w2
|
||||
# - v1
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_activation_checkpointing: true
|
||||
81
examples/dbrx/8bit-lora.yaml
Normal file
81
examples/dbrx/8bit-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
||||
base_model: LnL-AI/dbrx-base-converted-v2
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
# w1, w2, & v1 will hang the trainer
|
||||
lora_target_modules:
|
||||
- q_proj # attn
|
||||
- k_proj # attn
|
||||
- v_proj # attn
|
||||
- out_proj # attn
|
||||
- layer # router
|
||||
# - w1
|
||||
# - w2
|
||||
# - v1
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_activation_checkpointing: true
|
||||
26
examples/dbrx/README.md
Normal file
26
examples/dbrx/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# DBRX MoE
|
||||
|
||||
Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
|
||||
|
||||
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
|
||||
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
|
||||
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
|
||||
results in the trainer hanging.
|
||||
|
||||
|
||||
### FSDP
|
||||
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
|
||||
|
||||
The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
|
||||
|
||||
- 16-bit LoRA w/ FSDP
|
||||
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
|
||||
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
|
||||
- ✅ 8-bit LoRA w/ FSDP
|
||||
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
|
||||
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
|
||||
|
||||
|
||||
### Deepspeed
|
||||
|
||||
WIP
|
||||
56
examples/dbrx/fft-ds-zero3.yaml
Normal file
56
examples/dbrx/fft-ds-zero3.yaml
Normal file
@@ -0,0 +1,56 @@
|
||||
base_model: LnL-AI/dbrx-base-converted-v2
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
unfrozen_parameters:
|
||||
- transformer.blocks.[0-7].
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
weight_decay: 0.0
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
@@ -65,12 +65,14 @@ deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: true
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
special_tokens:
|
||||
|
||||
63
examples/mistral/bigstral-ds-zero3.yaml
Normal file
63
examples/mistral/bigstral-ds-zero3.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
unfrozen_parameters:
|
||||
- ^lm_head.weight$
|
||||
- ^model.embed_tokens.weight$
|
||||
- model.layers.4[4-9]+.block_sparse_moe.gate
|
||||
- model.layers.4[4-9]+.block_sparse_moe.experts
|
||||
- model.layers.5[0-5]+.block_sparse_moe.gate
|
||||
- model.layers.5[0-5]+.block_sparse_moe.experts
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0001
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
save_total_limit: 1
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
eos_token: "<|im_end|>"
|
||||
tokens:
|
||||
- "<|im_start|>"
|
||||
82
examples/mistral/mistral-qlora-fsdp.yml
Normal file
82
examples/mistral/mistral-qlora-fsdp.yml
Normal file
@@ -0,0 +1,82 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
output_dir: ./qlora-out
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
special_tokens:
|
||||
81
examples/mistral/mixtral-8x22b-qlora-fsdp.yml
Normal file
81
examples/mistral/mixtral-8x22b-qlora-fsdp.yml
Normal file
@@ -0,0 +1,81 @@
|
||||
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
output_dir: ./qlora-out
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: true
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
special_tokens:
|
||||
@@ -39,7 +39,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
@@ -47,7 +47,7 @@ train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
@@ -69,6 +69,17 @@ debug:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: true
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
special_tokens:
|
||||
|
||||
Reference in New Issue
Block a user