Select input_ids explicitly after panda conversion (#2335)

Without selecting the column, applying `len` counts the whole row as 1 which resulting the total number of the samples instead of the token counts.
This commit is contained in:
Seungduk Kim
2025-02-17 14:07:27 +09:00
committed by GitHub
parent a98526ef78
commit 97a2fa2781

View File

@@ -396,8 +396,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
):
total_num_tokens = np.sum(
train_dataset.select_columns("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.to_pandas()["input_ids"]
.apply(len)
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)