add support for gradient accumulation steps
This commit is contained in:
@@ -117,3 +117,30 @@ class ValidationTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
validate_config(cfg)
|
||||
|
||||
def test_gradient_accumulations_or_batch_size(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
"batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*gradient_accumulation_steps or batch_size.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user