don't set total num steps for grpo

This commit is contained in:
Wing Lian
2025-02-06 17:23:13 -05:00
parent e637f9b1a4
commit bb5a6135eb

View File

@@ -122,9 +122,11 @@ def load_preference_datasets(
`total_num_steps`. `total_num_steps`.
""" """
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int( total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )
if cfg.rl == "grpo":
total_num_steps = None
if cli_args.debug or cfg.debug: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")