better handling so that all devices have the same dataloader len

This commit is contained in:
Wing Lian
2023-07-25 22:18:34 -04:00
parent daed942fe9
commit 945f2e5029
2 changed files with 52 additions and 12 deletions

View File

@@ -1,11 +1,15 @@
# pylint: skip-file
import logging
import math
import os
from typing import Any, Callable, List, Union
import numba
import numpy as np
from torch.utils.data import DistributedSampler, Sampler
LOG = logging.getLogger("axolotl.utils.dataloader")
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
@@ -110,6 +114,7 @@ class MultipackDistributedDataloader:
seq_max_length: int = 2048,
batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
):
# Dataset
self.dataset = dataset
@@ -130,6 +135,7 @@ class MultipackDistributedDataloader:
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate
def generate_batches(self, set_stats=False):
if self.sampler:
@@ -160,6 +166,7 @@ class MultipackDistributedDataloader:
def __iter__(self):
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batch in all_batches:
concatenated = {}
batched = [self.dataset[batch_idx] for batch_idx in batch]
@@ -190,15 +197,42 @@ class MultipackDistributedDataloader:
}
chunked_data.append(chunk)
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return
def _len_est(self):
indices = range(0, len(self.dataset))
lengths = self.lengths[indices]
lengths_sum = np.cumsum(lengths)[-1]
lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1))
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return (
math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
/ self.seq_max_length
/ self.batch_size
)
- 1
)
def __len__(self):
batches, _ = self.generate_batches()
# shave off 1% for dealing with variance in packing and dataset length
return math.floor(len(batches) * 0.99)
def num_batches(self):
batches, _ = self.generate_batches()
return math.floor(len(batches) * 0.99)
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
# the same share of total tokens
if not self.eff_total_used:
batches, _ = self.generate_batches(set_stats=True)
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"actual packing efficiency: {self.efficiency()}"
)
return self._len_est()
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -187,11 +187,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else sum(len(s["input_ids"]) for s in train_dataset)
)
total_num_steps = (
math.ceil(
total_num_tokens
/ cfg.sample_packing_eff_est
/ 2048
/ cfg.batch_size
# match count to len est in dataloader
(
0.99
* math.ceil(
total_num_tokens
/ cfg.sample_packing_eff_est
/ 2048
/ cfg.batch_size
)
- 1
)
* cfg.num_epochs
)
@@ -210,6 +215,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
padding="longest",
),
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
)
data_loader_len = len(data_loader)
LOG.info(f"data_loader_len: {data_loader_len}")