Feat(config): add max steps (#387)

This commit is contained in:
Gabriel Puliatti
2023-08-14 10:19:29 -05:00
committed by GitHub
parent 5d48a10548
commit 3c2ad00d07
2 changed files with 8 additions and 2 deletions

View File

@@ -209,7 +209,13 @@ def train(
cfg, train_dataset, eval_dataset
)
barrier()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...")

View File

@@ -461,7 +461,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
evaluation_strategy = "steps"
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size