From 6d0ee4ba34fbf20e9846ce24875448019f8dba65 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 15 Jun 2023 08:40:41 -0400 Subject: [PATCH 1/4] support adamw and grad norm hyperparams --- src/axolotl/utils/trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 5152e649b..5cf3107f3 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -115,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): # TODO search Path("./") for one training_arguments_kwargs["deepspeed"] = "./ds_config.json" + if cfg.adam_beta1: + training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1 + if cfg.adam_beta2: + training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2 + if cfg.adam_epsilon: + training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon + if cfg.max_grad_norm: + training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm + training_args = transformers.TrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size From c969f0a9dc28c9f095a2bb6b3ecede0216d909b5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 15 Jun 2023 08:43:20 -0400 Subject: [PATCH 2/4] add docs --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index d6c9cfefb..5fbac1a48 100644 --- a/README.md +++ b/README.md @@ -422,6 +422,12 @@ log_sweep_max_lr: optimizer: # specify weight decay weight_decay: +# adamw hyperparams +adam_beta1: +adam_beta2: +adam_epsilon: +# Gradient clipping max norm +max_grad_norm: # whether to bettertransformers flash_optimum: From cb9d3af5c00e0189f95c03d64efdc283aec54679 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 15 Jun 2023 09:39:42 -0400 Subject: [PATCH 3/4] add validation and tests for adamw hyperparam --- src/axolotl/utils/validation.py | 5 ++++ tests/test_validation.py | 42 +++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 298d36c4e..2e0da69b3 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -87,6 +87,11 @@ def validate_config(cfg): "You probably want to disable group_by_length as it will force a streamed dataset to download completely." ) + if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and ( + not cfg.optimizer or "adamw" not in cfg.optimizer + ): + logging.warning("adamw hyperparameters found, but no adamw optimizer set") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index dba54586e..cc6d29a23 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -263,3 +263,45 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) + + def test_adamw_hyperparams(self): + cfg = DictDefault( + { + "optimizer": None, + "adamw_epsilon": 0.0001, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "adamw hyperparameters found, but no adamw optimizer set" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "optimizer": "adafactor", + "adamw_beta1": 0.0001, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "adamw hyperparameters found, but no adamw optimizer set" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "optimizer": "adamw_bnb_8bit", + "adamw_beta1": 0.0001, + "adamw_beta2": 0.0001, + "adamw_epsilon": 0.0001, + } + ) + + validate_config(cfg) From ad5ca4f734721d66b9c10a58ba7141bf13694452 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 15 Jun 2023 10:12:47 -0400 Subject: [PATCH 4/4] Additional test case per pr --- tests/test_validation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_validation.py b/tests/test_validation.py index cc6d29a23..d39a4618e 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -305,3 +305,11 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + cfg = DictDefault( + { + "optimizer": "adafactor", + } + ) + + validate_config(cfg)