better handling so that all devices have the same dataloader len

This commit is contained in:
Wing Lian
2023-07-25 22:18:34 -04:00
parent daed942fe9
commit 945f2e5029
2 changed files with 52 additions and 12 deletions

View File

@@ -1,11 +1,15 @@
# pylint: skip-file # pylint: skip-file
import logging
import math import math
import os
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Union
import numba import numba
import numpy as np import numpy as np
from torch.utils.data import DistributedSampler, Sampler from torch.utils.data import DistributedSampler, Sampler
LOG = logging.getLogger("axolotl.utils.dataloader")
@numba.njit @numba.njit
def ffd_check(a: np.ndarray, c: int, n: int): def ffd_check(a: np.ndarray, c: int, n: int):
@@ -110,6 +114,7 @@ class MultipackDistributedDataloader:
seq_max_length: int = 2048, seq_max_length: int = 2048,
batch_size: int = 1, batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None, sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
@@ -130,6 +135,7 @@ class MultipackDistributedDataloader:
# statistics # statistics
self.eff_total_used = 0 self.eff_total_used = 0
self.eff_total_slots = 0 self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
if self.sampler: if self.sampler:
@@ -160,6 +166,7 @@ class MultipackDistributedDataloader:
def __iter__(self): def __iter__(self):
all_batches, _ = self.generate_batches(set_stats=True) all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est()
for batch in all_batches: for batch in all_batches:
concatenated = {} concatenated = {}
batched = [self.dataset[batch_idx] for batch_idx in batch] batched = [self.dataset[batch_idx] for batch_idx in batch]
@@ -190,15 +197,42 @@ class MultipackDistributedDataloader:
} }
chunked_data.append(chunk) chunked_data.append(chunk)
yield self.collate_fn(chunked_data) yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return
def _len_est(self):
indices = range(0, len(self.dataset))
lengths = self.lengths[indices]
lengths_sum = np.cumsum(lengths)[-1]
lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1))
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return (
math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
/ self.seq_max_length
/ self.batch_size
)
- 1
)
def __len__(self): def __len__(self):
batches, _ = self.generate_batches() # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
# shave off 1% for dealing with variance in packing and dataset length # the same share of total tokens
return math.floor(len(batches) * 0.99) if not self.eff_total_used:
batches, _ = self.generate_batches(set_stats=True)
def num_batches(self): LOG.info(
batches, _ = self.generate_batches() f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
return math.floor(len(batches) * 0.99) f"actual packing efficiency: {self.efficiency()}"
)
return self._len_est()
def efficiency(self): def efficiency(self):
return self.eff_total_used / self.eff_total_slots return self.eff_total_used / self.eff_total_slots

View File

@@ -187,11 +187,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else sum(len(s["input_ids"]) for s in train_dataset) else sum(len(s["input_ids"]) for s in train_dataset)
) )
total_num_steps = ( total_num_steps = (
math.ceil( # match count to len est in dataloader
total_num_tokens (
/ cfg.sample_packing_eff_est 0.99
/ 2048 * math.ceil(
/ cfg.batch_size total_num_tokens
/ cfg.sample_packing_eff_est
/ 2048
/ cfg.batch_size
)
- 1
) )
* cfg.num_epochs * cfg.num_epochs
) )
@@ -210,6 +215,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
padding="longest", padding="longest",
), ),
sampler=sampler, sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
) )
data_loader_len = len(data_loader) data_loader_len = len(data_loader)
LOG.info(f"data_loader_len: {data_loader_len}") LOG.info(f"data_loader_len: {data_loader_len}")