automatically set pad_to_sequence_len when use packing (#2607)

* automatically set pad_to_sequence_len when use packing

* update tests
This commit is contained in:
Wing Lian
2025-05-01 13:24:38 -04:00
committed by GitHub
parent 6a3e6f8c53
commit bcb59c70e2
2 changed files with 32 additions and 5 deletions

View File

@@ -512,10 +512,17 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def hint_sample_packing_padding(cls, data): def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"): if data.get("sample_packing"):
LOG.warning( pad_to_sequence_len = data.get("pad_to_sequence_len")
"`pad_to_sequence_len: true` is recommended when using sample_packing" if pad_to_sequence_len is False:
) LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
elif pad_to_sequence_len is None:
LOG.info(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
)
data["pad_to_sequence_len"] = True
return data return data
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
DictDefault( DictDefault(
{ {
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": None, "pad_to_sequence_len": False,
"flash_attention": True, "flash_attention": True,
} }
) )
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
for record in self._caplog.records for record in self._caplog.records
) )
def test_packing_autoset(self, minimal_cfg):
cfg = (
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
in record.message
for record in self._caplog.records
)
assert cfg.pad_to_sequence_len is True
def test_merge_lora_no_bf16_fail(self, minimal_cfg): def test_merge_lora_no_bf16_fail(self, minimal_cfg):
""" """
This is assumed to be run on a CPU machine, so bf16 is not supported. This is assumed to be run on a CPU machine, so bf16 is not supported.