Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
7a7c56f018 fixes to support fsdp-qdora 2024-04-23 08:37:04 -04:00
5 changed files with 26 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ wandb_watch:
wandb_name:
wandb_log_model:
qlora_fsdp_alt_loader: true
adapter: lora
lora_model_dir:
lora_r: 8

View File

@@ -22,6 +22,7 @@ wandb_watch:
wandb_name:
wandb_log_model:
qlora_fsdp_alt_loader: true
adapter: lora
lora_model_dir:
lora_r: 8

View File

@@ -188,6 +188,7 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
qlora_fsdp_alt_loader: Optional[bool] = None
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None

View File

@@ -70,6 +70,7 @@ def load_and_quantize(
to_meta: bool = False,
verbose: bool = False,
quant_method: str = "bnb",
is_dora: bool = False,
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
@@ -108,6 +109,12 @@ def load_and_quantize(
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
if is_dora:
setattr(
submodule,
"dora_scale",
value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"),
)
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
@@ -177,6 +184,7 @@ def load_sharded_model_quant(
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
trust_remote_code=cfg.trust_remote_code,
)
if hasattr(model, "transformer"):
@@ -249,6 +257,7 @@ def load_sharded_model_quant(
to_meta=(low_memory and cfg.local_rank != 0),
verbose=verbose,
quant_method=quant_method,
is_dora=cfg.peft_use_dora,
)
if cfg.local_rank == 0 and verbose:

View File

@@ -34,6 +34,7 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.quantizers import AutoHfQuantizer
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
@@ -568,7 +569,7 @@ def load_model(
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
and cfg.qlora_fsdp_alt_loader
):
quant_storage = cfg.torch_dtype
model = load_sharded_model_quant(
@@ -577,6 +578,11 @@ def load_model(
cfg,
quant_storage=quant_storage,
)
if model_kwargs["quantization_config"]:
hf_quantizer = AutoHfQuantizer.from_config(
model_kwargs["quantization_config"]
)
model.hf_quantizer = hf_quantizer
skip_move_to_device = True
elif (
model_config.model_type == "llama"
@@ -1003,3 +1009,10 @@ def ensure_dtype(model, dtype=torch.bfloat16):
module.to(dtype)
except AttributeError:
pass
for name, param in model.named_parameters():
try:
if param.data.dtype != dtype:
print(f"Converting module {name}: {param.data.dtype} -> {dtype}")
param.data = param.data.to(dtype)
except AttributeError:
pass