From a4a5bf057ff3800e32cc0ca5eecd40198f1266a3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Jul 2024 09:53:13 -0400 Subject: [PATCH] fixes to prevent vram spike when train starts (#1742) --- src/axolotl/utils/models.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d479d425d..d8eac1ce1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -599,9 +599,12 @@ def load_model( and not cfg.trust_remote_code and not cfg.gptq ): - from transformers import LlamaForCausalLM + if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + skip_move_to_device = True + if "device_map" in model_kwargs: + del model_kwargs["device_map"] - model = LlamaForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( base_model, config=model_config, **model_kwargs, @@ -634,7 +637,11 @@ def load_model( base_model, **model_kwargs, ) - elif model_type and not cfg.trust_remote_code: + elif ( + model_type + and model_type != "AutoModelForCausalLM" + and not cfg.trust_remote_code + ): if cfg.gptq: model = AutoModelForCausalLM.from_pretrained( base_model, @@ -675,6 +682,7 @@ def load_model( ) else: if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + # disabling either of these two still leads to VRAM spike before setting back down skip_move_to_device = True if "device_map" in model_kwargs: del model_kwargs["device_map"]