default to dropping last batch in multipack batch sampler
This commit is contained in:
@@ -33,6 +33,7 @@ from axolotl.core.trainers.utils import (
|
|||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils import get_not_null
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -220,7 +221,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = get_not_null(
|
||||||
|
self.args.dataloader_drop_last, True
|
||||||
|
)
|
||||||
if sampler_fn is not None:
|
if sampler_fn is not None:
|
||||||
sampler = sampler_fn(dataset)
|
sampler = sampler_fn(dataset)
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
|
|||||||
@@ -52,3 +52,10 @@ def patch_optimized_env():
|
|||||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
set_pytorch_cuda_alloc_conf()
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
|
|
||||||
|
def get_not_null(value, default=None):
|
||||||
|
"""
|
||||||
|
return the value if it's not None, otherwise return the default value
|
||||||
|
"""
|
||||||
|
return value if value is not None else default
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
batch_max_len: int, # Maximum sequence length (bin capacity)
|
batch_max_len: int, # Maximum sequence length (bin capacity)
|
||||||
lengths: np.ndarray, # Sequence lengths
|
lengths: np.ndarray, # Sequence lengths
|
||||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||||
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
|
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||||
num_count_samples: int = 16, # Number of times to estimate batch count
|
num_count_samples: int = 16, # Number of times to estimate batch count
|
||||||
sequential: bool = False, # Whether to use sequential packing
|
sequential: bool = False, # Whether to use sequential packing
|
||||||
group_size: int = 100_000, # Size of groups for parallel packing
|
group_size: int = 100_000, # Size of groups for parallel packing
|
||||||
|
|||||||
Reference in New Issue
Block a user