diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index cbc0d127c..db07eb43b 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -122,9 +122,11 @@ def load_preference_datasets( `total_num_steps`. """ train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) - total_num_steps = int( + total_num_steps: Optional[int] = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) + if cfg.rl == "grpo": + total_num_steps = None if cli_args.debug or cfg.debug: LOG.info("check_dataset_labels...")