From fe250ada78ff3d5404e053f2ae050d66f3943248 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jul 2024 19:54:28 -0400 Subject: [PATCH] fix fsdp loading of models, esp 70b (#1780) --- src/axolotl/utils/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 339195df7..436b31fef 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -619,7 +619,7 @@ def load_model( and not cfg.trust_remote_code and not cfg.gptq ): - if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: skip_move_to_device = True if "device_map" in model_kwargs: del model_kwargs["device_map"] @@ -701,7 +701,7 @@ def load_model( **model_kwargs, ) else: - if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: # disabling either of these two still leads to VRAM spike before setting back down skip_move_to_device = True if "device_map" in model_kwargs: