use distributed sampler, avoid accelerate prepare

This commit is contained in:
Wing Lian
2023-07-19 12:16:19 -04:00
parent b02484a83e
commit 4ab9ab79fd
2 changed files with 18 additions and 20 deletions

View File

@@ -109,7 +109,6 @@ class MultipackDistributedDataloader:
seq_max_length: int = 2048,
batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None,
seed: int = 0,
):
# Dataset
self.dataset = dataset
@@ -127,19 +126,10 @@ class MultipackDistributedDataloader:
self.num_replicas = 1
self.rank = 0
# Seed
self.seed = seed
# Epoch
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
def set_epoch(self, epoch: int):
self.epoch = epoch
def generate_batches(self, set_stats=False):
if self.sampler:
indices = [idx for idx in self.sampler]

View File

@@ -14,7 +14,7 @@ import transformers
from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
@@ -87,18 +87,26 @@ class AxolotlTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
)
return MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
)
return super().get_train_dataloader()
@@ -278,7 +286,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None