diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 70e443cb3..e73d2af8b 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -33,6 +33,7 @@ from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, ) +from axolotl.utils import get_not_null from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -220,7 +221,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): } if not isinstance(dataset, torch.utils.data.IterableDataset): - dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["drop_last"] = get_not_null( + self.args.dataloader_drop_last, True + ) if sampler_fn is not None: sampler = sampler_fn(dataset) if isinstance(sampler, BatchSampler): diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 3d0ba7c9c..e669413f8 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -52,3 +52,10 @@ def patch_optimized_env(): if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None: os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" set_pytorch_cuda_alloc_conf() + + +def get_not_null(value, default=None): + """ + return the value if it's not None, otherwise return the default value + """ + return value if value is not None else default diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index e488ed7d5..13c9d4ea1 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -258,7 +258,7 @@ class MultipackBatchSampler(BatchSampler): batch_max_len: int, # Maximum sequence length (bin capacity) lengths: np.ndarray, # Sequence lengths packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate - drop_last: bool = False, # Whether to drop final batches (might be incomplete) + drop_last: bool = True, # Whether to drop final batches (might be incomplete) num_count_samples: int = 16, # Number of times to estimate batch count sequential: bool = False, # Whether to use sequential packing group_size: int = 100_000, # Size of groups for parallel packing