Refactor sequence length overflow handling in pretraining module
- Introduced DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING constant in utils.py. - Updated encode_packed_pretraining function to use this constant instead of a hardcoded value.
This commit is contained in:
@@ -11,6 +11,7 @@ from torch.utils.data import RandomSampler
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||||
|
|
||||||
@@ -261,7 +262,7 @@ def encode_packed_pretraining(
|
|||||||
drop_attention_mask=multipack_attn,
|
drop_attention_mask=multipack_attn,
|
||||||
# pass through handling mode from config via ds_wrapper function
|
# pass through handling mode from config via ds_wrapper function
|
||||||
handling=getattr(ds_wrapper, "cfg", {}).get(
|
handling=getattr(ds_wrapper, "cfg", {}).get(
|
||||||
"sequence_len_overflow_handling", "drop"
|
"sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ from axolotl.utils.trainer import truncate_or_drop_long_seq
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
|
||||||
|
|
||||||
|
|
||||||
class RetryStrategy(Enum):
|
class RetryStrategy(Enum):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user