add support for gradient accumulation steps

This commit is contained in:
Wing Lian
2023-05-30 23:24:37 -04:00
parent c5b0af1a7e
commit 3aad5f3b3e
3 changed files with 34 additions and 3 deletions

View File

@@ -117,3 +117,30 @@ class ValidationTest(unittest.TestCase):
}
)
validate_config(cfg)
def test_gradient_accumulations_or_batch_size(self):
cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"batch_size": 1,
}
)
with pytest.raises(ValueError, match=r".*gradient_accumulation_steps or batch_size.*"):
validate_config(cfg)
cfg = DictDefault(
{
"batch_size": 1,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
}
)
validate_config(cfg)