fixes to prevent vram spike when train starts (#1742)

This commit is contained in:
Wing Lian
2024-07-13 09:53:13 -04:00
committed by GitHub
parent 137d84d1b4
commit a4a5bf057f

View File

@@ -599,9 +599,12 @@ def load_model(
and not cfg.trust_remote_code
and not cfg.gptq
):
from transformers import LlamaForCausalLM
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = LlamaForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
@@ -634,7 +637,11 @@ def load_model(
base_model,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
elif (
model_type
and model_type != "AutoModelForCausalLM"
and not cfg.trust_remote_code
):
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
@@ -675,6 +682,7 @@ def load_model(
)
else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]