qlora-fsdp ram efficient loading with hf trainer (#1791)

* fix 405b with lower cpu ram requirements

* make sure to use doouble quant and only skip output embeddings

* set model attributes

* more fixes for sharded fsdp loading

* update the base model in example to use pre-quantized nf4-bf16 weights

* upstream fixes  for qlora+fsdp
This commit is contained in:
Wing Lian
2024-07-30 19:21:38 -04:00
committed by GitHub
parent dbf8fb549e
commit 3ebf22464b
10 changed files with 52 additions and 14 deletions

View File

@@ -624,14 +624,21 @@ def load_model(
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
):
quant_storage = cfg.torch_dtype
quantization_config = hasattr(
model_config, "quantization_config"
) and getattr(model_config, "quantization_config")
quantization_config = (
quantization_config or model_kwargs["quantization_config"]
)
model = load_sharded_model_quant(
base_model,
model_config,
cfg,
quant_storage=quant_storage,
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (