Feat(config): add max steps (#387)
This commit is contained in:
@@ -209,7 +209,13 @@ def train(
|
|||||||
cfg, train_dataset, eval_dataset
|
cfg, train_dataset, eval_dataset
|
||||||
)
|
)
|
||||||
barrier()
|
barrier()
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
if cfg.max_steps:
|
||||||
|
total_num_steps = min(
|
||||||
|
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
||||||
|
)
|
||||||
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
|
else:
|
||||||
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
|
|
||||||
if cfg.debug or "debug" in kwargs:
|
if cfg.debug or "debug" in kwargs:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
|
|||||||
@@ -461,7 +461,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
evaluation_strategy = "steps"
|
evaluation_strategy = "steps"
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size
|
||||||
|
|||||||
Reference in New Issue
Block a user