fixes to prevent vram spike when train starts (#1742)

This commit is contained in:
Wing Lian
2024-07-13 09:53:13 -04:00
committed by GitHub
parent 137d84d1b4
commit a4a5bf057f

View File

@@ -599,9 +599,12 @@ def load_model(
and not cfg.trust_remote_code and not cfg.trust_remote_code
and not cfg.gptq and not cfg.gptq
): ):
from transformers import LlamaForCausalLM if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = LlamaForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=model_config, config=model_config,
**model_kwargs, **model_kwargs,
@@ -634,7 +637,11 @@ def load_model(
base_model, base_model,
**model_kwargs, **model_kwargs,
) )
elif model_type and not cfg.trust_remote_code: elif (
model_type
and model_type != "AutoModelForCausalLM"
and not cfg.trust_remote_code
):
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
@@ -675,6 +682,7 @@ def load_model(
) )
else: else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True skip_move_to_device = True
if "device_map" in model_kwargs: if "device_map" in model_kwargs:
del model_kwargs["device_map"] del model_kwargs["device_map"]