Efficiently get the length of the tokenized docs (#1063)
* Efficiently get the length of the tokenized docs * chore: lint --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
committed by
GitHub
parent
732851f105
commit
81d384598e
@@ -37,7 +37,7 @@ from axolotl.utils.collators import (
|
|||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
)
|
)
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -170,12 +170,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
self.args.train_batch_size,
|
self.args.train_batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||||
lengths=(
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
self.train_dataset.data.column("position_ids")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: x[-1] + 1)
|
|
||||||
.values
|
|
||||||
),
|
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
@@ -189,12 +184,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
self.args.per_device_eval_batch_size,
|
self.args.per_device_eval_batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||||
lengths=(
|
lengths=get_dataset_lengths(eval_dataset),
|
||||||
eval_dataset.data.column("position_ids")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: x[-1] + 1)
|
|
||||||
.values
|
|
||||||
),
|
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ from axolotl.prompters import (
|
|||||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
from axolotl.utils.samplers.multipack import MultipackBatchSampler
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
@@ -889,12 +889,7 @@ def encode_packed_pretraining(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
batch_max_len=batch_size * max_seq_length,
|
batch_max_len=batch_size * max_seq_length,
|
||||||
lengths=(
|
lengths=get_dataset_lengths(train_dataset),
|
||||||
train_dataset.data.column("position_ids")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: x[-1] + 1)
|
|
||||||
.values
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chunked_data = defaultdict(list)
|
chunked_data = defaultdict(list)
|
||||||
|
|||||||
@@ -2,3 +2,4 @@
|
|||||||
axolotl samplers module
|
axolotl samplers module
|
||||||
"""
|
"""
|
||||||
from .multipack import MultipackBatchSampler # noqa: F401
|
from .multipack import MultipackBatchSampler # noqa: F401
|
||||||
|
from .utils import get_dataset_lengths # noqa: F401
|
||||||
|
|||||||
17
src/axolotl/utils/samplers/utils.py
Executable file
17
src/axolotl/utils/samplers/utils.py
Executable file
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
helper util to calculate dataset lengths
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_lengths(dataset):
|
||||||
|
if "length" in dataset.data.column_names:
|
||||||
|
lengths = np.array(dataset.data.column("length"))
|
||||||
|
else:
|
||||||
|
lengths = (
|
||||||
|
dataset.data.column("position_ids")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: x[-1] + 1)
|
||||||
|
.values
|
||||||
|
)
|
||||||
|
return lengths
|
||||||
@@ -14,7 +14,7 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
||||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = get_logger("axolotl")
|
LOG = get_logger("axolotl")
|
||||||
|
|
||||||
@@ -212,12 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
batch_max_len=cfg.micro_batch_size
|
batch_max_len=cfg.micro_batch_size
|
||||||
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
||||||
lengths=(
|
lengths=get_dataset_lengths(train_dataset),
|
||||||
train_dataset.data.column("position_ids")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: x[-1] + 1)
|
|
||||||
.values
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
|
|||||||
Reference in New Issue
Block a user