From 81d384598eb06978cd144fc035753e39f96fee7c Mon Sep 17 00:00:00 2001 From: Ricardo Dominguez-Olmedo Date: Mon, 8 Jan 2024 21:48:30 +0100 Subject: [PATCH] Efficiently get the length of the tokenized docs (#1063) * Efficiently get the length of the tokenized docs * chore: lint --------- Co-authored-by: Wing Lian --- src/axolotl/core/trainer_builder.py | 16 +++------------- src/axolotl/utils/data.py | 9 ++------- src/axolotl/utils/samplers/__init__.py | 1 + src/axolotl/utils/samplers/utils.py | 17 +++++++++++++++++ src/axolotl/utils/trainer.py | 9 ++------- 5 files changed, 25 insertions(+), 27 deletions(-) create mode 100755 src/axolotl/utils/samplers/utils.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 85db2bace..cc3c73c9f 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -37,7 +37,7 @@ from axolotl.utils.collators import ( DataCollatorForSeq2Seq, MambaDataCollator, ) -from axolotl.utils.samplers import MultipackBatchSampler +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup try: @@ -170,12 +170,7 @@ class AxolotlTrainer(Trainer): self.args.train_batch_size, drop_last=True, batch_max_len=self._train_batch_size * self.args.max_seq_length, - lengths=( - self.train_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), + lengths=get_dataset_lengths(self.train_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) return super()._get_train_sampler() @@ -189,12 +184,7 @@ class AxolotlTrainer(Trainer): self.args.per_device_eval_batch_size, drop_last=True, batch_max_len=self.args.eval_batch_size * self.args.max_seq_length, - lengths=( - eval_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), + lengths=get_dataset_lengths(eval_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) return super()._get_eval_sampler(eval_dataset) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index b3c7606eb..8ef3a7f78 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -44,7 +44,7 @@ from axolotl.prompters import ( from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first -from axolotl.utils.samplers.multipack import MultipackBatchSampler +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -889,12 +889,7 @@ def encode_packed_pretraining( batch_size=batch_size, drop_last=True, batch_max_len=batch_size * max_seq_length, - lengths=( - train_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), + lengths=get_dataset_lengths(train_dataset), ) chunked_data = defaultdict(list) diff --git a/src/axolotl/utils/samplers/__init__.py b/src/axolotl/utils/samplers/__init__.py index 4c102826f..96e00a5d2 100644 --- a/src/axolotl/utils/samplers/__init__.py +++ b/src/axolotl/utils/samplers/__init__.py @@ -2,3 +2,4 @@ axolotl samplers module """ from .multipack import MultipackBatchSampler # noqa: F401 +from .utils import get_dataset_lengths # noqa: F401 diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py new file mode 100755 index 000000000..926c7386a --- /dev/null +++ b/src/axolotl/utils/samplers/utils.py @@ -0,0 +1,17 @@ +""" +helper util to calculate dataset lengths +""" +import numpy as np + + +def get_dataset_lengths(dataset): + if "length" in dataset.data.column_names: + lengths = np.array(dataset.data.column("length")) + else: + lengths = ( + dataset.data.column("position_ids") + .to_pandas() + .apply(lambda x: x[-1] + 1) + .values + ) + return lengths diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3139f5600..17806de65 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -14,7 +14,7 @@ from torch.utils.data import DataLoader, RandomSampler from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first -from axolotl.utils.samplers import MultipackBatchSampler +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger("axolotl") @@ -212,12 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): drop_last=True, batch_max_len=cfg.micro_batch_size * (cfg.max_packed_sequence_len or cfg.sequence_len), - lengths=( - train_dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ), + lengths=get_dataset_lengths(train_dataset), ) data_loader = DataLoader(