fixes to prevent vram spike when train starts (#1742)
This commit is contained in:
@@ -599,9 +599,12 @@ def load_model(
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
):
|
||||
from transformers import LlamaForCausalLM
|
||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
if "device_map" in model_kwargs:
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
**model_kwargs,
|
||||
@@ -634,7 +637,11 @@ def load_model(
|
||||
base_model,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif model_type and not cfg.trust_remote_code:
|
||||
elif (
|
||||
model_type
|
||||
and model_type != "AutoModelForCausalLM"
|
||||
and not cfg.trust_remote_code
|
||||
):
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
@@ -675,6 +682,7 @@ def load_model(
|
||||
)
|
||||
else:
|
||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
# disabling either of these two still leads to VRAM spike before setting back down
|
||||
skip_move_to_device = True
|
||||
if "device_map" in model_kwargs:
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
Reference in New Issue
Block a user