fix fsdp loading of models, esp 70b (#1780)
This commit is contained in:
@@ -619,7 +619,7 @@ def load_model(
|
|||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
):
|
):
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
@@ -701,7 +701,7 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
# disabling either of these two still leads to VRAM spike before setting back down
|
# disabling either of these two still leads to VRAM spike before setting back down
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
|
|||||||
Reference in New Issue
Block a user