use distributed sampler, avoid accelerate prepare
This commit is contained in:
@@ -109,7 +109,6 @@ class MultipackDistributedDataloader:
|
|||||||
seq_max_length: int = 2048,
|
seq_max_length: int = 2048,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
sampler: Union[Sampler, DistributedSampler] = None,
|
sampler: Union[Sampler, DistributedSampler] = None,
|
||||||
seed: int = 0,
|
|
||||||
):
|
):
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
@@ -127,19 +126,10 @@ class MultipackDistributedDataloader:
|
|||||||
self.num_replicas = 1
|
self.num_replicas = 1
|
||||||
self.rank = 0
|
self.rank = 0
|
||||||
|
|
||||||
# Seed
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
# Epoch
|
|
||||||
self.epoch = 0
|
|
||||||
|
|
||||||
# statistics
|
# statistics
|
||||||
self.eff_total_used = 0
|
self.eff_total_used = 0
|
||||||
self.eff_total_slots = 0
|
self.eff_total_slots = 0
|
||||||
|
|
||||||
def set_epoch(self, epoch: int):
|
|
||||||
self.epoch = epoch
|
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
def generate_batches(self, set_stats=False):
|
||||||
if self.sampler:
|
if self.sampler:
|
||||||
indices = [idx for idx in self.sampler]
|
indices = [idx for idx in self.sampler]
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import transformers
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
@@ -87,18 +87,26 @@ class AxolotlTrainer(Trainer):
|
|||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
|
if self.args.world_size > 1 and self.args.sample_packing:
|
||||||
|
return DistributedSampler(
|
||||||
|
self.train_dataset,
|
||||||
|
num_replicas=self.args.world_size,
|
||||||
|
rank=self.args.process_index,
|
||||||
|
seed=self.args.seed,
|
||||||
|
)
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
train_sampler = self._get_train_sampler()
|
train_sampler = self._get_train_sampler()
|
||||||
|
|
||||||
return self.accelerator.prepare(
|
return MultipackDistributedDataloader(
|
||||||
MultipackDistributedDataloader(
|
self.train_dataset,
|
||||||
self.train_dataset,
|
batch_size=self._train_batch_size,
|
||||||
batch_size=self._train_batch_size,
|
seq_max_length=self.args.max_seq_length,
|
||||||
seq_max_length=self.args.max_seq_length,
|
collate_fn=self.data_collator,
|
||||||
collate_fn=self.data_collator,
|
sampler=train_sampler,
|
||||||
sampler=train_sampler,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
@@ -278,7 +286,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size
|
||||||
if cfg.eval_batch_size is not None
|
if cfg.eval_batch_size is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user