sampler / dataloader refactor
This commit is contained in:
@@ -7,7 +7,7 @@ import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
@@ -435,20 +435,25 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
def _create_sp_sampler(self, dataset, shuffle=True, is_eval=False):
|
||||
"""Create a sampler for sequence parallelism"""
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
return torch.utils.data.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed if shuffle else None,
|
||||
shuffle=shuffle,
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self) -> torch.utils.data.Sampler | None:
|
||||
# Handle sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
# Create base sampler for SP groups
|
||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed if not self.args.curriculum_sampling else None,
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
drop_last=True,
|
||||
base_sampler = self._create_sp_sampler(
|
||||
self.train_dataset, shuffle=not self.args.curriculum_sampling
|
||||
)
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
@@ -458,16 +463,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
dataset=self.train_dataset,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
)
|
||||
|
||||
return base_sampler
|
||||
|
||||
# Regular training sampler logic
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
base_sampler = (
|
||||
SequentialSampler(self.train_dataset)
|
||||
if self.args.curriculum_sampling
|
||||
else RandomSampler(self.train_dataset)
|
||||
)
|
||||
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=self.train_dataset,
|
||||
@@ -480,94 +484,76 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Optional[Dataset] = None
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
self, eval_dataset: Dataset | None = None
|
||||
) -> torch.utils.data.Sampler | None:
|
||||
"""Get evaluation sampler"""
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Handle sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Create sampler for SP groups
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
# Create distributed sampler for the SP group
|
||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
group_size = (
|
||||
self.args.eval_packing_group_size
|
||||
if hasattr(self.args, "eval_packing_group_size")
|
||||
else self.args.sample_packing_group_size
|
||||
)
|
||||
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=eval_dataset,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
return base_sampler
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
base_sampler = SequentialSampler(eval_dataset)
|
||||
group_size = (
|
||||
# Get the appropriate group size for sample packing
|
||||
def get_pack_group_size():
|
||||
return (
|
||||
self.args.eval_packing_group_size
|
||||
if hasattr(self.args, "eval_packing_group_size")
|
||||
else self.args.sample_packing_group_size
|
||||
)
|
||||
|
||||
# Handle sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self._create_sp_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=eval_dataset,
|
||||
group_size=get_pack_group_size(),
|
||||
)
|
||||
return base_sampler
|
||||
|
||||
# Regular evaluation sampler logic
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
base_sampler = SequentialSampler(eval_dataset)
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=eval_dataset,
|
||||
group_size=group_size,
|
||||
group_size=get_pack_group_size(),
|
||||
)
|
||||
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||
"""Create common dataloader parameters for train or eval."""
|
||||
batch_size = custom_batch_size or (
|
||||
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||
)
|
||||
|
||||
# Handle dataset preprocessing
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and not self.args.pretraining
|
||||
and "length" in train_dataset.features
|
||||
):
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
if not (self.args.sample_packing and not self.args.pretraining):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="training"
|
||||
)
|
||||
|
||||
# Build common dataloader parameters
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
params = {
|
||||
"batch_size": batch_size,
|
||||
"collate_fn": self.data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||
}
|
||||
|
||||
# Add persistent workers only for training
|
||||
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||
|
||||
# Add prefetch factor if specified
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
sampler = self._get_train_sampler()
|
||||
return params
|
||||
|
||||
def _prepare_dataloader(
|
||||
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||
):
|
||||
"""Prepare a dataloader with the given dataset and sampler."""
|
||||
# Get base parameters
|
||||
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||
|
||||
# Add sampler configuration
|
||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
if isinstance(sampler, BatchSampler):
|
||||
# batch_size and batch_sampler are mutually exclusive
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
@@ -576,114 +562,110 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
if not is_eval:
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
dataloader = DataLoader(train_dataset, **dataloader_params)
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
# Create the dataloader
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
if self.args.sample_packing and (
|
||||
(not is_eval and not self.args.pretraining)
|
||||
or (is_eval and self.args.eval_sample_packing is not False)
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
# Don't prepare dataloader for sequence parallelism
|
||||
# We use a distributed sampler in this case
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator # type: ignore
|
||||
|
||||
# Handle dataset preprocessing
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
if not (self.args.sample_packing and not self.args.pretraining):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
data_collator,
|
||||
description="training",
|
||||
)
|
||||
|
||||
# Get sampler and create dataloader
|
||||
sampler = self._get_train_sampler()
|
||||
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||
"""Get dataloader for evaluation"""
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
)
|
||||
if eval_dataset and "length" in eval_dataset.features:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.train_data_collator
|
||||
)
|
||||
|
||||
return dataloader
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
|
||||
# Only remove length column if it exists
|
||||
if "length" in eval_dataset.features:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
)
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
dataloader = DataLoader(eval_dataset, **dataloader_params)
|
||||
|
||||
# Don't prepare dataloader for sequence parallelism
|
||||
# We use a distributed sampler in this case
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
data_collator = (
|
||||
# Handle sample packing or sequence parallelism
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and self.args.eval_sample_packing is not False
|
||||
or self.args.sequence_parallel_degree > 1
|
||||
):
|
||||
# Get appropriate data collator
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
if self.eval_data_collator
|
||||
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
||||
else self.data_collator
|
||||
)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
|
||||
# Handle dataset preprocessing as in the parent implementation
|
||||
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||
eval_dataset = self._remove_unused_columns(
|
||||
eval_dataset, description="evaluation"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="evaluation"
|
||||
)
|
||||
# Handle dataset preprocessing for SP
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if is_datasets_available() and isinstance(
|
||||
eval_dataset, datasets.Dataset
|
||||
):
|
||||
eval_dataset = self._remove_unused_columns(
|
||||
eval_dataset, description="evaluation"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
self.data_collator, description="evaluation"
|
||||
)
|
||||
|
||||
# Build dataloader parameters
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.per_device_eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||
batch_size = (
|
||||
self.args.eval_batch_size
|
||||
if self.args.sample_packing
|
||||
else self.args.per_device_eval_batch_size
|
||||
)
|
||||
sampler = self._get_eval_sampler(eval_dataset)
|
||||
dataloader = self._prepare_dataloader(
|
||||
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
||||
)
|
||||
|
||||
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||
sampler = self._get_eval_sampler(eval_dataset)
|
||||
dataloader_params["sampler"] = sampler
|
||||
|
||||
# Don't prepare dataloader for sequence parallelism
|
||||
# We use a distributed sampler in this case
|
||||
return DataLoader(eval_dataset, **dataloader_params)
|
||||
return dataloader
|
||||
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
) -> torch.utils.data.Sampler | None:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
@@ -920,15 +902,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
return res
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
start_time (`Optional[float]`):
|
||||
The start of training.
|
||||
logs: The values to log.
|
||||
start_time: The start of training.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
@@ -940,7 +920,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
) -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
@@ -1043,7 +1023,7 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
Reference in New Issue
Block a user