diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 55721f820..394e335f4 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -355,8 +355,9 @@ def load_model( # else: # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() or cfg.deepspeed: del model_kwargs["device_map"] + model_kwargs["low_cpu_mem_usage"] = True if cfg.model_revision: model_kwargs["revision"] = cfg.model_revision