DBRX Model Support (#1462)

* wip for dbrx finetuning

* add fastcore for parallel loading of sharded weights

* fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback

* update to use v2 of the converted model

* more fixes for dbrx loras

* make sure to enable fsdp activation checkpointing

* fix support for 8bit loras too for dbrx

* apply z3 leaf moe fix for DBRX with deepspeed

* don't raise value error since child module searches could fail and be ok

* revert a previous change to fix fsdp

* update mistral/mistral qlora+fsdp yamls

* fix qlora+fsdp quant storage type

* more edge cases for qlora-fsdp

* fixes for fsdp+qlora w optimizer in 8bit

* add bigstral z3 config and make sure to use full_state_dict for fsdp
This commit is contained in:
Wing Lian
2024-04-12 09:02:36 -04:00
committed by GitHub
parent 5ed29393e3
commit 132eb740f0
19 changed files with 859 additions and 29 deletions

View File

@@ -1,4 +1,6 @@
{ {
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
"offload_optimizer": { "offload_optimizer": {

View File

@@ -1,4 +1,6 @@
{ {
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
"offload_param": { "offload_param": {

View File

@@ -0,0 +1,81 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true

View File

@@ -0,0 +1,81 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true

26
examples/dbrx/README.md Normal file
View File

@@ -0,0 +1,26 @@
# DBRX MoE
Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
results in the trainer hanging.
### FSDP
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
- 16-bit LoRA w/ FSDP
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
- ✅ 8-bit LoRA w/ FSDP
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
### Deepspeed
WIP

View File

@@ -0,0 +1,56 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
unfrozen_parameters:
- transformer.blocks.[0-7].
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
deepspeed: deepspeed_configs/zero3_bf16.json

View File

@@ -65,12 +65,14 @@ deepspeed:
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
- full_shard - full_shard
- auto_wrap
fsdp_config: fsdp_config:
fsdp_limit_all_gathers: true fsdp_limit_all_gathers: true
fsdp_sync_module_states: true fsdp_sync_module_states: true
fsdp_offload_params: true fsdp_offload_params: true
fsdp_use_orig_params: false fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
special_tokens: special_tokens:

View File

@@ -0,0 +1,63 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
- model.layers.4[4-9]+.block_sparse_moe.gate
- model.layers.4[4-9]+.block_sparse_moe.experts
- model.layers.5[0-5]+.block_sparse_moe.gate
- model.layers.5[0-5]+.block_sparse_moe.experts
model_config:
output_router_logits: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
save_total_limit: 1
save_steps:
debug:
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"

View File

@@ -0,0 +1,82 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: false
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:

View File

@@ -0,0 +1,81 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 1 num_epochs: 1
optimizer: paged_adamw_8bit optimizer: adamw_torch
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -47,7 +47,7 @@ train_on_inputs: false
group_by_length: false group_by_length: false
bf16: auto bf16: auto
fp16: fp16:
tf32: false tf32: true
gradient_checkpointing: true gradient_checkpointing: true
early_stopping_patience: early_stopping_patience:
@@ -69,6 +69,17 @@ debug:
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
- full_shard - full_shard
- auto_wrap
fsdp_config: fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_sharding_strategy: FULL_SHARD
fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens: special_tokens:

View File

@@ -41,3 +41,4 @@ gcsfs
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
zstandard==0.22.0 zstandard==0.22.0
fastcore

View File

@@ -918,10 +918,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
): ):
callbacks.append(SaveBetterTransformerModelCallback()) callbacks.append(SaveBetterTransformerModelCallback())
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available(): if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import ( from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoMlflowCallback,

View File

@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from peft import PeftModel from peft import PeftModel
@@ -81,6 +82,8 @@ def train(
if cfg.adapter: if cfg.adapter:
msg += " and peft_config..." msg += " and peft_config..."
LOG.debug(msg) LOG.debug(msg)
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True model.generation_config.do_sample = True

View File

@@ -259,6 +259,7 @@ class ModelInputConfig(BaseModel):
base_model: str base_model: str
base_model_config: Optional[str] = None base_model_config: Optional[str] = None
cls_model_config: Optional[str] = None
tokenizer_config: Optional[str] = None tokenizer_config: Optional[str] = None
tokenizer_use_fast: Optional[bool] = None tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None tokenizer_legacy: Optional[bool] = None
@@ -971,9 +972,16 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_fsdp_w_8bit_optimizer(cls, data): def check_fsdp_offload_w_8bit_optimizer(cls, data):
if data.get("fsdp") and "bnb" in data.get("optimizer", ""): if (
raise ValueError(f"FSDP not compatible with {data.get('optimizer')}") data.get("fsdp")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params")
):
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -4,27 +4,25 @@ utility helpers for distributed checks
import os import os
import pickle # nosec import pickle # nosec
from contextlib import contextmanager from contextlib import contextmanager
from datetime import timedelta
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate import Accelerator from accelerate import PartialState
accelerate = None # pylint: disable=invalid-name distributed_state = None # pylint: disable=invalid-name
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed(): def is_distributed():
""" """
Check if distributed training is initialized. Check if distributed training is initialized.
""" """
global accelerate # pylint: disable=global-statement global distributed_state # pylint: disable=global-statement
if not accelerate: if not distributed_state:
accelerate = Accelerator() timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
return dist.is_available() and dist.is_initialized() distributed_state = PartialState(timeout=timedelta(seconds=timeout))
return distributed_state.use_distributed and distributed_state.initialized
def barrier(): def barrier():

View File

@@ -0,0 +1,259 @@
"""
module to handle loading model on cpu/meta device for FSDP
"""
import os
import time
from typing import List, Optional, Type, Union
import safetensors
import torch
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
def _replace_linear(
model: nn.Module,
linear_replacement: Type[nn.Module],
quant_config: Union[dict, None] = None,
skip_modules=None,
**kwargs,
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
if skip_modules is None:
skip_modules = ["lm_head"]
for name, module in model.named_children():
if len(list(module.children())) > 0:
_replace_linear(
module, linear_replacement, quant_config, skip_modules, **kwargs
)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access
name
] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs,
)
else:
raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}"
)
return model
def load_and_quantize(
module: nn.Module,
name: str,
value: Tensor,
device: torch.device = None,
dtype: torch.dtype = None,
skip_names: Optional[List[str]] = None,
to_cpu: bool = False,
to_meta: bool = False,
verbose: bool = False,
quant_method: str = "bnb",
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
"""
if not skip_names:
skip_names = []
def place_on_device(value):
if to_meta:
device = "meta"
elif to_cpu:
device = "cpu"
return value.to(device=device, dtype=dtype)
if any(skip_name in name for skip_name in skip_names):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition(".")
try:
submodule = module.get_submodule(module_key)
except AttributeError as exc:
print(f"Module {module_key} not found:\n{exc}")
return
try:
if quant_method == "bnb":
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
if to_meta:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif to_cpu:
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
value = type(param)(place_on_device(value).data)
except AttributeError:
# it's a buffer
value = place_on_device(value)
setattr(submodule, value_key, value)
def n_loading_workers(quant_method: str, param_count: float):
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
model_params_b = 70
right = int(
(4 if quant_method == "hqq" else 8)
* (devprops.total_memory / 1e9 / 40)
* (model_params_b / (param_count / 1e9))
)
return min(left, right)
def load_sharded_model(
model_name,
model_config,
cfg,
torch_dtype=torch.bfloat16,
low_memory=True,
):
if (low_memory and cfg.local_rank == 0) or not low_memory:
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_cache=False,
torch_dtype=torch.float32,
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
trust_remote_code=cfg.trust_remote_code,
)
dtype = torch_dtype if not cfg.float32 else None
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
else:
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code,
)
return model
def load_sharded_model_quant(
model_name,
model_config,
cfg,
compute_dtype=torch.bfloat16,
quant_storage=torch.float32,
low_memory=True,
verbose=False,
loading_workers=2,
):
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
trust_remote_code=cfg.trust_remote_code,
)
if hasattr(model, "transformer"):
model.transformer = _replace_linear(
model.transformer,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
)
else:
# this is the more common case with HF transformers
model.model = _replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
)
model.is_loaded_in_4bit = True
# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
except OSError as exc:
# This means the model probably doesn't have a safetensors file
raise exc
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)
quant_method = "bnb"
param_count = sum((p.numel() for n, p in model.named_parameters()))
n_workers = (
n_loading_workers(quant_method, param_count)
if loading_workers == -1
else loading_workers
)
if cfg.local_rank == 0 and verbose:
print(f"Using n_workers: {n_workers} for loading")
start = time.time()
for filename in tqdm(
files,
desc="Loading & Quantizing Model Shards",
disable=cfg.local_rank != 0,
position=0,
):
weights = safetensors.torch.load_file(filename)
parallel(
load_and_quantize_parallel,
iter(weights.items()),
n_workers=n_workers,
threadpool=True,
model=model,
dtype=quant_storage,
device=cfg.local_rank,
skip_names=[],
to_cpu=(low_memory and cfg.local_rank == 0),
to_meta=(low_memory and cfg.local_rank != 0),
verbose=verbose,
quant_method=quant_method,
)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache()
return model

View File

@@ -45,10 +45,35 @@ from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only from axolotl.utils.distributed import zero_only
from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
# copied from accelerator.FullyShardedDataParallelPlugin
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
if len(modules_children) == 0:
return None
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
return None
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
quant_config_exists = ( quant_config_exists = (
hasattr(model_config, "quantization_config") hasattr(model_config, "quantization_config")
@@ -459,7 +484,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16, "bnb_4bit_quant_storage": torch.bfloat16,
} }
if not cfg.deepspeed: if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude # for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16 # but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32 bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -470,6 +495,13 @@ def load_model(
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
) )
elif cfg.adapter == "lora" and cfg.load_in_8bit:
bnb_config = {
"load_in_8bit": True,
}
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
if cfg.load_in_8bit and cfg.adapter is not None: if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True model_kwargs["load_in_8bit"] = True
@@ -517,7 +549,31 @@ def load_model(
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
try: try:
skip_move_to_device = False
if ( if (
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
) and not qlora_fsdp:
model = load_sharded_model(
base_model,
model_config,
cfg,
torch_dtype=cfg.torch_dtype,
)
skip_move_to_device = True
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
):
quant_storage = cfg.torch_dtype
model = load_sharded_model_quant(
base_model,
model_config,
cfg,
quant_storage=quant_storage,
)
skip_move_to_device = True
elif (
model_config.model_type == "llama" model_config.model_type == "llama"
and not cfg.trust_remote_code and not cfg.trust_remote_code
and not cfg.gptq and not cfg.gptq
@@ -597,6 +653,11 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
else: else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=model_config, config=model_config,
@@ -670,13 +731,17 @@ def load_model(
needs_fa2_dtype = cfg.adapter or cfg.fsdp needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
from deepspeed.utils import ( # pylint: disable=no-name-in-module from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules, set_z3_leaf_modules,
) )
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) if cfg.model_config_type == "mixtral":
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
set_z3_leaf_modules(model, [moe_block])
elif cfg.model_config_type == "dbrx":
moe_block = get_module_class_from_name(model, "DbrxFFN")
set_z3_leaf_modules(model, [moe_block])
if cfg.model_config_type == "qwen" and cfg.adapter == "lora": if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled # Qwen doesn't play nicely with LoRA if this is enabled
@@ -686,7 +751,8 @@ def load_model(
if cfg.adapter == "lora" and loftq_bits: if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True
if qlora_fsdp: if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
# make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]: if cfg.adapter in ["lora", "qlora"]:
@@ -727,7 +793,7 @@ def load_model(
cfg.ddp cfg.ddp
and not load_in_8bit and not load_in_8bit
and not (cfg.rl and cfg.load_in_4bit) and not (cfg.rl and cfg.load_in_4bit)
and not qlora_fsdp and not skip_move_to_device
): ):
# TODO revaldate this conditional # TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
@@ -883,7 +949,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
rank = int(os.environ.get("LOCAL_RANK", 0)) rank = int(os.environ.get("LOCAL_RANK", 0))
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0: if (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_meta_for_peft(model) setup_quantized_meta_for_peft(model)
if cfg.lora_model_dir: if cfg.lora_model_dir:
@@ -908,7 +979,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
LOG.warning( LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s", exc "Exception caught during model.print_trainable_parameters(): %s", exc
) )
elif cfg.fsdp and cfg.adapter == "qlora": elif (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_peft_meta_for_training(model) setup_quantized_peft_meta_for_training(model)
return model, lora_config return model, lora_config

View File

@@ -306,6 +306,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
def setup_fsdp_envs(cfg): def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true" os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_activation_checkpointing:
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
if cfg.fsdp_config.fsdp_offload_params: if cfg.fsdp_config.fsdp_offload_params:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true" os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states: if cfg.fsdp_config.fsdp_sync_module_states: