Deprecate max packed sequence len (#1141)
This commit is contained in:
@@ -324,20 +324,19 @@ class ValidationTest(BaseValidation):
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_packing(self):
|
||||
def test_deprecated_packing(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"max_packed_sequence_len": 2048,
|
||||
"max_packed_sequence_len": 1024,
|
||||
}
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with pytest.raises(
|
||||
DeprecationWarning,
|
||||
match=r"`max_packed_sequence_len` is no longer supported",
|
||||
):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
def test_packing(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
@@ -352,16 +351,6 @@ class ValidationTest(BaseValidation):
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"max_packed_sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
}
|
||||
)
|
||||
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
is_torch_bf16_gpu_available(),
|
||||
reason="test should only run on gpus w/o bf16 support",
|
||||
|
||||
Reference in New Issue
Block a user