diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 5152e649b..5cf3107f3 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -115,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): # TODO search Path("./") for one training_arguments_kwargs["deepspeed"] = "./ds_config.json" + if cfg.adam_beta1: + training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1 + if cfg.adam_beta2: + training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2 + if cfg.adam_epsilon: + training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon + if cfg.max_grad_norm: + training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm + training_args = transformers.TrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size