DBRX Model Support (#1462)
* wip for dbrx finetuning * add fastcore for parallel loading of sharded weights * fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback * update to use v2 of the converted model * more fixes for dbrx loras * make sure to enable fsdp activation checkpointing * fix support for 8bit loras too for dbrx * apply z3 leaf moe fix for DBRX with deepspeed * don't raise value error since child module searches could fail and be ok * revert a previous change to fix fsdp * update mistral/mistral qlora+fsdp yamls * fix qlora+fsdp quant storage type * more edge cases for qlora-fsdp * fixes for fsdp+qlora w optimizer in 8bit * add bigstral z3 config and make sure to use full_state_dict for fsdp
This commit is contained in:
@@ -918,10 +918,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback())
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from peft import PeftModel
|
||||
@@ -81,6 +82,8 @@ def train(
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
# we wait unitl the last possible moment to setup Accelerator
|
||||
Accelerator()
|
||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
|
||||
@@ -259,6 +259,7 @@ class ModelInputConfig(BaseModel):
|
||||
|
||||
base_model: str
|
||||
base_model_config: Optional[str] = None
|
||||
cls_model_config: Optional[str] = None
|
||||
tokenizer_config: Optional[str] = None
|
||||
tokenizer_use_fast: Optional[bool] = None
|
||||
tokenizer_legacy: Optional[bool] = None
|
||||
@@ -971,9 +972,16 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_w_8bit_optimizer(cls, data):
|
||||
if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
|
||||
raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
|
||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and "8bit" in data.get("optimizer", "")
|
||||
and data.get("fsdp_config")
|
||||
and data["fsdp_config"].get("fsdp_offload_params")
|
||||
):
|
||||
raise ValueError(
|
||||
f"FSDP Offload not compatible with {data.get('optimizer')}"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -4,27 +4,25 @@ utility helpers for distributed checks
|
||||
import os
|
||||
import pickle # nosec
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import Accelerator
|
||||
from accelerate import PartialState
|
||||
|
||||
accelerate = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def load_accelerate():
|
||||
global accelerate # pylint: disable=global-statement
|
||||
accelerate = Accelerator()
|
||||
distributed_state = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def is_distributed():
|
||||
"""
|
||||
Check if distributed training is initialized.
|
||||
"""
|
||||
global accelerate # pylint: disable=global-statement
|
||||
if not accelerate:
|
||||
accelerate = Accelerator()
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
global distributed_state # pylint: disable=global-statement
|
||||
if not distributed_state:
|
||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||
|
||||
return distributed_state.use_distributed and distributed_state.initialized
|
||||
|
||||
|
||||
def barrier():
|
||||
|
||||
259
src/axolotl/utils/model_shard_quant.py
Normal file
259
src/axolotl/utils/model_shard_quant.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
module to handle loading model on cpu/meta device for FSDP
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from bitsandbytes.nn import Linear4bit, Params4bit
|
||||
from fastcore.parallel import parallel
|
||||
from torch import Tensor, nn
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
||||
|
||||
|
||||
def _replace_linear(
|
||||
model: nn.Module,
|
||||
linear_replacement: Type[nn.Module],
|
||||
quant_config: Union[dict, None] = None,
|
||||
skip_modules=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Replace linear modules with a new Linear module.
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
linear_replacement (`torch.nn.Module`):
|
||||
The linear module that replaces the old one. Only expects standard arguments.
|
||||
If other arguments need to be passed, use a lambda.
|
||||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||
List of modules names not to convert. Defaults to `lm_head`.
|
||||
"""
|
||||
if skip_modules is None:
|
||||
skip_modules = ["lm_head"]
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
_replace_linear(
|
||||
module, linear_replacement, quant_config, skip_modules, **kwargs
|
||||
)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||
if issubclass(linear_replacement, Linear4bit):
|
||||
model._modules[ # pylint: disable=protected-access
|
||||
name
|
||||
] = linear_replacement(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported linear replacement: {type(linear_replacement)}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def load_and_quantize(
|
||||
module: nn.Module,
|
||||
name: str,
|
||||
value: Tensor,
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = None,
|
||||
skip_names: Optional[List[str]] = None,
|
||||
to_cpu: bool = False,
|
||||
to_meta: bool = False,
|
||||
verbose: bool = False,
|
||||
quant_method: str = "bnb",
|
||||
):
|
||||
"""
|
||||
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
||||
|
||||
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
|
||||
"""
|
||||
|
||||
if not skip_names:
|
||||
skip_names = []
|
||||
|
||||
def place_on_device(value):
|
||||
if to_meta:
|
||||
device = "meta"
|
||||
elif to_cpu:
|
||||
device = "cpu"
|
||||
return value.to(device=device, dtype=dtype)
|
||||
|
||||
if any(skip_name in name for skip_name in skip_names):
|
||||
if verbose:
|
||||
print(f"Skipping {name} because it is in skip_names")
|
||||
return
|
||||
|
||||
module_key, _, value_key = name.rpartition(".")
|
||||
try:
|
||||
submodule = module.get_submodule(module_key)
|
||||
except AttributeError as exc:
|
||||
print(f"Module {module_key} not found:\n{exc}")
|
||||
return
|
||||
|
||||
try:
|
||||
if quant_method == "bnb":
|
||||
param = submodule.get_parameter(value_key)
|
||||
if isinstance(param, Params4bit):
|
||||
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
||||
# shape as the quantized Params4bit with an initialized quant_state. However,
|
||||
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
||||
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
||||
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
||||
value = type(param)(
|
||||
value.to(device=device, dtype=dtype).data, **param.__dict__
|
||||
).cuda(device)
|
||||
if to_meta:
|
||||
value = type(param)(value.data.to("meta"), **value.__dict__)
|
||||
elif to_cpu:
|
||||
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
||||
else:
|
||||
value = type(param)(place_on_device(value).data)
|
||||
|
||||
except AttributeError:
|
||||
# it's a buffer
|
||||
value = place_on_device(value)
|
||||
|
||||
setattr(submodule, value_key, value)
|
||||
|
||||
|
||||
def n_loading_workers(quant_method: str, param_count: float):
|
||||
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
left = int(os.cpu_count() / torch.cuda.device_count())
|
||||
model_params_b = 70
|
||||
right = int(
|
||||
(4 if quant_method == "hqq" else 8)
|
||||
* (devprops.total_memory / 1e9 / 40)
|
||||
* (model_params_b / (param_count / 1e9))
|
||||
)
|
||||
return min(left, right)
|
||||
|
||||
|
||||
def load_sharded_model(
|
||||
model_name,
|
||||
model_config,
|
||||
cfg,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_memory=True,
|
||||
):
|
||||
if (low_memory and cfg.local_rank == 0) or not low_memory:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
use_cache=False,
|
||||
torch_dtype=torch.float32,
|
||||
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
)
|
||||
dtype = torch_dtype if not cfg.float32 else None
|
||||
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
|
||||
else:
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def load_sharded_model_quant(
|
||||
model_name,
|
||||
model_config,
|
||||
cfg,
|
||||
compute_dtype=torch.bfloat16,
|
||||
quant_storage=torch.float32,
|
||||
low_memory=True,
|
||||
verbose=False,
|
||||
loading_workers=2,
|
||||
):
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
model_config,
|
||||
trust_remote_code=cfg.trust_remote_code,
|
||||
)
|
||||
if hasattr(model, "transformer"):
|
||||
model.transformer = _replace_linear(
|
||||
model.transformer,
|
||||
Linear4bit,
|
||||
compute_dtype=compute_dtype,
|
||||
quant_type="nf4",
|
||||
quant_storage=quant_storage,
|
||||
)
|
||||
else:
|
||||
# this is the more common case with HF transformers
|
||||
model.model = _replace_linear(
|
||||
model.model,
|
||||
Linear4bit,
|
||||
compute_dtype=compute_dtype,
|
||||
quant_type="nf4",
|
||||
quant_storage=quant_storage,
|
||||
)
|
||||
model.is_loaded_in_4bit = True
|
||||
|
||||
# Grab the safetensors files that hold the weights
|
||||
try:
|
||||
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
|
||||
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
|
||||
except OSError:
|
||||
try:
|
||||
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
||||
files = []
|
||||
files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
|
||||
except OSError as exc:
|
||||
# This means the model probably doesn't have a safetensors file
|
||||
raise exc
|
||||
|
||||
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
||||
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
||||
def load_and_quantize_parallel(name_param, model, **kwargs):
|
||||
name, param = name_param
|
||||
load_and_quantize(model, name, param, **kwargs)
|
||||
|
||||
quant_method = "bnb"
|
||||
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
||||
|
||||
n_workers = (
|
||||
n_loading_workers(quant_method, param_count)
|
||||
if loading_workers == -1
|
||||
else loading_workers
|
||||
)
|
||||
if cfg.local_rank == 0 and verbose:
|
||||
print(f"Using n_workers: {n_workers} for loading")
|
||||
|
||||
start = time.time()
|
||||
for filename in tqdm(
|
||||
files,
|
||||
desc="Loading & Quantizing Model Shards",
|
||||
disable=cfg.local_rank != 0,
|
||||
position=0,
|
||||
):
|
||||
weights = safetensors.torch.load_file(filename)
|
||||
parallel(
|
||||
load_and_quantize_parallel,
|
||||
iter(weights.items()),
|
||||
n_workers=n_workers,
|
||||
threadpool=True,
|
||||
model=model,
|
||||
dtype=quant_storage,
|
||||
device=cfg.local_rank,
|
||||
skip_names=[],
|
||||
to_cpu=(low_memory and cfg.local_rank == 0),
|
||||
to_meta=(low_memory and cfg.local_rank != 0),
|
||||
verbose=verbose,
|
||||
quant_method=quant_method,
|
||||
)
|
||||
|
||||
if cfg.local_rank == 0 and verbose:
|
||||
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
||||
# cleanup any extra memory usage from parallel loading
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model
|
||||
@@ -45,10 +45,35 @@ from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import zero_only
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||
def get_module_class_from_name(module, name):
|
||||
"""
|
||||
Gets a class from a module by its name.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module to get the class from.
|
||||
name (`str`): The name of the class.
|
||||
"""
|
||||
modules_children = list(module.children())
|
||||
if module.__class__.__name__ == name:
|
||||
return module.__class__
|
||||
|
||||
if len(modules_children) == 0:
|
||||
return None
|
||||
|
||||
for child_module in modules_children:
|
||||
module_class = get_module_class_from_name(child_module, name)
|
||||
if module_class is not None:
|
||||
return module_class
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||
quant_config_exists = (
|
||||
hasattr(model_config, "quantization_config")
|
||||
@@ -459,7 +484,7 @@ def load_model(
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||
}
|
||||
if not cfg.deepspeed:
|
||||
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||
@@ -470,6 +495,13 @@ def load_model(
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
elif cfg.adapter == "lora" and cfg.load_in_8bit:
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
if cfg.load_in_8bit and cfg.adapter is not None:
|
||||
model_kwargs["load_in_8bit"] = True
|
||||
@@ -517,7 +549,31 @@ def load_model(
|
||||
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
|
||||
|
||||
try:
|
||||
skip_move_to_device = False
|
||||
if (
|
||||
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
) and not qlora_fsdp:
|
||||
model = load_sharded_model(
|
||||
base_model,
|
||||
model_config,
|
||||
cfg,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
)
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
qlora_fsdp
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and cfg.model_config_type == "dbrx"
|
||||
):
|
||||
quant_storage = cfg.torch_dtype
|
||||
model = load_sharded_model_quant(
|
||||
base_model,
|
||||
model_config,
|
||||
cfg,
|
||||
quant_storage=quant_storage,
|
||||
)
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
model_config.model_type == "llama"
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
@@ -597,6 +653,11 @@ def load_model(
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
if "device_map" in model_kwargs:
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
@@ -670,13 +731,17 @@ def load_model(
|
||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||
skip_prepare_model_for_kbit_training = False
|
||||
|
||||
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
||||
set_z3_leaf_modules,
|
||||
)
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
if cfg.model_config_type == "mixtral":
|
||||
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
||||
set_z3_leaf_modules(model, [moe_block])
|
||||
elif cfg.model_config_type == "dbrx":
|
||||
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
||||
set_z3_leaf_modules(model, [moe_block])
|
||||
|
||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
@@ -686,7 +751,8 @@ def load_model(
|
||||
if cfg.adapter == "lora" and loftq_bits:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if qlora_fsdp:
|
||||
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
|
||||
# make sure everything is in the same dtype
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if cfg.adapter in ["lora", "qlora"]:
|
||||
@@ -727,7 +793,7 @@ def load_model(
|
||||
cfg.ddp
|
||||
and not load_in_8bit
|
||||
and not (cfg.rl and cfg.load_in_4bit)
|
||||
and not qlora_fsdp
|
||||
and not skip_move_to_device
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
@@ -883,7 +949,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
|
||||
if (
|
||||
cfg.fsdp
|
||||
and cfg.adapter
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and rank != 0
|
||||
):
|
||||
setup_quantized_meta_for_peft(model)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
@@ -908,7 +979,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
LOG.warning(
|
||||
"Exception caught during model.print_trainable_parameters(): %s", exc
|
||||
)
|
||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
||||
elif (
|
||||
cfg.fsdp
|
||||
and cfg.adapter
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and rank != 0
|
||||
):
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
return model, lora_config
|
||||
|
||||
@@ -306,6 +306,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
|
||||
def setup_fsdp_envs(cfg):
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
||||
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
|
||||
if cfg.fsdp_config.fsdp_offload_params:
|
||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||
if cfg.fsdp_config.fsdp_sync_module_states:
|
||||
|
||||
Reference in New Issue
Block a user