fix fsdp loading of models, esp 70b (#1780)

This commit is contained in:
Wing Lian
2024-07-23 19:54:28 -04:00
committed by GitHub
parent e6b299dd79
commit fe250ada78

View File

@@ -619,7 +619,7 @@ def load_model(
and not cfg.trust_remote_code
and not cfg.gptq
):
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
@@ -701,7 +701,7 @@ def load_model(
**model_kwargs,
)
else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in model_kwargs: