From 4b7ad9927f50d4430d0d4304c77adeec2b079f2c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 22 Jul 2023 03:35:06 -0400 Subject: [PATCH] validation for sample packing and doc --- README.md | 3 +++ src/axolotl/utils/validation.py | 13 +++++++++++++ tests/test_validation.py | 24 ++++++++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/README.md b/README.md index fe22bbc31..c471494bb 100644 --- a/README.md +++ b/README.md @@ -375,7 +375,10 @@ dataset_shard_idx: sequence_len: 2048 # max sequence length to concatenate training samples together up to # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning +# soon to be DEPRECATED max_packed_sequence_len: 1024 +# use efficient multi-packing with block diagonal attention and per sequence position_ids +sample_packing: # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model adapter: lora diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 06669cba2..3ea59f391 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -8,6 +8,19 @@ LOG = logging.getLogger("axolotl") def validate_config(cfg): + if cfg.max_packed_sequence_len and cfg.sample_packing: + raise ValueError( + "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing" + ) + if cfg.max_packed_sequence_len: + LOG.warning( + str( + PendingDeprecationWarning( + "max_packed_sequence_len will be deprecated in favor of sample_packing" + ) + ) + ) + if cfg.gradient_accumulation_steps and cfg.batch_size: raise ValueError( "please set only one of gradient_accumulation_steps or batch_size" diff --git a/tests/test_validation.py b/tests/test_validation.py index 88c97f0b7..e956d7b40 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -313,3 +313,27 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + def test_packing(self): + cfg = DictDefault( + { + "max_packed_sequence_len": 2048, + } + ) + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "max_packed_sequence_len will be deprecated in favor of sample_packing" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "max_packed_sequence_len": 2048, + "sample_packing": True, + } + ) + regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*" + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg)