From 5d7a61576dbc13bff8ae5eeec4a5f572029e61e6 Mon Sep 17 00:00:00 2001 From: mhenrhcsen Date: Thu, 15 May 2025 12:55:09 +0200 Subject: [PATCH] Refactor sequence length overflow handling in pretraining module - Introduced DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING constant in utils.py. - Updated encode_packed_pretraining function to use this constant instead of a hardcoded value. --- src/axolotl/utils/data/pretraining.py | 3 ++- src/axolotl/utils/data/utils.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index a1c2f9c85..83836bf0b 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -11,6 +11,7 @@ from torch.utils.data import RandomSampler from transformers import PreTrainedTokenizerBase from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.trainer import process_pretraining_datasets_for_packing @@ -261,7 +262,7 @@ def encode_packed_pretraining( drop_attention_mask=multipack_attn, # pass through handling mode from config via ds_wrapper function handling=getattr(ds_wrapper, "cfg", {}).get( - "sequence_len_overflow_handling", "drop" + "sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING ), ) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index b26a8942b..c31529260 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -17,6 +17,8 @@ from axolotl.utils.trainer import truncate_or_drop_long_seq LOG = logging.getLogger(__name__) +DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop" + class RetryStrategy(Enum): """