use low_cpu_mem_usage with ds zero 1 or 2

This commit is contained in:
Wing Lian
2024-01-16 19:33:44 -05:00
parent 1b59a3e698
commit 1b33588f09

View File

@@ -355,7 +355,9 @@ def load_model(
# else:
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
if is_deepspeed_zero3_enabled() or cfg.deepspeed:
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
elif cfg.deepspeed:
del model_kwargs["device_map"]
model_kwargs["low_cpu_mem_usage"] = True