fix rounding of len of batches to int

This commit is contained in:
Wing Lian
2023-07-25 10:29:49 -04:00
parent df3eb645da
commit daed942fe9

View File

@@ -1,5 +1,5 @@
# pylint: skip-file # pylint: skip-file
import math
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Union
import numba import numba
@@ -193,13 +193,12 @@ class MultipackDistributedDataloader:
def __len__(self): def __len__(self):
batches, _ = self.generate_batches() batches, _ = self.generate_batches()
return ( # shave off 1% for dealing with variance in packing and dataset length
len(batches) * 0.99 return math.floor(len(batches) * 0.99)
) # shave off 1% for dealing with variance in packing and dataset length
def num_batches(self): def num_batches(self):
batches, _ = self.generate_batches() batches, _ = self.generate_batches()
return len(batches) * 0.99 return math.floor(len(batches) * 0.99)
def efficiency(self): def efficiency(self):
return self.eff_total_used / self.eff_total_slots return self.eff_total_used / self.eff_total_slots