Add more save strategies for DPO training. (#1255)
* Set save_strategy and save_steps in HFDPOTrainerBuilder * fix doublicate save_steps
This commit is contained in:
@@ -1096,13 +1096,21 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
"use_reentrant": False
|
||||
}
|
||||
|
||||
# set save_strategy and save_steps
|
||||
if self.cfg.save_steps:
|
||||
training_args_kwargs["save_strategy"] = "steps"
|
||||
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
learning_rate=self.cfg.learning_rate,
|
||||
save_strategy="steps",
|
||||
save_steps=self.cfg.save_steps,
|
||||
output_dir=self.cfg.output_dir,
|
||||
warmup_steps=self.cfg.warmup_steps,
|
||||
logging_first_step=True,
|
||||
|
||||
Reference in New Issue
Block a user