fixes and go back to distributed sampler since batch sampler won't work

This commit is contained in:
Wing Lian
2023-08-08 03:49:29 -04:00
parent 4f7c04bae0
commit 58e9dee204
2 changed files with 16 additions and 22 deletions

View File

@@ -427,7 +427,7 @@ def load_prepare_datasets(
+ "|"
+ "train"
+ "|"
+ cfg.seed
+ str(cfg.seed or 42)
)
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
@@ -436,7 +436,7 @@ def load_prepare_datasets(
+ "|"
+ "test"
+ "|"
+ cfg.seed
+ str(cfg.seed or 42)
)
train_fingerprint = hashlib.md5(
to_hash_train.encode(), usedforsecurity=False
@@ -449,7 +449,7 @@ def load_prepare_datasets(
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
@@ -458,7 +458,7 @@ def load_prepare_datasets(
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)

View File

@@ -16,7 +16,7 @@ import transformers
from datasets import Dataset, set_caching_enabled
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
@@ -159,23 +159,20 @@ class AxolotlTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
# def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# if self.args.world_size > 1 and self.args.sample_packing:
# return DistributedSampler(
# self.train_dataset,
# num_replicas=self.args.world_size,
# rank=self.args.process_index,
# seed=self.args.seed,
# )
# return super()._get_train_sampler()
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
# If set to True, the dataloader prepared is only iterated through on the
# main process and then the batches are split and broadcast to each process
self.accelerator.dispatch_batches = True
return self.accelerator.prepare_data_loader(
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
@@ -197,10 +194,7 @@ class AxolotlTrainer(Trainer):
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
# If set to True, the datalaoder prepared is only iterated through on the
# main process and then the batches are split and broadcast to each process
self.accelerator.dispatch_batches = True
return self.accelerator.prepare_data_loader(
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,