From 3aad5f3b3e171fa42846a91b2339848c1154d0ed Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 May 2023 23:24:37 -0400 Subject: [PATCH] add support for gradient accumulation steps --- scripts/finetune.py | 6 +++--- src/axolotl/utils/validation.py | 4 ++++ tests/test_validation.py | 27 +++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 6c42b3061..974c94117 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -149,8 +149,10 @@ def train( else: cfg[k] = kwargs[k] + validate_config(cfg) + # setup some derived config / hyperparams - cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size + cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (cfg.batch_size // cfg.micro_batch_size) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) choose_device(cfg) @@ -168,8 +170,6 @@ def train( cfg.fp16 = True cfg.bf16 = False - validate_config(cfg) - # load the tokenizer first logging.info("loading tokenizer...") tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index c4bc4f952..0d9610aae 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -4,6 +4,10 @@ import logging def validate_config(cfg): + if cfg.gradient_accumulation_steps and cfg.batch_size: + raise ValueError( + "please set only one of gradient_accumulation_steps or batch_size" + ) if cfg.load_4bit: raise ValueError( "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq" diff --git a/tests/test_validation.py b/tests/test_validation.py index 15bc07f84..95c98543c 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -117,3 +117,30 @@ class ValidationTest(unittest.TestCase): } ) validate_config(cfg) + + def test_gradient_accumulations_or_batch_size(self): + cfg = DictDefault( + { + "gradient_accumulation_steps": 1, + "batch_size": 1, + } + ) + + with pytest.raises(ValueError, match=r".*gradient_accumulation_steps or batch_size.*"): + validate_config(cfg) + + cfg = DictDefault( + { + "batch_size": 1, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "gradient_accumulation_steps": 1, + } + ) + + validate_config(cfg)