better handling so that all devices have the same dataloader len
This commit is contained in:
@@ -1,11 +1,15 @@
|
|||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from typing import Any, Callable, List, Union
|
from typing import Any, Callable, List, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import DistributedSampler, Sampler
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.dataloader")
|
||||||
|
|
||||||
|
|
||||||
@numba.njit
|
@numba.njit
|
||||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||||
@@ -110,6 +114,7 @@ class MultipackDistributedDataloader:
|
|||||||
seq_max_length: int = 2048,
|
seq_max_length: int = 2048,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
sampler: Union[Sampler, DistributedSampler] = None,
|
sampler: Union[Sampler, DistributedSampler] = None,
|
||||||
|
packing_efficiency_estimate: float = 1.0,
|
||||||
):
|
):
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
@@ -130,6 +135,7 @@ class MultipackDistributedDataloader:
|
|||||||
# statistics
|
# statistics
|
||||||
self.eff_total_used = 0
|
self.eff_total_used = 0
|
||||||
self.eff_total_slots = 0
|
self.eff_total_slots = 0
|
||||||
|
self.packing_efficiency_estimate = packing_efficiency_estimate
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
def generate_batches(self, set_stats=False):
|
||||||
if self.sampler:
|
if self.sampler:
|
||||||
@@ -160,6 +166,7 @@ class MultipackDistributedDataloader:
|
|||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
all_batches, _ = self.generate_batches(set_stats=True)
|
all_batches, _ = self.generate_batches(set_stats=True)
|
||||||
features = self.dataset.features.keys()
|
features = self.dataset.features.keys()
|
||||||
|
len_remaining = self._len_est()
|
||||||
for batch in all_batches:
|
for batch in all_batches:
|
||||||
concatenated = {}
|
concatenated = {}
|
||||||
batched = [self.dataset[batch_idx] for batch_idx in batch]
|
batched = [self.dataset[batch_idx] for batch_idx in batch]
|
||||||
@@ -190,15 +197,42 @@ class MultipackDistributedDataloader:
|
|||||||
}
|
}
|
||||||
chunked_data.append(chunk)
|
chunked_data.append(chunk)
|
||||||
yield self.collate_fn(chunked_data)
|
yield self.collate_fn(chunked_data)
|
||||||
|
len_remaining -= 1
|
||||||
|
if not len_remaining:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _len_est(self):
|
||||||
|
indices = range(0, len(self.dataset))
|
||||||
|
lengths = self.lengths[indices]
|
||||||
|
lengths_sum = np.cumsum(lengths)[-1]
|
||||||
|
lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
LOG.info(
|
||||||
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||||
|
return (
|
||||||
|
math.floor(
|
||||||
|
0.99
|
||||||
|
* lengths_sum_per_device
|
||||||
|
/ self.packing_efficiency_estimate
|
||||||
|
/ self.seq_max_length
|
||||||
|
/ self.batch_size
|
||||||
|
)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
batches, _ = self.generate_batches()
|
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||||
# shave off 1% for dealing with variance in packing and dataset length
|
# the same share of total tokens
|
||||||
return math.floor(len(batches) * 0.99)
|
if not self.eff_total_used:
|
||||||
|
batches, _ = self.generate_batches(set_stats=True)
|
||||||
def num_batches(self):
|
LOG.info(
|
||||||
batches, _ = self.generate_batches()
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
return math.floor(len(batches) * 0.99)
|
f"actual packing efficiency: {self.efficiency()}"
|
||||||
|
)
|
||||||
|
return self._len_est()
|
||||||
|
|
||||||
def efficiency(self):
|
def efficiency(self):
|
||||||
return self.eff_total_used / self.eff_total_slots
|
return self.eff_total_used / self.eff_total_slots
|
||||||
|
|||||||
@@ -187,11 +187,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
else sum(len(s["input_ids"]) for s in train_dataset)
|
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||||
)
|
)
|
||||||
total_num_steps = (
|
total_num_steps = (
|
||||||
math.ceil(
|
# match count to len est in dataloader
|
||||||
total_num_tokens
|
(
|
||||||
/ cfg.sample_packing_eff_est
|
0.99
|
||||||
/ 2048
|
* math.ceil(
|
||||||
/ cfg.batch_size
|
total_num_tokens
|
||||||
|
/ cfg.sample_packing_eff_est
|
||||||
|
/ 2048
|
||||||
|
/ cfg.batch_size
|
||||||
|
)
|
||||||
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
)
|
)
|
||||||
@@ -210,6 +215,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
padding="longest",
|
padding="longest",
|
||||||
),
|
),
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||||
)
|
)
|
||||||
data_loader_len = len(data_loader)
|
data_loader_len = len(data_loader)
|
||||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||||
|
|||||||
Reference in New Issue
Block a user