diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 881114634..5345165f1 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1032,10 +1032,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if blocklist_key in training_args_kwargs: del training_args_kwargs[blocklist_key] + max_steps = self.cfg.max_steps or total_num_steps or -1 training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg self.cfg.output_dir, per_device_train_batch_size=self.cfg.micro_batch_size, - max_steps=self.cfg.max_steps or total_num_steps, + max_steps=max_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, warmup_steps=self.cfg.warmup_steps,