make sure to save the lora adapter at the end of RL/dpo training (#1573)

This commit is contained in:
Wing Lian
2024-05-08 10:39:33 -04:00
committed by GitHub
parent cb78a36374
commit 796a085b2f

View File

@@ -212,6 +212,10 @@ def train(
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: