automatically set pad_to_sequence_len when use packing (#2607)
* automatically set pad_to_sequence_len when use packing * update tests
This commit is contained in:
@@ -512,10 +512,17 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def hint_sample_packing_padding(cls, data):
|
def hint_sample_packing_padding(cls, data):
|
||||||
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
|
if data.get("sample_packing"):
|
||||||
LOG.warning(
|
pad_to_sequence_len = data.get("pad_to_sequence_len")
|
||||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
if pad_to_sequence_len is False:
|
||||||
)
|
LOG.warning(
|
||||||
|
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||||
|
)
|
||||||
|
elif pad_to_sequence_len is None:
|
||||||
|
LOG.info(
|
||||||
|
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||||
|
)
|
||||||
|
data["pad_to_sequence_len"] = True
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
|
|||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": None,
|
"pad_to_sequence_len": False,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
|
|||||||
for record in self._caplog.records
|
for record in self._caplog.records
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_packing_autoset(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"sample_packing": True,
|
||||||
|
"pad_to_sequence_len": None,
|
||||||
|
"flash_attention": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.INFO):
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||||
|
in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
assert cfg.pad_to_sequence_len is True
|
||||||
|
|
||||||
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
|
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
|
||||||
"""
|
"""
|
||||||
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
||||||
|
|||||||
Reference in New Issue
Block a user