Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
7a7c56f018 fixes to support fsdp-qdora 2024-04-23 08:37:04 -04:00
5 changed files with 26 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
qlora_fsdp_alt_loader: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
lora_r: 8 lora_r: 8

View File

@@ -22,6 +22,7 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
qlora_fsdp_alt_loader: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
lora_r: 8 lora_r: 8

View File

@@ -188,6 +188,7 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None
qlora_fsdp_alt_loader: Optional[bool] = None
lora_on_cpu: Optional[bool] = None lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None gptq: Optional[bool] = None

View File

@@ -70,6 +70,7 @@ def load_and_quantize(
to_meta: bool = False, to_meta: bool = False,
verbose: bool = False, verbose: bool = False,
quant_method: str = "bnb", quant_method: str = "bnb",
is_dora: bool = False,
): ):
""" """
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`. Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
@@ -108,6 +109,12 @@ def load_and_quantize(
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then # workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0. # replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
if is_dora:
setattr(
submodule,
"dora_scale",
value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"),
)
value = type(param)( value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__ value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device) ).cuda(device)
@@ -177,6 +184,7 @@ def load_sharded_model_quant(
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
model_config, model_config,
attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
trust_remote_code=cfg.trust_remote_code, trust_remote_code=cfg.trust_remote_code,
) )
if hasattr(model, "transformer"): if hasattr(model, "transformer"):
@@ -249,6 +257,7 @@ def load_sharded_model_quant(
to_meta=(low_memory and cfg.local_rank != 0), to_meta=(low_memory and cfg.local_rank != 0),
verbose=verbose, verbose=verbose,
quant_method=quant_method, quant_method=quant_method,
is_dora=cfg.peft_use_dora,
) )
if cfg.local_rank == 0 and verbose: if cfg.local_rank == 0 and verbose:

View File

@@ -34,6 +34,7 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.quantizers import AutoHfQuantizer
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
@@ -568,7 +569,7 @@ def load_model(
elif ( elif (
qlora_fsdp qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx" and cfg.qlora_fsdp_alt_loader
): ):
quant_storage = cfg.torch_dtype quant_storage = cfg.torch_dtype
model = load_sharded_model_quant( model = load_sharded_model_quant(
@@ -577,6 +578,11 @@ def load_model(
cfg, cfg,
quant_storage=quant_storage, quant_storage=quant_storage,
) )
if model_kwargs["quantization_config"]:
hf_quantizer = AutoHfQuantizer.from_config(
model_kwargs["quantization_config"]
)
model.hf_quantizer = hf_quantizer
skip_move_to_device = True skip_move_to_device = True
elif ( elif (
model_config.model_type == "llama" model_config.model_type == "llama"
@@ -1003,3 +1009,10 @@ def ensure_dtype(model, dtype=torch.bfloat16):
module.to(dtype) module.to(dtype)
except AttributeError: except AttributeError:
pass pass
for name, param in model.named_parameters():
try:
if param.data.dtype != dtype:
print(f"Converting module {name}: {param.data.dtype} -> {dtype}")
param.data = param.data.to(dtype)
except AttributeError:
pass