qlora-fsdp ram efficient loading with hf trainer (#1791)
* fix 405b with lower cpu ram requirements * make sure to use doouble quant and only skip output embeddings * set model attributes * more fixes for sharded fsdp loading * update the base model in example to use pre-quantized nf4-bf16 weights * upstream fixes for qlora+fsdp
This commit is contained in:
@@ -624,14 +624,21 @@ def load_model(
|
||||
elif (
|
||||
qlora_fsdp
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and cfg.model_config_type == "dbrx"
|
||||
and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
|
||||
):
|
||||
quant_storage = cfg.torch_dtype
|
||||
quantization_config = hasattr(
|
||||
model_config, "quantization_config"
|
||||
) and getattr(model_config, "quantization_config")
|
||||
quantization_config = (
|
||||
quantization_config or model_kwargs["quantization_config"]
|
||||
)
|
||||
model = load_sharded_model_quant(
|
||||
base_model,
|
||||
model_config,
|
||||
cfg,
|
||||
quant_storage=quant_storage,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
skip_move_to_device = True
|
||||
elif (
|
||||
|
||||
Reference in New Issue
Block a user