diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 339195df7..436b31fef 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -619,7 +619,7 @@ def load_model( and not cfg.trust_remote_code and not cfg.gptq ): - if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: skip_move_to_device = True if "device_map" in model_kwargs: del model_kwargs["device_map"] @@ -701,7 +701,7 @@ def load_model( **model_kwargs, ) else: - if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: # disabling either of these two still leads to VRAM spike before setting back down skip_move_to_device = True if "device_map" in model_kwargs: