black formatting
This commit is contained in:
@@ -152,7 +152,9 @@ def train(
|
|||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
# setup some derived config / hyperparams
|
# setup some derived config / hyperparams
|
||||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (cfg.batch_size // cfg.micro_batch_size)
|
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||||
|
cfg.batch_size // cfg.micro_batch_size
|
||||||
|
)
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
choose_device(cfg)
|
choose_device(cfg)
|
||||||
|
|||||||
@@ -126,7 +126,9 @@ class ValidationTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r".*gradient_accumulation_steps or batch_size.*"):
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
||||||
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
Reference in New Issue
Block a user