Refactor sequence length overflow handling in pretraining module
- Introduced DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING constant in utils.py. - Updated encode_packed_pretraining function to use this constant instead of a hardcoded value.
This commit is contained in:
@@ -11,6 +11,7 @@ from torch.utils.data import RandomSampler
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||
|
||||
@@ -261,7 +262,7 @@ def encode_packed_pretraining(
|
||||
drop_attention_mask=multipack_attn,
|
||||
# pass through handling mode from config via ds_wrapper function
|
||||
handling=getattr(ds_wrapper, "cfg", {}).get(
|
||||
"sequence_len_overflow_handling", "drop"
|
||||
"sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
|
||||
|
||||
|
||||
class RetryStrategy(Enum):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user