Compare commits
1 Commits
custom-mod
...
wait-distr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
459f407e69 |
@@ -289,16 +289,18 @@ def save_trained_model(
|
|||||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
elif cfg.local_rank == 0:
|
else:
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.local_rank == 0:
|
||||||
model = BetterTransformer.reverse(model)
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
trainer.model.save_pretrained(
|
trainer.model.save_pretrained(
|
||||||
cfg.output_dir, safe_serialization=safe_serialization
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
)
|
)
|
||||||
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
trainer.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||||
# TODO: add integration support so this can be implemented completely within the plugin
|
# TODO: add integration support so this can be implemented completely within the plugin
|
||||||
|
|||||||
Reference in New Issue
Block a user