Fix trainer dataloader slow loading issue (#3219)
* Fix trainer dataloader handling in src/axolotl/core/trainers/base.py * update comment to reflect torch version --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
2
setup.py
2
setup.py
@@ -49,7 +49,7 @@ def parse_requirements(extras_require_map):
|
|||||||
try:
|
try:
|
||||||
torch_version = version("torch")
|
torch_version = version("torch")
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
torch_version = "2.6.0" # default to torch 2.6
|
torch_version = "2.8.0" # default to torch 2.8.0
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
|
|||||||
@@ -225,17 +225,6 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
data_collator = self.data_collator if is_training else self.eval_data_collator
|
data_collator = self.data_collator if is_training else self.eval_data_collator
|
||||||
|
|
||||||
if dataset.column_names and "length" in dataset.column_names:
|
|
||||||
dataset = dataset.remove_columns(["length"])
|
|
||||||
if (
|
|
||||||
dataset.column_names
|
|
||||||
and "position_ids" in dataset.column_names
|
|
||||||
and "attention_mask" in dataset.column_names
|
|
||||||
and self.args.sample_packing
|
|
||||||
and self.args.sample_packing_drop_attention_mask
|
|
||||||
):
|
|
||||||
dataset = dataset.remove_columns(["attention_mask"])
|
|
||||||
|
|
||||||
if isinstance(dataset, datasets.Dataset):
|
if isinstance(dataset, datasets.Dataset):
|
||||||
if is_training:
|
if is_training:
|
||||||
if not self.args.sample_packing or self.args.pretraining:
|
if not self.args.sample_packing or self.args.pretraining:
|
||||||
@@ -294,6 +283,18 @@ class AxolotlTrainer(
|
|||||||
):
|
):
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
|
if dataset.column_names and "length" in dataset.column_names:
|
||||||
|
dataset = dataset.remove_columns(["length"])
|
||||||
|
|
||||||
|
if (
|
||||||
|
dataset.column_names
|
||||||
|
and "position_ids" in dataset.column_names
|
||||||
|
and "attention_mask" in dataset.column_names
|
||||||
|
and self.args.sample_packing
|
||||||
|
and self.args.sample_packing_drop_attention_mask
|
||||||
|
):
|
||||||
|
dataset = dataset.remove_columns(["attention_mask"])
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, **dataloader_params)
|
dataloader = DataLoader(dataset, **dataloader_params)
|
||||||
|
|
||||||
# Accelerator.free_memory() will destroy the references, so
|
# Accelerator.free_memory() will destroy the references, so
|
||||||
|
|||||||
Reference in New Issue
Block a user