fix fsdp loading of models, esp 70b (#1780)

This commit is contained in:
Wing Lian
2024-07-23 19:54:28 -04:00
committed by GitHub
parent e6b299dd79
commit fe250ada78

View File

@@ -619,7 +619,7 @@ def load_model(
and not cfg.trust_remote_code and not cfg.trust_remote_code
and not cfg.gptq and not cfg.gptq
): ):
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True skip_move_to_device = True
if "device_map" in model_kwargs: if "device_map" in model_kwargs:
del model_kwargs["device_map"] del model_kwargs["device_map"]
@@ -701,7 +701,7 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
else: else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
# disabling either of these two still leads to VRAM spike before setting back down # disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True skip_move_to_device = True
if "device_map" in model_kwargs: if "device_map" in model_kwargs: