From 13eea21f9bc5b9f675461de016d22b02a548f610 Mon Sep 17 00:00:00 2001 From: Philip May Date: Tue, 6 Feb 2024 06:38:43 +0100 Subject: [PATCH] Add more save strategies for DPO training. (#1255) * Set save_strategy and save_steps in HFDPOTrainerBuilder * fix doublicate save_steps --- src/axolotl/core/trainer_builder.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 85c0dc7db..73eddd426 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1096,13 +1096,21 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): "use_reentrant": False } + # set save_strategy and save_steps + if self.cfg.save_steps: + training_args_kwargs["save_strategy"] = "steps" + training_args_kwargs["save_steps"] = self.cfg.save_steps + elif self.cfg.save_strategy: + training_args_kwargs["save_strategy"] = self.cfg.save_strategy + else: + # default to saving each epoch if not defined + training_args_kwargs["save_strategy"] = "epoch" + training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=self.cfg.max_steps or total_num_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, - save_strategy="steps", - save_steps=self.cfg.save_steps, output_dir=self.cfg.output_dir, warmup_steps=self.cfg.warmup_steps, logging_first_step=True,