fix rounding of len of batches to int

This commit is contained in:
Wing Lian
2023-07-25 10:29:49 -04:00
parent df3eb645da
commit daed942fe9

View File

@@ -1,5 +1,5 @@
# pylint: skip-file
import math
from typing import Any, Callable, List, Union
import numba
@@ -193,13 +193,12 @@ class MultipackDistributedDataloader:
def __len__(self):
batches, _ = self.generate_batches()
return (
len(batches) * 0.99
) # shave off 1% for dealing with variance in packing and dataset length
# shave off 1% for dealing with variance in packing and dataset length
return math.floor(len(batches) * 0.99)
def num_batches(self):
batches, _ = self.generate_batches()
return len(batches) * 0.99
return math.floor(len(batches) * 0.99)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots