Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
459f407e69 avoid crash/oom on train end 2025-05-15 15:53:35 -04:00

View File

@@ -289,7 +289,8 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
elif cfg.local_rank == 0:
else:
if cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
@@ -299,6 +300,7 @@ def save_trained_model(
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
trainer.accelerator.wait_for_everyone()
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin