avoid crash/oom on train end
This commit is contained in:
@@ -289,16 +289,18 @@ def save_trained_model(
|
||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
else:
|
||||
if cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
|
||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||
# TODO: add integration support so this can be implemented completely within the plugin
|
||||
|
||||
Reference in New Issue
Block a user