diff --git a/setup.py b/setup.py index b2eeb92d6..a93d8d49e 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ def parse_requirements(extras_require_map): try: torch_version = version("torch") except PackageNotFoundError: - torch_version = "2.6.0" # default to torch 2.6 + torch_version = "2.8.0" # default to torch 2.8.0 _install_requires.append(f"torch=={torch_version}") version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 11dfecb98..7d7420fb8 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -225,17 +225,6 @@ class AxolotlTrainer( data_collator = self.data_collator if is_training else self.eval_data_collator - if dataset.column_names and "length" in dataset.column_names: - dataset = dataset.remove_columns(["length"]) - if ( - dataset.column_names - and "position_ids" in dataset.column_names - and "attention_mask" in dataset.column_names - and self.args.sample_packing - and self.args.sample_packing_drop_attention_mask - ): - dataset = dataset.remove_columns(["attention_mask"]) - if isinstance(dataset, datasets.Dataset): if is_training: if not self.args.sample_packing or self.args.pretraining: @@ -294,6 +283,18 @@ class AxolotlTrainer( ): self.accelerator.even_batches = False + if dataset.column_names and "length" in dataset.column_names: + dataset = dataset.remove_columns(["length"]) + + if ( + dataset.column_names + and "position_ids" in dataset.column_names + and "attention_mask" in dataset.column_names + and self.args.sample_packing + and self.args.sample_packing_drop_attention_mask + ): + dataset = dataset.remove_columns(["attention_mask"]) + dataloader = DataLoader(dataset, **dataloader_params) # Accelerator.free_memory() will destroy the references, so