diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index a1c2f9c85..83836bf0b 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -11,6 +11,7 @@ from torch.utils.data import RandomSampler from transformers import PreTrainedTokenizerBase from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.trainer import process_pretraining_datasets_for_packing @@ -261,7 +262,7 @@ def encode_packed_pretraining( drop_attention_mask=multipack_attn, # pass through handling mode from config via ds_wrapper function handling=getattr(ds_wrapper, "cfg", {}).get( - "sequence_len_overflow_handling", "drop" + "sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING ), ) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index b26a8942b..c31529260 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -17,6 +17,8 @@ from axolotl.utils.trainer import truncate_or_drop_long_seq LOG = logging.getLogger(__name__) +DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop" + class RetryStrategy(Enum): """