Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
459f407e69 avoid crash/oom on train end 2025-05-15 15:53:35 -04:00

View File

@@ -289,16 +289,18 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError: except FileNotFoundError:
pass pass
elif cfg.local_rank == 0: else:
if cfg.flash_optimum and BetterTransformer: if cfg.local_rank == 0:
model = BetterTransformer.reverse(model) if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained( trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization cfg.output_dir, safe_serialization=safe_serialization
) )
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
trainer.accelerator.wait_for_everyone()
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin # TODO: add integration support so this can be implemented completely within the plugin