From 787880215b3ab32ccaf81c1b2e9588c6f3e6e764 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 3 Jun 2025 14:27:09 -0700 Subject: [PATCH] fix(deepspeed): deepspeed config not being set for z3 (#2754) * fix(deepspeed): deepspeed config not being set for z3 * fix: comments --- src/axolotl/loaders/model.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 1d26a99dd..3b2a455ca 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -556,11 +556,18 @@ class ModelLoader: if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True - def _configure_zero3_memory_efficient_loading(self): - """Set the deepspeed config to load the model into RAM first before moving - to VRAM. + def _configure_zero3_memory_efficient_loading( + self, + ) -> HfTrainerDeepSpeedConfig | None: + """ + Set the deepspeed config to load the model into RAM first before moving to VRAM. - We need to return `hf_ds_cfg` as it needs to exist before model loading. + IMPORTANT + ========== + + We need to return `hf_ds_cfg` as it needs to exist before model loading for zero3. + HfTrainerDeepSpeedConfig is a class that is used to configure the DeepSpeed training. + It is not passed anywhere in the model loading function, just need to exist. """ hf_ds_cfg = None @@ -625,7 +632,8 @@ class ModelLoader: if "device_map" in self.model_kwargs: del self.model_kwargs["device_map"] - self._configure_zero3_memory_efficient_loading() + # Please don't remove underscore binding without reading the fn docstring. + _ = self._configure_zero3_memory_efficient_loading() # Load model with random initialization if specified if self.cfg.random_init_weights: @@ -695,7 +703,8 @@ class ModelLoader: if "device_map" in self.model_kwargs: del self.model_kwargs["device_map"] - self._configure_zero3_memory_efficient_loading() + # Please don't remove underscore binding without reading the fn docstring. + _ = self._configure_zero3_memory_efficient_loading() self.model = self.auto_model_loader.from_pretrained( self.base_model,