diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f0669565f..f91f4e318 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -401,6 +401,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`") cfg.total_num_tokens = total_num_tokens + if not cfg.total_supervised_tokens: + total_supervised_tokens = ( + train_dataset.data.column("labels") + .to_pandas() + .apply(lambda x: np.sum(np.array(x) != -100)) + .sum() + ) + LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`") + cfg.total_supervised_tokens = total_supervised_tokens + if cfg.sample_packing_eff_est: total_num_steps = ( # match count to len est in dataloader