fsdp fft loading on meta device

This commit is contained in:
Wing Lian
2024-08-11 22:18:04 -04:00
parent 1853d6021d
commit 2b890ead05

View File

@@ -655,6 +655,14 @@ def load_model(
if "device_map" in model_kwargs: if "device_map" in model_kwargs:
del model_kwargs["device_map"] del model_kwargs["device_map"]
if cfg.fsdp and not cfg.adapter and cfg.local_rank != 0:
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=model_config, config=model_config,