diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 988ed29ba..3057d5da5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -340,6 +340,7 @@ def load_model( base_model, config=config, trust_remote_code=cfg.trust_remote_code or False, + low_cpu_mem_usage=True, ).half() model = tp.tensor_parallel(model, sharded=False) else: