automatically set pad_to_sequence_len when use packing (#2607)

* automatically set pad_to_sequence_len when use packing

* update tests
This commit is contained in:
Wing Lian
2025-05-01 13:24:38 -04:00
committed by GitHub
parent 6a3e6f8c53
commit bcb59c70e2
2 changed files with 32 additions and 5 deletions

View File

@@ -512,10 +512,17 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
if data.get("sample_packing"):
pad_to_sequence_len = data.get("pad_to_sequence_len")
if pad_to_sequence_len is False:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
elif pad_to_sequence_len is None:
LOG.info(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
)
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before")