various batch of fixes (#1785)
* various batch of fixes * more tweaks * fix autoawq requirement for torch flexibility * simplify conditionals * multi-node fixes wip * bump transformers and include 405b qlora+fsdp yaml
This commit is contained in:
@@ -29,6 +29,7 @@ from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
GPTQConfig,
|
||||
PreTrainedModel,
|
||||
@@ -36,6 +37,7 @@ from transformers import ( # noqa: F401
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
@@ -510,7 +512,25 @@ def load_model(
|
||||
model_kwargs["quantization_config"] = GPTQConfig(
|
||||
**model_config.quantization_config
|
||||
)
|
||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||
if (
|
||||
cfg.adapter in ["qlora", "lora"]
|
||||
and hasattr(model_config, "quantization_config")
|
||||
and model_config.quantization_config["quant_method"]
|
||||
in ["gptq", "awq", "bitsandbytes"]
|
||||
):
|
||||
if model_config.quantization_config["quant_method"] == "gptq":
|
||||
model_kwargs["quantization_config"] = GPTQConfig(
|
||||
**model_config.quantization_config
|
||||
)
|
||||
elif model_config.quantization_config["quant_method"] == "awq":
|
||||
model_kwargs["quantization_config"] = AwqConfig(
|
||||
**model_config.quantization_config
|
||||
)
|
||||
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**model_config.quantization_config
|
||||
)
|
||||
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
"llm_int8_threshold": 6.0,
|
||||
@@ -785,12 +805,14 @@ def load_model(
|
||||
set_z3_leaf_modules,
|
||||
)
|
||||
|
||||
if cfg.model_config_type == "mixtral":
|
||||
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
||||
set_z3_leaf_modules(model, [moe_block])
|
||||
elif cfg.model_config_type == "dbrx":
|
||||
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
||||
set_z3_leaf_modules(model, [moe_block])
|
||||
if cfg.model_config_type in MOE_ARCH_BLOCK:
|
||||
set_z3_leaf_modules(
|
||||
model,
|
||||
[
|
||||
get_module_class_from_name(model, module_name)
|
||||
for module_name in MOE_ARCH_BLOCK[cfg.model_config_type]
|
||||
],
|
||||
)
|
||||
|
||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
@@ -804,6 +826,9 @@ def load_model(
|
||||
# make sure everything is in the same dtype
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if cfg.adapter in ["lora", "qlora"]:
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
@@ -838,6 +863,9 @@ def load_model(
|
||||
else:
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
if (
|
||||
cfg.ddp
|
||||
and not load_in_8bit
|
||||
|
||||
Reference in New Issue
Block a user