fixes and go back to distributed sampler since batch sampler won't work
This commit is contained in:
@@ -427,7 +427,7 @@ def load_prepare_datasets(
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ cfg.seed
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
to_hash_test = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
@@ -436,7 +436,7 @@ def load_prepare_datasets(
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ cfg.seed
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
train_fingerprint = hashlib.md5(
|
||||
to_hash_train.encode(), usedforsecurity=False
|
||||
@@ -449,7 +449,7 @@ def load_prepare_datasets(
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
@@ -458,7 +458,7 @@ def load_prepare_datasets(
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ import transformers
|
||||
from datasets import Dataset, set_caching_enabled
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
@@ -159,23 +159,20 @@ class AxolotlTrainer(Trainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
# def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
# if self.args.world_size > 1 and self.args.sample_packing:
|
||||
# return DistributedSampler(
|
||||
# self.train_dataset,
|
||||
# num_replicas=self.args.world_size,
|
||||
# rank=self.args.process_index,
|
||||
# seed=self.args.seed,
|
||||
# )
|
||||
# return super()._get_train_sampler()
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
# If set to True, the dataloader prepared is only iterated through on the
|
||||
# main process and then the batches are split and broadcast to each process
|
||||
self.accelerator.dispatch_batches = True
|
||||
return self.accelerator.prepare_data_loader(
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
@@ -197,10 +194,7 @@ class AxolotlTrainer(Trainer):
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
# If set to True, the datalaoder prepared is only iterated through on the
|
||||
# main process and then the batches are split and broadcast to each process
|
||||
self.accelerator.dispatch_batches = True
|
||||
return self.accelerator.prepare_data_loader(
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
|
||||
Reference in New Issue
Block a user