From 3750fdcf79313f5c626d9508c72ea167f7da2985 Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Wed, 22 Oct 2025 07:22:14 -0700 Subject: [PATCH] Fix trainer dataloader slow loading issue (#3219) * Fix trainer dataloader handling in src/axolotl/core/trainers/base.py * update comment to reflect torch version --------- Co-authored-by: Wing Lian --- setup.py | 2 +- src/axolotl/core/trainers/base.py | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index b2eeb92d6..a93d8d49e 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ def parse_requirements(extras_require_map): try: torch_version = version("torch") except PackageNotFoundError: - torch_version = "2.6.0" # default to torch 2.6 + torch_version = "2.8.0" # default to torch 2.8.0 _install_requires.append(f"torch=={torch_version}") version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 11dfecb98..7d7420fb8 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -225,17 +225,6 @@ class AxolotlTrainer( data_collator = self.data_collator if is_training else self.eval_data_collator - if dataset.column_names and "length" in dataset.column_names: - dataset = dataset.remove_columns(["length"]) - if ( - dataset.column_names - and "position_ids" in dataset.column_names - and "attention_mask" in dataset.column_names - and self.args.sample_packing - and self.args.sample_packing_drop_attention_mask - ): - dataset = dataset.remove_columns(["attention_mask"]) - if isinstance(dataset, datasets.Dataset): if is_training: if not self.args.sample_packing or self.args.pretraining: @@ -294,6 +283,18 @@ class AxolotlTrainer( ): self.accelerator.even_batches = False + if dataset.column_names and "length" in dataset.column_names: + dataset = dataset.remove_columns(["length"]) + + if ( + dataset.column_names + and "position_ids" in dataset.column_names + and "attention_mask" in dataset.column_names + and self.args.sample_packing + and self.args.sample_packing_drop_attention_mask + ): + dataset = dataset.remove_columns(["attention_mask"]) + dataloader = DataLoader(dataset, **dataloader_params) # Accelerator.free_memory() will destroy the references, so