fix fsdp loading of models, esp 70b (#1780)
This commit is contained in:
@@ -619,7 +619,7 @@ def load_model(
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
):
|
||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
if "device_map" in model_kwargs:
|
||||
del model_kwargs["device_map"]
|
||||
@@ -701,7 +701,7 @@ def load_model(
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
# disabling either of these two still leads to VRAM spike before setting back down
|
||||
skip_move_to_device = True
|
||||
if "device_map" in model_kwargs:
|
||||
|
||||
Reference in New Issue
Block a user