From fb12895a172d95bf9303d36c35f4ca129c23882e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 25 Nov 2023 12:15:43 +0900 Subject: [PATCH] Feat: Add warmup_ratio (#893) * Feat: Add warmup_ratio * fix: update readme with more details on conflict --- README.md | 3 ++- src/axolotl/core/trainer_builder.py | 13 ++++++++----- src/axolotl/utils/config.py | 3 +++ tests/test_validation.py | 30 +++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index fc300d605..717c99bf0 100644 --- a/README.md +++ b/README.md @@ -675,7 +675,8 @@ gradient_accumulation_steps: 1 micro_batch_size: 2 eval_batch_size: num_epochs: 4 -warmup_steps: 100 +warmup_steps: 100 # cannot use with warmup_ratio +warmup_ratio: 0.05 # cannot use with warmup_steps learning_rate: 0.00003 lr_quadratic_warmup: logging_steps: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6b78f1f1a..62e527beb 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -461,11 +461,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return AxolotlTrainer def build(self, total_num_steps): - warmup_steps = ( - self.cfg.warmup_steps - if self.cfg.warmup_steps is not None - else min(int(0.03 * total_num_steps), 100) - ) + warmup_steps = None + if self.cfg.warmup_steps is not None: + warmup_steps = self.cfg.warmup_steps + elif self.cfg.warmup_ratio is not None: + warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) + else: + warmup_steps = min(int(0.03 * total_num_steps), 100) + logging_steps = ( self.cfg.logging_steps if self.cfg.logging_steps is not None diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index d2db92a63..c41e059cd 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -372,6 +372,9 @@ def validate_config(cfg): if cfg.rope_scaling: LOG.warning("`rope_scaling` should now be be a key under `model_config`") + if cfg.warmup_steps and cfg.warmup_ratio: + raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index dbd030da3..5a4ef427b 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -649,3 +649,33 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + def test_warmup_step_no_conflict(self): + cfg = DictDefault( + { + "warmup_steps": 10, + "warmup_ratio": 0.1, + } + ) + + with pytest.raises( + ValueError, + match=r".*warmup_steps and warmup_ratio are mutually exclusive*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "warmup_steps": 10, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "warmup_ratio": 0.1, + } + ) + + validate_config(cfg)