set fsdp state dict (#584)

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
This commit is contained in:
Jan Philipp Harries
2023-09-15 23:47:36 +02:00
committed by GitHub
parent aeec7c4688
commit be75668400

View File

@@ -117,6 +117,10 @@ def train(
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()