From 3c2ad00d0763b42ed063290d6916d0c92199c994 Mon Sep 17 00:00:00 2001 From: Gabriel Puliatti Date: Mon, 14 Aug 2023 10:19:29 -0500 Subject: [PATCH] Feat(config): add max steps (#387) --- scripts/finetune.py | 8 +++++++- src/axolotl/utils/trainer.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 850606356..bb96d9789 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -209,7 +209,13 @@ def train( cfg, train_dataset, eval_dataset ) barrier() - total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) + if cfg.max_steps: + total_num_steps = min( + calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps + ) + LOG.info(f"Maximum number of steps set at {total_num_steps}") + else: + total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) if cfg.debug or "debug" in kwargs: LOG.info("check_dataset_labels...") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index b2551028b..54ad058f5 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -461,7 +461,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ evaluation_strategy = "steps" training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg - # max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps + max_steps=total_num_steps if cfg.max_steps else -1, max_seq_length=cfg.sequence_len, per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size