From bb5a6135eb5047dac1d6838d8316121ba1bd1f3f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 6 Feb 2025 17:23:13 -0500 Subject: [PATCH] don't set total num steps for grpo --- src/axolotl/common/datasets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index cbc0d127c..db07eb43b 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -122,9 +122,11 @@ def load_preference_datasets( `total_num_steps`. """ train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) - total_num_steps = int( + total_num_steps: Optional[int] = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) + if cfg.rl == "grpo": + total_num_steps = None if cli_args.debug or cfg.debug: LOG.info("check_dataset_labels...")