Refactor sequence length overflow handling in pretraining module

- Introduced DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING constant in utils.py.
- Updated encode_packed_pretraining function to use this constant instead of a hardcoded value.
This commit is contained in:
mhenrhcsen
2025-05-15 12:55:09 +02:00
parent 5ecf22b54e
commit 5d7a61576d
2 changed files with 4 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
@@ -261,7 +262,7 @@ def encode_packed_pretraining(
drop_attention_mask=multipack_attn,
# pass through handling mode from config via ds_wrapper function
handling=getattr(ds_wrapper, "cfg", {}).get(
"sequence_len_overflow_handling", "drop"
"sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
),
)

View File

@@ -17,6 +17,8 @@ from axolotl.utils.trainer import truncate_or_drop_long_seq
LOG = logging.getLogger(__name__)
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
class RetryStrategy(Enum):
"""