support adamw and grad norm hyperparams

This commit is contained in:
Wing Lian
2023-06-15 08:40:41 -04:00
parent a81f52d575
commit 6d0ee4ba34

View File

@@ -115,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
# TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
if cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size