diff --git a/scripts/finetune.py b/scripts/finetune.py index 6c42b3061..974c94117 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -149,8 +149,10 @@ def train( else: cfg[k] = kwargs[k] + validate_config(cfg) + # setup some derived config / hyperparams - cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size + cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (cfg.batch_size // cfg.micro_batch_size) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) choose_device(cfg) @@ -168,8 +170,6 @@ def train( cfg.fp16 = True cfg.bf16 = False - validate_config(cfg) - # load the tokenizer first logging.info("loading tokenizer...") tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index c4bc4f952..0d9610aae 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -4,6 +4,10 @@ import logging def validate_config(cfg): + if cfg.gradient_accumulation_steps and cfg.batch_size: + raise ValueError( + "please set only one of gradient_accumulation_steps or batch_size" + ) if cfg.load_4bit: raise ValueError( "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq" diff --git a/tests/test_validation.py b/tests/test_validation.py index 15bc07f84..95c98543c 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -117,3 +117,30 @@ class ValidationTest(unittest.TestCase): } ) validate_config(cfg) + + def test_gradient_accumulations_or_batch_size(self): + cfg = DictDefault( + { + "gradient_accumulation_steps": 1, + "batch_size": 1, + } + ) + + with pytest.raises(ValueError, match=r".*gradient_accumulation_steps or batch_size.*"): + validate_config(cfg) + + cfg = DictDefault( + { + "batch_size": 1, + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "gradient_accumulation_steps": 1, + } + ) + + validate_config(cfg)