set fsdp state dict (#584)
Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
aeec7c4688
commit
be75668400
@@ -117,6 +117,10 @@ def train(
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
model = model.merge_and_unload()
|
||||
|
||||
Reference in New Issue
Block a user