DBRX Model Support (#1462)
* wip for dbrx finetuning * add fastcore for parallel loading of sharded weights * fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback * update to use v2 of the converted model * more fixes for dbrx loras * make sure to enable fsdp activation checkpointing * fix support for 8bit loras too for dbrx * apply z3 leaf moe fix for DBRX with deepspeed * don't raise value error since child module searches could fail and be ok * revert a previous change to fix fsdp * update mistral/mistral qlora+fsdp yamls * fix qlora+fsdp quant storage type * more edge cases for qlora-fsdp * fixes for fsdp+qlora w optimizer in 8bit * add bigstral z3 config and make sure to use full_state_dict for fsdp
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
{
|
{
|
||||||
|
"zero_force_ds_cpu_optimizer": false,
|
||||||
|
"zero_allow_untested_optimizer": true,
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 3,
|
"stage": 3,
|
||||||
"offload_optimizer": {
|
"offload_optimizer": {
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
{
|
{
|
||||||
|
"zero_force_ds_cpu_optimizer": false,
|
||||||
|
"zero_allow_untested_optimizer": true,
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 3,
|
"stage": 3,
|
||||||
"offload_param": {
|
"offload_param": {
|
||||||
|
|||||||
81
examples/dbrx/16bit-lora.yaml
Normal file
81
examples/dbrx/16bit-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: LnL-AI/dbrx-base-converted-v2
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
# w1, w2, & v1 will hang the trainer
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj # attn
|
||||||
|
- k_proj # attn
|
||||||
|
- v_proj # attn
|
||||||
|
- out_proj # attn
|
||||||
|
- layer # router
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - v1
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_activation_checkpointing: true
|
||||||
81
examples/dbrx/8bit-lora.yaml
Normal file
81
examples/dbrx/8bit-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: LnL-AI/dbrx-base-converted-v2
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
# w1, w2, & v1 will hang the trainer
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj # attn
|
||||||
|
- k_proj # attn
|
||||||
|
- v_proj # attn
|
||||||
|
- out_proj # attn
|
||||||
|
- layer # router
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - v1
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_activation_checkpointing: true
|
||||||
26
examples/dbrx/README.md
Normal file
26
examples/dbrx/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# DBRX MoE
|
||||||
|
|
||||||
|
Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
|
||||||
|
|
||||||
|
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
|
||||||
|
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
|
||||||
|
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
|
||||||
|
results in the trainer hanging.
|
||||||
|
|
||||||
|
|
||||||
|
### FSDP
|
||||||
|
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
|
||||||
|
|
||||||
|
The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
|
||||||
|
|
||||||
|
- 16-bit LoRA w/ FSDP
|
||||||
|
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
|
||||||
|
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
|
||||||
|
- ✅ 8-bit LoRA w/ FSDP
|
||||||
|
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
|
||||||
|
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
|
||||||
|
|
||||||
|
|
||||||
|
### Deepspeed
|
||||||
|
|
||||||
|
WIP
|
||||||
56
examples/dbrx/fft-ds-zero3.yaml
Normal file
56
examples/dbrx/fft-ds-zero3.yaml
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
base_model: LnL-AI/dbrx-base-converted-v2
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
unfrozen_parameters:
|
||||||
|
- transformer.blocks.[0-7].
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||||
@@ -65,12 +65,14 @@ deepspeed:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
- full_shard
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_limit_all_gathers: true
|
fsdp_limit_all_gathers: true
|
||||||
fsdp_sync_module_states: true
|
fsdp_sync_module_states: true
|
||||||
fsdp_offload_params: true
|
fsdp_offload_params: true
|
||||||
fsdp_use_orig_params: false
|
fsdp_use_orig_params: false
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
63
examples/mistral/bigstral-ds-zero3.yaml
Normal file
63
examples/mistral/bigstral-ds-zero3.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
unfrozen_parameters:
|
||||||
|
- ^lm_head.weight$
|
||||||
|
- ^model.embed_tokens.weight$
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.gate
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.experts
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.gate
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.experts
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 3
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
save_total_limit: 1
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
eos_token: "<|im_end|>"
|
||||||
|
tokens:
|
||||||
|
- "<|im_start|>"
|
||||||
82
examples/mistral/mistral-qlora-fsdp.yml
Normal file
82
examples/mistral/mistral-qlora-fsdp.yml
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.02
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
special_tokens:
|
||||||
81
examples/mistral/mixtral-8x22b-qlora-fsdp.yml
Normal file
81
examples/mistral/mixtral-8x22b-qlora-fsdp.yml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.02
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
special_tokens:
|
||||||
@@ -39,7 +39,7 @@ wandb_log_model:
|
|||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: paged_adamw_8bit
|
optimizer: adamw_torch
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ train_on_inputs: false
|
|||||||
group_by_length: false
|
group_by_length: false
|
||||||
bf16: auto
|
bf16: auto
|
||||||
fp16:
|
fp16:
|
||||||
tf32: false
|
tf32: true
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
@@ -69,6 +69,17 @@ debug:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
- full_shard
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_forward_prefetch: false
|
||||||
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -41,3 +41,4 @@ gcsfs
|
|||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
|
fastcore
|
||||||
|
|||||||
@@ -918,10 +918,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
):
|
):
|
||||||
callbacks.append(SaveBetterTransformerModelCallback())
|
callbacks.append(SaveBetterTransformerModelCallback())
|
||||||
|
|
||||||
if self.cfg.use_wandb:
|
|
||||||
callbacks.append(
|
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
|
||||||
)
|
|
||||||
if self.cfg.use_mlflow and is_mlflow_available():
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
from axolotl.utils.callbacks.mlflow_ import (
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
@@ -81,6 +82,8 @@ def train(
|
|||||||
if cfg.adapter:
|
if cfg.adapter:
|
||||||
msg += " and peft_config..."
|
msg += " and peft_config..."
|
||||||
LOG.debug(msg)
|
LOG.debug(msg)
|
||||||
|
# we wait unitl the last possible moment to setup Accelerator
|
||||||
|
Accelerator()
|
||||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
|
|||||||
@@ -259,6 +259,7 @@ class ModelInputConfig(BaseModel):
|
|||||||
|
|
||||||
base_model: str
|
base_model: str
|
||||||
base_model_config: Optional[str] = None
|
base_model_config: Optional[str] = None
|
||||||
|
cls_model_config: Optional[str] = None
|
||||||
tokenizer_config: Optional[str] = None
|
tokenizer_config: Optional[str] = None
|
||||||
tokenizer_use_fast: Optional[bool] = None
|
tokenizer_use_fast: Optional[bool] = None
|
||||||
tokenizer_legacy: Optional[bool] = None
|
tokenizer_legacy: Optional[bool] = None
|
||||||
@@ -971,9 +972,16 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_w_8bit_optimizer(cls, data):
|
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||||
if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
|
if (
|
||||||
raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
|
data.get("fsdp")
|
||||||
|
and "8bit" in data.get("optimizer", "")
|
||||||
|
and data.get("fsdp_config")
|
||||||
|
and data["fsdp_config"].get("fsdp_offload_params")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"FSDP Offload not compatible with {data.get('optimizer')}"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -4,27 +4,25 @@ utility helpers for distributed checks
|
|||||||
import os
|
import os
|
||||||
import pickle # nosec
|
import pickle # nosec
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate import Accelerator
|
from accelerate import PartialState
|
||||||
|
|
||||||
accelerate = None # pylint: disable=invalid-name
|
distributed_state = None # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def load_accelerate():
|
|
||||||
global accelerate # pylint: disable=global-statement
|
|
||||||
accelerate = Accelerator()
|
|
||||||
|
|
||||||
|
|
||||||
def is_distributed():
|
def is_distributed():
|
||||||
"""
|
"""
|
||||||
Check if distributed training is initialized.
|
Check if distributed training is initialized.
|
||||||
"""
|
"""
|
||||||
global accelerate # pylint: disable=global-statement
|
global distributed_state # pylint: disable=global-statement
|
||||||
if not accelerate:
|
if not distributed_state:
|
||||||
accelerate = Accelerator()
|
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||||
return dist.is_available() and dist.is_initialized()
|
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||||
|
|
||||||
|
return distributed_state.use_distributed and distributed_state.initialized
|
||||||
|
|
||||||
|
|
||||||
def barrier():
|
def barrier():
|
||||||
|
|||||||
259
src/axolotl/utils/model_shard_quant.py
Normal file
259
src/axolotl/utils/model_shard_quant.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
"""
|
||||||
|
module to handle loading model on cpu/meta device for FSDP
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Type, Union
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from bitsandbytes.nn import Linear4bit, Params4bit
|
||||||
|
from fastcore.parallel import parallel
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_linear(
|
||||||
|
model: nn.Module,
|
||||||
|
linear_replacement: Type[nn.Module],
|
||||||
|
quant_config: Union[dict, None] = None,
|
||||||
|
skip_modules=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Replace linear modules with a new Linear module.
|
||||||
|
Parameters:
|
||||||
|
model (`torch.nn.Module`):
|
||||||
|
Input model or `torch.nn.Module` as the function is run recursively.
|
||||||
|
linear_replacement (`torch.nn.Module`):
|
||||||
|
The linear module that replaces the old one. Only expects standard arguments.
|
||||||
|
If other arguments need to be passed, use a lambda.
|
||||||
|
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||||
|
List of modules names not to convert. Defaults to `lm_head`.
|
||||||
|
"""
|
||||||
|
if skip_modules is None:
|
||||||
|
skip_modules = ["lm_head"]
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
_replace_linear(
|
||||||
|
module, linear_replacement, quant_config, skip_modules, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||||
|
if issubclass(linear_replacement, Linear4bit):
|
||||||
|
model._modules[ # pylint: disable=protected-access
|
||||||
|
name
|
||||||
|
] = linear_replacement(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported linear replacement: {type(linear_replacement)}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_quantize(
|
||||||
|
module: nn.Module,
|
||||||
|
name: str,
|
||||||
|
value: Tensor,
|
||||||
|
device: torch.device = None,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
skip_names: Optional[List[str]] = None,
|
||||||
|
to_cpu: bool = False,
|
||||||
|
to_meta: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
|
quant_method: str = "bnb",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
||||||
|
|
||||||
|
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not skip_names:
|
||||||
|
skip_names = []
|
||||||
|
|
||||||
|
def place_on_device(value):
|
||||||
|
if to_meta:
|
||||||
|
device = "meta"
|
||||||
|
elif to_cpu:
|
||||||
|
device = "cpu"
|
||||||
|
return value.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if any(skip_name in name for skip_name in skip_names):
|
||||||
|
if verbose:
|
||||||
|
print(f"Skipping {name} because it is in skip_names")
|
||||||
|
return
|
||||||
|
|
||||||
|
module_key, _, value_key = name.rpartition(".")
|
||||||
|
try:
|
||||||
|
submodule = module.get_submodule(module_key)
|
||||||
|
except AttributeError as exc:
|
||||||
|
print(f"Module {module_key} not found:\n{exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if quant_method == "bnb":
|
||||||
|
param = submodule.get_parameter(value_key)
|
||||||
|
if isinstance(param, Params4bit):
|
||||||
|
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
||||||
|
# shape as the quantized Params4bit with an initialized quant_state. However,
|
||||||
|
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
||||||
|
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
||||||
|
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
||||||
|
value = type(param)(
|
||||||
|
value.to(device=device, dtype=dtype).data, **param.__dict__
|
||||||
|
).cuda(device)
|
||||||
|
if to_meta:
|
||||||
|
value = type(param)(value.data.to("meta"), **value.__dict__)
|
||||||
|
elif to_cpu:
|
||||||
|
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
||||||
|
else:
|
||||||
|
value = type(param)(place_on_device(value).data)
|
||||||
|
|
||||||
|
except AttributeError:
|
||||||
|
# it's a buffer
|
||||||
|
value = place_on_device(value)
|
||||||
|
|
||||||
|
setattr(submodule, value_key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def n_loading_workers(quant_method: str, param_count: float):
|
||||||
|
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||||
|
left = int(os.cpu_count() / torch.cuda.device_count())
|
||||||
|
model_params_b = 70
|
||||||
|
right = int(
|
||||||
|
(4 if quant_method == "hqq" else 8)
|
||||||
|
* (devprops.total_memory / 1e9 / 40)
|
||||||
|
* (model_params_b / (param_count / 1e9))
|
||||||
|
)
|
||||||
|
return min(left, right)
|
||||||
|
|
||||||
|
|
||||||
|
def load_sharded_model(
|
||||||
|
model_name,
|
||||||
|
model_config,
|
||||||
|
cfg,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
low_memory=True,
|
||||||
|
):
|
||||||
|
if (low_memory and cfg.local_rank == 0) or not low_memory:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
use_cache=False,
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
|
||||||
|
trust_remote_code=cfg.trust_remote_code,
|
||||||
|
)
|
||||||
|
dtype = torch_dtype if not cfg.float32 else None
|
||||||
|
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
|
||||||
|
else:
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
model_config,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=cfg.trust_remote_code,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_sharded_model_quant(
|
||||||
|
model_name,
|
||||||
|
model_config,
|
||||||
|
cfg,
|
||||||
|
compute_dtype=torch.bfloat16,
|
||||||
|
quant_storage=torch.float32,
|
||||||
|
low_memory=True,
|
||||||
|
verbose=False,
|
||||||
|
loading_workers=2,
|
||||||
|
):
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
model_config,
|
||||||
|
trust_remote_code=cfg.trust_remote_code,
|
||||||
|
)
|
||||||
|
if hasattr(model, "transformer"):
|
||||||
|
model.transformer = _replace_linear(
|
||||||
|
model.transformer,
|
||||||
|
Linear4bit,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
quant_type="nf4",
|
||||||
|
quant_storage=quant_storage,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# this is the more common case with HF transformers
|
||||||
|
model.model = _replace_linear(
|
||||||
|
model.model,
|
||||||
|
Linear4bit,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
quant_type="nf4",
|
||||||
|
quant_storage=quant_storage,
|
||||||
|
)
|
||||||
|
model.is_loaded_in_4bit = True
|
||||||
|
|
||||||
|
# Grab the safetensors files that hold the weights
|
||||||
|
try:
|
||||||
|
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
|
||||||
|
except OSError:
|
||||||
|
try:
|
||||||
|
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
||||||
|
files = []
|
||||||
|
files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
|
||||||
|
except OSError as exc:
|
||||||
|
# This means the model probably doesn't have a safetensors file
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
||||||
|
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
||||||
|
def load_and_quantize_parallel(name_param, model, **kwargs):
|
||||||
|
name, param = name_param
|
||||||
|
load_and_quantize(model, name, param, **kwargs)
|
||||||
|
|
||||||
|
quant_method = "bnb"
|
||||||
|
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
||||||
|
|
||||||
|
n_workers = (
|
||||||
|
n_loading_workers(quant_method, param_count)
|
||||||
|
if loading_workers == -1
|
||||||
|
else loading_workers
|
||||||
|
)
|
||||||
|
if cfg.local_rank == 0 and verbose:
|
||||||
|
print(f"Using n_workers: {n_workers} for loading")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for filename in tqdm(
|
||||||
|
files,
|
||||||
|
desc="Loading & Quantizing Model Shards",
|
||||||
|
disable=cfg.local_rank != 0,
|
||||||
|
position=0,
|
||||||
|
):
|
||||||
|
weights = safetensors.torch.load_file(filename)
|
||||||
|
parallel(
|
||||||
|
load_and_quantize_parallel,
|
||||||
|
iter(weights.items()),
|
||||||
|
n_workers=n_workers,
|
||||||
|
threadpool=True,
|
||||||
|
model=model,
|
||||||
|
dtype=quant_storage,
|
||||||
|
device=cfg.local_rank,
|
||||||
|
skip_names=[],
|
||||||
|
to_cpu=(low_memory and cfg.local_rank == 0),
|
||||||
|
to_meta=(low_memory and cfg.local_rank != 0),
|
||||||
|
verbose=verbose,
|
||||||
|
quant_method=quant_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.local_rank == 0 and verbose:
|
||||||
|
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
||||||
|
# cleanup any extra memory usage from parallel loading
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return model
|
||||||
@@ -45,10 +45,35 @@ from axolotl.utils.chat_templates import chat_templates
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import zero_only
|
from axolotl.utils.distributed import zero_only
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||||
|
def get_module_class_from_name(module, name):
|
||||||
|
"""
|
||||||
|
Gets a class from a module by its name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (`torch.nn.Module`): The module to get the class from.
|
||||||
|
name (`str`): The name of the class.
|
||||||
|
"""
|
||||||
|
modules_children = list(module.children())
|
||||||
|
if module.__class__.__name__ == name:
|
||||||
|
return module.__class__
|
||||||
|
|
||||||
|
if len(modules_children) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for child_module in modules_children:
|
||||||
|
module_class = get_module_class_from_name(child_module, name)
|
||||||
|
if module_class is not None:
|
||||||
|
return module_class
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||||
quant_config_exists = (
|
quant_config_exists = (
|
||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
@@ -459,7 +484,7 @@ def load_model(
|
|||||||
"bnb_4bit_quant_type": "nf4",
|
"bnb_4bit_quant_type": "nf4",
|
||||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||||
}
|
}
|
||||||
if not cfg.deepspeed:
|
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
|
||||||
# for some reason, this causes the loss to be off by an order of magnitude
|
# for some reason, this causes the loss to be off by an order of magnitude
|
||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
@@ -470,6 +495,13 @@ def load_model(
|
|||||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
elif cfg.adapter == "lora" and cfg.load_in_8bit:
|
||||||
|
bnb_config = {
|
||||||
|
"load_in_8bit": True,
|
||||||
|
}
|
||||||
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
**bnb_config,
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.load_in_8bit and cfg.adapter is not None:
|
if cfg.load_in_8bit and cfg.adapter is not None:
|
||||||
model_kwargs["load_in_8bit"] = True
|
model_kwargs["load_in_8bit"] = True
|
||||||
@@ -517,7 +549,31 @@ def load_model(
|
|||||||
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
|
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
skip_move_to_device = False
|
||||||
if (
|
if (
|
||||||
|
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
|
) and not qlora_fsdp:
|
||||||
|
model = load_sharded_model(
|
||||||
|
base_model,
|
||||||
|
model_config,
|
||||||
|
cfg,
|
||||||
|
torch_dtype=cfg.torch_dtype,
|
||||||
|
)
|
||||||
|
skip_move_to_device = True
|
||||||
|
elif (
|
||||||
|
qlora_fsdp
|
||||||
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
|
and cfg.model_config_type == "dbrx"
|
||||||
|
):
|
||||||
|
quant_storage = cfg.torch_dtype
|
||||||
|
model = load_sharded_model_quant(
|
||||||
|
base_model,
|
||||||
|
model_config,
|
||||||
|
cfg,
|
||||||
|
quant_storage=quant_storage,
|
||||||
|
)
|
||||||
|
skip_move_to_device = True
|
||||||
|
elif (
|
||||||
model_config.model_type == "llama"
|
model_config.model_type == "llama"
|
||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
@@ -597,6 +653,11 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
|
skip_move_to_device = True
|
||||||
|
if "device_map" in model_kwargs:
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
@@ -670,13 +731,17 @@ def load_model(
|
|||||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
skip_prepare_model_for_kbit_training = False
|
skip_prepare_model_for_kbit_training = False
|
||||||
|
|
||||||
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
||||||
set_z3_leaf_modules,
|
set_z3_leaf_modules,
|
||||||
)
|
)
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|
||||||
|
|
||||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
if cfg.model_config_type == "mixtral":
|
||||||
|
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
||||||
|
set_z3_leaf_modules(model, [moe_block])
|
||||||
|
elif cfg.model_config_type == "dbrx":
|
||||||
|
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
||||||
|
set_z3_leaf_modules(model, [moe_block])
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||||
@@ -686,7 +751,8 @@ def load_model(
|
|||||||
if cfg.adapter == "lora" and loftq_bits:
|
if cfg.adapter == "lora" and loftq_bits:
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if qlora_fsdp:
|
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
|
||||||
|
# make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if cfg.adapter in ["lora", "qlora"]:
|
if cfg.adapter in ["lora", "qlora"]:
|
||||||
@@ -727,7 +793,7 @@ def load_model(
|
|||||||
cfg.ddp
|
cfg.ddp
|
||||||
and not load_in_8bit
|
and not load_in_8bit
|
||||||
and not (cfg.rl and cfg.load_in_4bit)
|
and not (cfg.rl and cfg.load_in_4bit)
|
||||||
and not qlora_fsdp
|
and not skip_move_to_device
|
||||||
):
|
):
|
||||||
# TODO revaldate this conditional
|
# TODO revaldate this conditional
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
@@ -883,7 +949,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
|
|
||||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
|
||||||
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
|
if (
|
||||||
|
cfg.fsdp
|
||||||
|
and cfg.adapter
|
||||||
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
|
and rank != 0
|
||||||
|
):
|
||||||
setup_quantized_meta_for_peft(model)
|
setup_quantized_meta_for_peft(model)
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
@@ -908,7 +979,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Exception caught during model.print_trainable_parameters(): %s", exc
|
"Exception caught during model.print_trainable_parameters(): %s", exc
|
||||||
)
|
)
|
||||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
elif (
|
||||||
|
cfg.fsdp
|
||||||
|
and cfg.adapter
|
||||||
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
|
and rank != 0
|
||||||
|
):
|
||||||
setup_quantized_peft_meta_for_training(model)
|
setup_quantized_peft_meta_for_training(model)
|
||||||
|
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|||||||
@@ -306,6 +306,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
|
|
||||||
def setup_fsdp_envs(cfg):
|
def setup_fsdp_envs(cfg):
|
||||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||||
|
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
||||||
|
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
|
||||||
if cfg.fsdp_config.fsdp_offload_params:
|
if cfg.fsdp_config.fsdp_offload_params:
|
||||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||||
if cfg.fsdp_config.fsdp_sync_module_states:
|
if cfg.fsdp_config.fsdp_sync_module_states:
|
||||||
|
|||||||
Reference in New Issue
Block a user