automatically set pad_to_sequence_len when use packing (#2607)

* automatically set pad_to_sequence_len when use packing

* update tests
This commit is contained in:
Wing Lian
2025-05-01 13:24:38 -04:00
committed by GitHub
parent 6a3e6f8c53
commit bcb59c70e2
2 changed files with 32 additions and 5 deletions

View File

@@ -512,10 +512,17 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
if data.get("sample_packing"):
pad_to_sequence_len = data.get("pad_to_sequence_len")
if pad_to_sequence_len is False:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
elif pad_to_sequence_len is None:
LOG.info(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
)
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before")

View File

@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"pad_to_sequence_len": False,
"flash_attention": True,
}
)
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
for record in self._caplog.records
)
def test_packing_autoset(self, minimal_cfg):
cfg = (
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
in record.message
for record in self._caplog.records
)
assert cfg.pad_to_sequence_len is True
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.