fixes to prevent vram spike when train starts (#1742)
This commit is contained in:
@@ -599,9 +599,12 @@ def load_model(
|
|||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
):
|
):
|
||||||
from transformers import LlamaForCausalLM
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
|
skip_move_to_device = True
|
||||||
|
if "device_map" in model_kwargs:
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -634,7 +637,11 @@ def load_model(
|
|||||||
base_model,
|
base_model,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type and not cfg.trust_remote_code:
|
elif (
|
||||||
|
model_type
|
||||||
|
and model_type != "AutoModelForCausalLM"
|
||||||
|
and not cfg.trust_remote_code
|
||||||
|
):
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -675,6 +682,7 @@ def load_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
|
# disabling either of these two still leads to VRAM spike before setting back down
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
|
|||||||
Reference in New Issue
Block a user