From 6fa40bf8adbf248ffa8bbd5f666fc771a147cd41 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 May 2023 23:33:37 -0400 Subject: [PATCH] black formatting --- scripts/finetune.py | 4 +++- tests/test_validation.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 974c94117..e34e7378c 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -152,7 +152,9 @@ def train( validate_config(cfg) # setup some derived config / hyperparams - cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (cfg.batch_size // cfg.micro_batch_size) + cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( + cfg.batch_size // cfg.micro_batch_size + ) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) choose_device(cfg) diff --git a/tests/test_validation.py b/tests/test_validation.py index 95c98543c..93ec15269 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -126,7 +126,9 @@ class ValidationTest(unittest.TestCase): } ) - with pytest.raises(ValueError, match=r".*gradient_accumulation_steps or batch_size.*"): + with pytest.raises( + ValueError, match=r".*gradient_accumulation_steps or batch_size.*" + ): validate_config(cfg) cfg = DictDefault(