From 459f407e692736943a81d81a2aec4b6cbc133730 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 15 May 2025 15:53:35 -0400 Subject: [PATCH] avoid crash/oom on train end --- src/axolotl/train.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 90ab10e9f..3d044ece6 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -289,16 +289,18 @@ def save_trained_model( os.remove(os.path.join(cfg.output_dir, "model.safetensors")) except FileNotFoundError: pass - elif cfg.local_rank == 0: - if cfg.flash_optimum and BetterTransformer: - model = BetterTransformer.reverse(model) + else: + if cfg.local_rank == 0: + if cfg.flash_optimum and BetterTransformer: + model = BetterTransformer.reverse(model) - if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: - trainer.model.save_pretrained( - cfg.output_dir, safe_serialization=safe_serialization - ) + if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: + trainer.model.save_pretrained( + cfg.output_dir, safe_serialization=safe_serialization + ) - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + trainer.accelerator.wait_for_everyone() if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: # TODO: add integration support so this can be implemented completely within the plugin