Deprecate max packed sequence len (#1141)

This commit is contained in:
Wing Lian
2024-01-20 05:11:50 -05:00
committed by GitHub
parent 3db5f2fd17
commit 2ce5c0d68a
6 changed files with 38 additions and 170 deletions

View File

@@ -324,20 +324,19 @@ class ValidationTest(BaseValidation):
validate_config(cfg)
def test_packing(self):
def test_deprecated_packing(self):
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
"max_packed_sequence_len": 1024,
}
)
with self._caplog.at_level(logging.WARNING):
with pytest.raises(
DeprecationWarning,
match=r"`max_packed_sequence_len` is no longer supported",
):
validate_config(cfg)
assert any(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
in record.message
for record in self._caplog.records
)
def test_packing(self):
cfg = DictDefault(
{
"sample_packing": True,
@@ -352,16 +351,6 @@ class ValidationTest(BaseValidation):
for record in self._caplog.records
)
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
"sample_packing": True,
}
)
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
@pytest.mark.skipif(
is_torch_bf16_gpu_available(),
reason="test should only run on gpus w/o bf16 support",