DBRX Model Support (#1462)
* wip for dbrx finetuning * add fastcore for parallel loading of sharded weights * fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback * update to use v2 of the converted model * more fixes for dbrx loras * make sure to enable fsdp activation checkpointing * fix support for 8bit loras too for dbrx * apply z3 leaf moe fix for DBRX with deepspeed * don't raise value error since child module searches could fail and be ok * revert a previous change to fix fsdp * update mistral/mistral qlora+fsdp yamls * fix qlora+fsdp quant storage type * more edge cases for qlora-fsdp * fixes for fsdp+qlora w optimizer in 8bit * add bigstral z3 config and make sure to use full_state_dict for fsdp
This commit is contained in:
@@ -45,10 +45,35 @@ from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import zero_only
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||
def get_module_class_from_name(module, name):
|
||||
"""
|
||||
Gets a class from a module by its name.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module to get the class from.
|
||||
name (`str`): The name of the class.
|
||||
"""
|
||||
modules_children = list(module.children())
|
||||
if module.__class__.__name__ == name:
|
||||
return module.__class__
|
||||
|
||||
if len(modules_children) == 0:
|
||||
return None
|
||||
|
||||
for child_module in modules_children:
|
||||
module_class = get_module_class_from_name(child_module, name)
|
||||
if module_class is not None:
|
||||
return module_class
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||
quant_config_exists = (
|
||||
hasattr(model_config, "quantization_config")
|
||||
@@ -459,7 +484,7 @@ def load_model(
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||
}
|
||||
if not cfg.deepspeed:
|
||||
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||
@@ -470,6 +495,13 @@ def load_model(
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
elif cfg.adapter == "lora" and cfg.load_in_8bit:
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
if cfg.load_in_8bit and cfg.adapter is not None:
|
||||
model_kwargs["load_in_8bit"] = True
|
||||
@@ -517,7 +549,31 @@ def load_model(
|
||||
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
|
||||
|
||||
try:
|
||||
skip_move_to_device = False
|
||||
if (
|
||||
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
) and not qlora_fsdp:
|
||||
model = load_sharded_model(
|
||||
base_model,
|
||||
model_config,
|
||||
cfg,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
)
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
qlora_fsdp
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and cfg.model_config_type == "dbrx"
|
||||
):
|
||||
quant_storage = cfg.torch_dtype
|
||||
model = load_sharded_model_quant(
|
||||
base_model,
|
||||
model_config,
|
||||
cfg,
|
||||
quant_storage=quant_storage,
|
||||
)
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
model_config.model_type == "llama"
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
@@ -597,6 +653,11 @@ def load_model(
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
if "device_map" in model_kwargs:
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
@@ -670,13 +731,17 @@ def load_model(
|
||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||
skip_prepare_model_for_kbit_training = False
|
||||
|
||||
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled():
|
||||
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
||||
set_z3_leaf_modules,
|
||||
)
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
if cfg.model_config_type == "mixtral":
|
||||
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
||||
set_z3_leaf_modules(model, [moe_block])
|
||||
elif cfg.model_config_type == "dbrx":
|
||||
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
||||
set_z3_leaf_modules(model, [moe_block])
|
||||
|
||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
@@ -686,7 +751,8 @@ def load_model(
|
||||
if cfg.adapter == "lora" and loftq_bits:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if qlora_fsdp:
|
||||
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
|
||||
# make sure everything is in the same dtype
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if cfg.adapter in ["lora", "qlora"]:
|
||||
@@ -727,7 +793,7 @@ def load_model(
|
||||
cfg.ddp
|
||||
and not load_in_8bit
|
||||
and not (cfg.rl and cfg.load_in_4bit)
|
||||
and not qlora_fsdp
|
||||
and not skip_move_to_device
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
@@ -883,7 +949,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
|
||||
if (
|
||||
cfg.fsdp
|
||||
and cfg.adapter
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and rank != 0
|
||||
):
|
||||
setup_quantized_meta_for_peft(model)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
@@ -908,7 +979,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
LOG.warning(
|
||||
"Exception caught during model.print_trainable_parameters(): %s", exc
|
||||
)
|
||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
||||
elif (
|
||||
cfg.fsdp
|
||||
and cfg.adapter
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and rank != 0
|
||||
):
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
return model, lora_config
|
||||
|
||||
Reference in New Issue
Block a user