diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index bc7ee7e72..02308695c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -512,10 +512,17 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def hint_sample_packing_padding(cls, data): - if data.get("sample_packing") and not data.get("pad_to_sequence_len"): - LOG.warning( - "`pad_to_sequence_len: true` is recommended when using sample_packing" - ) + if data.get("sample_packing"): + pad_to_sequence_len = data.get("pad_to_sequence_len") + if pad_to_sequence_len is False: + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using sample_packing" + ) + elif pad_to_sequence_len is None: + LOG.info( + "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing" + ) + data["pad_to_sequence_len"] = True return data @model_validator(mode="before") diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 3262a6981..683db61b2 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -648,7 +648,7 @@ class TestValidation(BaseValidation): DictDefault( { "sample_packing": True, - "pad_to_sequence_len": None, + "pad_to_sequence_len": False, "flash_attention": True, } ) @@ -662,6 +662,26 @@ class TestValidation(BaseValidation): for record in self._caplog.records ) + def test_packing_autoset(self, minimal_cfg): + cfg = ( + DictDefault( + { + "sample_packing": True, + "pad_to_sequence_len": None, + "flash_attention": True, + } + ) + | minimal_cfg + ) + with self._caplog.at_level(logging.INFO): + cfg = validate_config(cfg) + assert any( + "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing" + in record.message + for record in self._caplog.records + ) + assert cfg.pad_to_sequence_len is True + def test_merge_lora_no_bf16_fail(self, minimal_cfg): """ This is assumed to be run on a CPU machine, so bf16 is not supported.