diff --git a/scripts/finetune.py b/scripts/finetune.py index 850606356..bb96d9789 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -209,7 +209,13 @@ def train( cfg, train_dataset, eval_dataset ) barrier() - total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) + if cfg.max_steps: + total_num_steps = min( + calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps + ) + LOG.info(f"Maximum number of steps set at {total_num_steps}") + else: + total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) if cfg.debug or "debug" in kwargs: LOG.info("check_dataset_labels...") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index b2551028b..54ad058f5 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -461,7 +461,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ evaluation_strategy = "steps" training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg - # max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps + max_steps=total_num_steps if cfg.max_steps else -1, max_seq_length=cfg.sequence_len, per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size