black formatting

This commit is contained in:
Wing Lian
2023-05-30 23:33:37 -04:00
parent 3aad5f3b3e
commit 6fa40bf8ad
2 changed files with 6 additions and 2 deletions

View File

@@ -152,7 +152,9 @@ def train(
validate_config(cfg) validate_config(cfg)
# setup some derived config / hyperparams # setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (cfg.batch_size // cfg.micro_batch_size) cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg) choose_device(cfg)

View File

@@ -126,7 +126,9 @@ class ValidationTest(unittest.TestCase):
} }
) )
with pytest.raises(ValueError, match=r".*gradient_accumulation_steps or batch_size.*"): with pytest.raises(
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
):
validate_config(cfg) validate_config(cfg)
cfg = DictDefault( cfg = DictDefault(