From 6d0ee4ba34fbf20e9846ce24875448019f8dba65 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 15 Jun 2023 08:40:41 -0400 Subject: [PATCH] support adamw and grad norm hyperparams --- src/axolotl/utils/trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 5152e649b..5cf3107f3 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -115,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): # TODO search Path("./") for one training_arguments_kwargs["deepspeed"] = "./ds_config.json" + if cfg.adam_beta1: + training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1 + if cfg.adam_beta2: + training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2 + if cfg.adam_epsilon: + training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon + if cfg.max_grad_norm: + training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm + training_args = transformers.TrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size