diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 85db2bace..cc3c73c9f 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -37,7 +37,7 @@ from axolotl.utils.collators import ( DataCollatorForSeq2Seq, MambaDataCollator, ) -from axolotl.utils.samplers import MultipackBatchSampler +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup try: @@ -170,12 +170,7 @@ class AxolotlTrainer(Trainer): self.args.train_batch_size, drop_last=True, batch_max_len=self._train_batch_size * self.args.max_seq_length, - lengths=( - self.train_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), + lengths=get_dataset_lengths(self.train_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) return super()._get_train_sampler() @@ -189,12 +184,7 @@ class AxolotlTrainer(Trainer): self.args.per_device_eval_batch_size, drop_last=True, batch_max_len=self.args.eval_batch_size * self.args.max_seq_length, - lengths=( - eval_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), + lengths=get_dataset_lengths(eval_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) return super()._get_eval_sampler(eval_dataset) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index b3c7606eb..8ef3a7f78 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -44,7 +44,7 @@ from axolotl.prompters import ( from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first -from axolotl.utils.samplers.multipack import MultipackBatchSampler +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -889,12 +889,7 @@ def encode_packed_pretraining( batch_size=batch_size, drop_last=True, batch_max_len=batch_size * max_seq_length, - lengths=( - train_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), + lengths=get_dataset_lengths(train_dataset), ) chunked_data = defaultdict(list) diff --git a/src/axolotl/utils/samplers/__init__.py b/src/axolotl/utils/samplers/__init__.py index 4c102826f..96e00a5d2 100644 --- a/src/axolotl/utils/samplers/__init__.py +++ b/src/axolotl/utils/samplers/__init__.py @@ -2,3 +2,4 @@ axolotl samplers module """ from .multipack import MultipackBatchSampler # noqa: F401 +from .utils import get_dataset_lengths # noqa: F401 diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py new file mode 100755 index 000000000..926c7386a --- /dev/null +++ b/src/axolotl/utils/samplers/utils.py @@ -0,0 +1,17 @@ +""" +helper util to calculate dataset lengths +""" +import numpy as np + + +def get_dataset_lengths(dataset): + if "length" in dataset.data.column_names: + lengths = np.array(dataset.data.column("length")) + else: + lengths = ( + dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values + ) + return lengths diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3139f5600..17806de65 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -14,7 +14,7 @@ from torch.utils.data import DataLoader, RandomSampler from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first -from axolotl.utils.samplers import MultipackBatchSampler +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger("axolotl") @@ -212,12 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): drop_last=True, batch_max_len=cfg.micro_batch_size * (cfg.max_packed_sequence_len or cfg.sequence_len), - lengths=( - train_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), + lengths=get_dataset_lengths(train_dataset), ) data_loader = DataLoader(