make sure to save the lora adapter at the end of RL/dpo training (#1573)
This commit is contained in:
@@ -212,6 +212,10 @@ def train(
|
|||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
|
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
|
trainer.model.save_pretrained(
|
||||||
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
|
|||||||
Reference in New Issue
Block a user