better handling so that all devices have the same dataloader len
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
# pylint: skip-file
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.dataloader")
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
@@ -110,6 +114,7 @@ class MultipackDistributedDataloader:
|
||||
seq_max_length: int = 2048,
|
||||
batch_size: int = 1,
|
||||
sampler: Union[Sampler, DistributedSampler] = None,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
@@ -130,6 +135,7 @@ class MultipackDistributedDataloader:
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
if self.sampler:
|
||||
@@ -160,6 +166,7 @@ class MultipackDistributedDataloader:
|
||||
def __iter__(self):
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
for batch in all_batches:
|
||||
concatenated = {}
|
||||
batched = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
@@ -190,15 +197,42 @@ class MultipackDistributedDataloader:
|
||||
}
|
||||
chunked_data.append(chunk)
|
||||
yield self.collate_fn(chunked_data)
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
|
||||
def _len_est(self):
|
||||
indices = range(0, len(self.dataset))
|
||||
lengths = self.lengths[indices]
|
||||
lengths_sum = np.cumsum(lengths)[-1]
|
||||
lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1))
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return (
|
||||
math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
/ self.seq_max_length
|
||||
/ self.batch_size
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
batches, _ = self.generate_batches()
|
||||
# shave off 1% for dealing with variance in packing and dataset length
|
||||
return math.floor(len(batches) * 0.99)
|
||||
|
||||
def num_batches(self):
|
||||
batches, _ = self.generate_batches()
|
||||
return math.floor(len(batches) * 0.99)
|
||||
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||
# the same share of total tokens
|
||||
if not self.eff_total_used:
|
||||
batches, _ = self.generate_batches(set_stats=True)
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"actual packing efficiency: {self.efficiency()}"
|
||||
)
|
||||
return self._len_est()
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
|
||||
@@ -187,11 +187,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||
)
|
||||
total_num_steps = (
|
||||
math.ceil(
|
||||
total_num_tokens
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ 2048
|
||||
/ cfg.batch_size
|
||||
# match count to len est in dataloader
|
||||
(
|
||||
0.99
|
||||
* math.ceil(
|
||||
total_num_tokens
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ 2048
|
||||
/ cfg.batch_size
|
||||
)
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
)
|
||||
@@ -210,6 +215,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
padding="longest",
|
||||
),
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
|
||||
Reference in New Issue
Block a user