validation for sample packing and doc
This commit is contained in:
@@ -375,7 +375,10 @@ dataset_shard_idx:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
# max sequence length to concatenate training samples together up to
|
# max sequence length to concatenate training samples together up to
|
||||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||||
|
# soon to be DEPRECATED
|
||||||
max_packed_sequence_len: 1024
|
max_packed_sequence_len: 1024
|
||||||
|
# use efficient multi-packing with block diagonal attention and per sequence position_ids
|
||||||
|
sample_packing:
|
||||||
|
|
||||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|||||||
@@ -8,6 +8,19 @@ LOG = logging.getLogger("axolotl")
|
|||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
|
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||||
|
raise ValueError(
|
||||||
|
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|
||||||
|
)
|
||||||
|
if cfg.max_packed_sequence_len:
|
||||||
|
LOG.warning(
|
||||||
|
str(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"please set only one of gradient_accumulation_steps or batch_size"
|
"please set only one of gradient_accumulation_steps or batch_size"
|
||||||
|
|||||||
@@ -313,3 +313,27 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_packing(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"max_packed_sequence_len": 2048,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
||||||
|
in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"max_packed_sequence_len": 2048,
|
||||||
|
"sample_packing": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
||||||
|
with pytest.raises(ValueError, match=regex_exp):
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user